mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
Update TensorRT-LLM (#524)
This commit is contained in:
parent
711a28d9bf
commit
71f60f6df0
4
.gitignore
vendored
4
.gitignore
vendored
@ -18,6 +18,10 @@ venv/
|
||||
.hypothesis/
|
||||
.idea/
|
||||
cpp/cmake-build-*
|
||||
cpp/.ccache/
|
||||
tensorrt_llm/libs
|
||||
tensorrt_llm/bindings.pyi
|
||||
tensorrt_llm/bindings/*.pyi
|
||||
|
||||
# Testing
|
||||
.coverage.*
|
||||
|
||||
1
.gitmodules
vendored
1
.gitmodules
vendored
@ -1,7 +1,6 @@
|
||||
[submodule "3rdparty/cutlass"]
|
||||
path = 3rdparty/cutlass
|
||||
url = https://github.com/NVIDIA/cutlass.git
|
||||
branch = v2.10.0
|
||||
[submodule "3rdparty/json"]
|
||||
path = 3rdparty/json
|
||||
url = https://github.com/nlohmann/json.git
|
||||
|
||||
@ -15,7 +15,7 @@ repos:
|
||||
rev: v4.1.0
|
||||
hooks:
|
||||
- id: check-added-large-files
|
||||
exclude: 'cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin'
|
||||
exclude: 'cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/'
|
||||
- id: check-merge-conflict
|
||||
- id: check-symlinks
|
||||
- id: detect-private-key
|
||||
@ -33,9 +33,7 @@ repos:
|
||||
- id: clang-format
|
||||
types_or: [c++, c, cuda]
|
||||
exclude: |
|
||||
(?x)^(
|
||||
cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/.*
|
||||
)$
|
||||
(?x)^(.*cubin.cpp$ | .*fmha_cubin.h)$
|
||||
- repo: https://github.com/cheshirekow/cmake-format-precommit
|
||||
rev: v0.6.10
|
||||
hooks:
|
||||
|
||||
124
README.md
124
README.md
@ -36,7 +36,6 @@ H200 FP8 achieves 11,819 tok/s on Llama2-13B on a single GPU, and is up to 1.9x
|
||||
[2023/10/4 - Perplexity](https://blog.perplexity.ai/blog/introducing-pplx-api) ;
|
||||
[2023/9/27 - CloudFlare](https://www.cloudflare.com/press-releases/2023/cloudflare-powers-hyper-local-ai-inference-with-nvidia/);
|
||||
|
||||
|
||||
## Table of Contents
|
||||
|
||||
- [TensorRT-LLM Overview](#tensorrt-llm-overview)
|
||||
@ -186,7 +185,8 @@ TensorRT-LLM is rigorously tested on the following GPUs:
|
||||
|
||||
* [H100](https://www.nvidia.com/en-us/data-center/h100/)
|
||||
* [L40S](https://www.nvidia.com/en-us/data-center/l40s/)
|
||||
* [A100](https://www.nvidia.com/en-us/data-center/a100/)/[A30](https://www.nvidia.com/en-us/data-center/products/a30-gpu/)
|
||||
* [A100](https://www.nvidia.com/en-us/data-center/a100/)
|
||||
* [A30](https://www.nvidia.com/en-us/data-center/products/a30-gpu/)
|
||||
* [V100](https://www.nvidia.com/en-us/data-center/v100/) (experimental)
|
||||
|
||||
If a GPU is not listed above, it is important to note that TensorRT-LLM is
|
||||
@ -254,6 +254,7 @@ The list of supported models is:
|
||||
* [LLaMA-v2](examples/llama)
|
||||
* [Mistral](examples/llama)
|
||||
* [MPT](examples/mpt)
|
||||
* [mT5](examples/enc_dec)
|
||||
* [OPT](examples/opt)
|
||||
* [Qwen](examples/qwen)
|
||||
* [Replit Code](examples/mpt)
|
||||
@ -261,7 +262,10 @@ The list of supported models is:
|
||||
* [StarCoder](examples/gpt)
|
||||
* [T5](examples/enc_dec)
|
||||
|
||||
Note: [Encoder-Decoder](examples/enc_dec/) provides general encoder-decoder support that contains many encoder-decoder models such as T5, Flan-T5, etc. We unroll the exact model names in the list above to let users find specific models easier.
|
||||
Note: [Encoder-Decoder](examples/enc_dec/) provides general encoder-decoder
|
||||
support that contains many encoder-decoder models such as T5, Flan-T5, etc. We
|
||||
unroll the exact model names in the list above to let users find specific
|
||||
models easier.
|
||||
|
||||
## Performance
|
||||
|
||||
@ -325,7 +329,11 @@ enable plugins, for example: `--use_gpt_attention_plugin`.
|
||||
|
||||
* MPI + Slurm
|
||||
|
||||
TensorRT-LLM is a [MPI](https://en.wikipedia.org/wiki/Message_Passing_Interface)-aware package that uses [`mpi4py`](https://mpi4py.readthedocs.io/en/stable/). If you are running scripts in a [Slurm](https://slurm.schedmd.com/) environment, you might encounter interferences:
|
||||
TensorRT-LLM is a
|
||||
[MPI](https://en.wikipedia.org/wiki/Message_Passing_Interface)-aware package
|
||||
that uses [`mpi4py`](https://mpi4py.readthedocs.io/en/stable/). If you are
|
||||
running scripts in a [Slurm](https://slurm.schedmd.com/) environment, you might
|
||||
encounter interferences:
|
||||
```
|
||||
--------------------------------------------------------------------------
|
||||
PMI2_Init failed to initialize. Return code: 14
|
||||
@ -347,19 +355,123 @@ SLURM, depending upon the SLURM version you are using:
|
||||
Please configure as appropriate and try again.
|
||||
--------------------------------------------------------------------------
|
||||
```
|
||||
As a rule of thumb, if you are running TensorRT-LLM interactively on a Slurm node, prefix your commands with `mpirun -n 1` to run TensorRT-LLM in a dedicated MPI environment, not the one provided by your Slurm allocation.
|
||||
As a rule of thumb, if you are running TensorRT-LLM interactively on a Slurm
|
||||
node, prefix your commands with `mpirun -n 1` to run TensorRT-LLM in a
|
||||
dedicated MPI environment, not the one provided by your Slurm allocation.
|
||||
|
||||
For example: `mpirun -n 1 python3 examples/gpt/build.py ...`
|
||||
|
||||
## Release notes
|
||||
|
||||
* TensorRT-LLM requires TensorRT 9.1.0.4 and 23.08 containers.
|
||||
* TensorRT-LLM requires TensorRT 9.2 and 23.10 containers.
|
||||
|
||||
### Change Log
|
||||
|
||||
#### Version 0.6.0
|
||||
|
||||
* Models
|
||||
* ChatGLM3
|
||||
* InternLM (contributed by @wangruohui)
|
||||
* Mistral 7B (developed in collaboration with Mistral.AI)
|
||||
* MQA/GQA support to MPT (and GPT) models (contributed by @bheilbrun)
|
||||
* Qwen (contributed by @Tlntin and @zhaohb)
|
||||
* Replit Code V-1.5 3B (external contribution)
|
||||
* T5, mT5, Flan-T5 (Python runtime only)
|
||||
|
||||
* Features
|
||||
* Add runtime statistics related to active requests and KV cache
|
||||
utilization from the batch manager (see
|
||||
the [batch manager](docs/source/batch_manager.md) documentation)
|
||||
* Add `sequence_length` tensor to support proper lengths in beam-search
|
||||
(when beam-width > 1 - see
|
||||
[tensorrt_llm/batch_manager/GptManager.h](cpp/include/tensorrt_llm/batch_manager/GptManager.h))
|
||||
* BF16 support for encoder-decoder models (Python runtime - see
|
||||
[examples/enc_dec](examples/enc_dec/README.md))
|
||||
* Improvements to memory utilization (CPU and GPU - including memory
|
||||
leaks)
|
||||
* Improved error reporting and memory consumption
|
||||
* Improved support for stop and bad words
|
||||
* INT8 SmoothQuant and INT8 KV Cache support for the Baichuan models (see
|
||||
[examples/baichuan](examples/baichuan/README.md))
|
||||
* INT4 AWQ Tensor Parallelism support and INT8 KV cache + AWQ/weight-only
|
||||
support for the GPT-J model (see [examples/gptj](examples/gptj/README.md))
|
||||
* INT4 AWQ support for the Falcon models
|
||||
(see [examples/falcon](examples/falcon/README.md))
|
||||
* LoRA support (functional preview only - limited to the Python runtime,
|
||||
only QKV support and not optimized in terms of runtime performance) for
|
||||
the GPT model (see the
|
||||
[Run LoRA with the Nemo checkpoint](examples/gpt/README.md#Run-LoRA-with-the-Nemo-checkpoint)
|
||||
in the GPT example)
|
||||
* Multi-GPU support for encoder-decoder models (Python runtime - see
|
||||
[examples/enc_dec](examples/enc_dec/README.md))
|
||||
* New heuristic for launching the Multi-block Masked MHA kernel (similar
|
||||
to FlashDecoding - see
|
||||
[decoderMaskedMultiheadAttentionLaunch.h](cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderMaskedMultiheadAttentionLaunch.h))
|
||||
* Prompt-Tuning support for GPT and LLaMA models (see the
|
||||
[Prompt-tuning](examples/gpt/README.md#Prompt-tuning) Section in the GPT example)
|
||||
* Performance optimizations in various CUDA kernels
|
||||
* Possibility to exclude input tokens from the output (see `excludeInputInOutput` in
|
||||
[`GptManager`](cpp/include/tensorrt_llm/batch_manager/GptManager.h))
|
||||
* Python binding for the C++ runtime (GptSession - see [`pybind`](cpp/tensorrt_llm/pybind))
|
||||
* Support for different micro batch sizes for context and generation
|
||||
phases with pipeline parallelism (see `GptSession::Config::ctxMicroBatchSize` and
|
||||
`GptSession::Config::genMicroBatchSize` in
|
||||
[tensorrt_llm/runtime/gptSession.h](cpp/include/tensorrt_llm/runtime/gptSession.h))
|
||||
* Support for "remove input padding" for encoder-decoder models (see
|
||||
[examples/enc_dec](examples/enc_dec/README.md))
|
||||
* Support for context and generation logits (see `mComputeContextLogits` and
|
||||
`mComputeGenerationLogits` in
|
||||
[tensorrt_llm/runtime/gptModelConfig.h](cpp/include/tensorrt_llm/runtime/gptModelConfig.h))
|
||||
* Support for `logProbs` and `cumLogProbs` (see `"output_log_probs"` and
|
||||
`"cum_log_probs"` in [`GptManager`](cpp/include/tensorrt_llm/batch_manager/GptManager.h))
|
||||
* Update to CUTLASS 3.x
|
||||
|
||||
* Bug fixes
|
||||
* Fix for ChatGLM2 #93 and #138
|
||||
* Fix tensor names error "RuntimeError: Tensor names
|
||||
(`host_max_kv_cache_length`) in engine are not the same as expected in
|
||||
the main branch" #369
|
||||
* Fix weights split issue in BLOOM when `world_size = 2` ("array split
|
||||
does not result in an equal division") #374
|
||||
* Fix SmoothQuant multi-GPU failure with tensor parallelism is 2 #267
|
||||
* Fix a crash in GenerationSession if stream keyword argument is not None
|
||||
#202
|
||||
* Fix a typo when calling PyNVML API [BUG] code bug #410
|
||||
* Fix bugs related to the improper management of the `end_id` for various
|
||||
models [C++ and Python]
|
||||
* Fix memory leaks [C++ code and Python models]
|
||||
* Fix the std::alloc error when running the gptManagerBenchmark -- issue
|
||||
gptManagerBenchmark std::bad_alloc error #66
|
||||
* Fix a bug in pipeline parallelism when beam-width > 1
|
||||
* Fix a bug with Llama GPTQ due to improper support of GQA
|
||||
* Fix issue #88
|
||||
* Fix an issue with the Huggingface Transformers version #16
|
||||
* Fix link jump in windows readme.md #30 - by @yuanlehome
|
||||
* Fix typo in batchScheduler.h #56 - by @eltociear
|
||||
* Fix typo #58 - by @RichardScottOZ
|
||||
* Fix Multi-block MMHA: Difference between `max_batch_size` in the engine
|
||||
builder and `max_num_sequences` in TrtGptModelOptionalParams? #65
|
||||
* Fix the log message to be more accurate on KV cache #224
|
||||
* Fix Windows release wheel installation: Failed to install the release
|
||||
wheel for Windows using pip #261
|
||||
* Fix missing torch dependencies: [BUG] The batch_manage.a choice error
|
||||
in --cpp-only when torch's cxx_abi version is different with gcc #151
|
||||
* Fix linking error during compiling google-test & benchmarks #277
|
||||
* Fix logits dtype for Baichuan and ChatGLM: segmentation fault caused by
|
||||
the lack of bfloat16 #335
|
||||
* Minor bug fixes
|
||||
|
||||
#### Version 0.5.0
|
||||
|
||||
* TensorRT-LLM v0.5.0 is the first public release.
|
||||
|
||||
### Known Issues
|
||||
|
||||
* The hang reported in issue
|
||||
[#149](https://github.com/triton-inference-server/tensorrtllm_backend/issues/149)
|
||||
has not been reproduced by the TensorRT-LLM team. If it is caused by a bug
|
||||
in TensorRT-LLM, that bug may be present in that release
|
||||
|
||||
### Report Issues
|
||||
|
||||
You can use GitHub issues to report issues with TensorRT-LLM.
|
||||
|
||||
@ -18,9 +18,13 @@ instead, and be sure to set DLL paths as specified in
|
||||
|
||||
### 2. Launch C++ benchmarking (Fixed BatchSize/InputLen/OutputLen)
|
||||
|
||||
#### Prepare TensorRT-LLM engine(s)
|
||||
|
||||
Before you launch C++ benchmarking, please make sure that you have already built engine(s) using TensorRT-LLM API, C++ benchmarking code cannot generate engine(s) for you.
|
||||
|
||||
You can reuse the engine built by benchmarking code for Python Runtime, please see that [`document`](../python/README.md).
|
||||
You can use the [`build.py`](../python/build.py) script to build the engine(s). Alternatively, if you have already benchmarked Python Runtime, you can reuse the engine(s) built by benchmarking code, please see that [`document`](../python/README.md).
|
||||
|
||||
#### Launch benchmarking
|
||||
|
||||
For detailed usage, you can do the following
|
||||
```
|
||||
|
||||
@ -15,11 +15,14 @@
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "tensorrt_llm/batch_manager/GptManager.h"
|
||||
#include "tensorrt_llm/batch_manager/NamedTensor.h"
|
||||
|
||||
#include "tensorrt_llm/batch_manager/inferenceRequest.h"
|
||||
#include "tensorrt_llm/batch_manager/namedTensor.h"
|
||||
#include "tensorrt_llm/common/assert.h"
|
||||
#include "tensorrt_llm/common/logger.h"
|
||||
#include "tensorrt_llm/common/memoryUtils.h"
|
||||
#include "tensorrt_llm/common/mpiUtils.h"
|
||||
#include "tensorrt_llm/common/stringUtils.h"
|
||||
#include "tensorrt_llm/plugins/api/tllmPlugin.h"
|
||||
#include "tensorrt_llm/runtime/tllmLogger.h"
|
||||
|
||||
@ -42,9 +45,9 @@ namespace trt = nvinfer1;
|
||||
class WorkItem
|
||||
{
|
||||
public:
|
||||
WorkItem(std::shared_ptr<InferenceRequest> ir, uint64_t RequestId)
|
||||
WorkItem(std::shared_ptr<InferenceRequest> ir, uint64_t requestId)
|
||||
: mInferenceRequest(ir)
|
||||
, mRequestId(RequestId)
|
||||
, mRequestId(requestId)
|
||||
{
|
||||
}
|
||||
|
||||
@ -100,18 +103,12 @@ public:
|
||||
void push(std::shared_ptr<InferenceRequest> request, uint64_t requestId)
|
||||
{
|
||||
std::lock_guard<std::mutex> lk(mMutex);
|
||||
if (hasInProgressReqId(requestId) || hasPendingReqId(requestId))
|
||||
{
|
||||
std::string errStr
|
||||
= "requestId " + std::to_string(requestId) + " is already in progress, request is ignored.";
|
||||
throw std::runtime_error(errStr);
|
||||
}
|
||||
else
|
||||
{
|
||||
auto workItem = std::make_shared<WorkItem>(request, requestId);
|
||||
mPendingWorkItems.push_back(workItem);
|
||||
mPendingWorkItemsReqIds.insert(workItem->requestId());
|
||||
}
|
||||
TLLM_CHECK_WITH_INFO(!hasInProgressReqId(requestId) && !hasPendingReqId(requestId),
|
||||
"requestId %lu is already in progress, request is ignored.", requestId);
|
||||
|
||||
auto workItem = std::make_shared<WorkItem>(request, requestId);
|
||||
mPendingWorkItems.push_back(workItem);
|
||||
mPendingWorkItemsReqIds.insert(workItem->requestId());
|
||||
}
|
||||
|
||||
/// @brief Get a new work item from the queue, and move it to the list of
|
||||
@ -208,12 +205,12 @@ public:
|
||||
|
||||
void recordStart(std::shared_ptr<InferenceRequest> request, uint64_t requestId)
|
||||
{
|
||||
const auto& input_ids_tensor = request->getInputTensor("input_ids");
|
||||
std::vector<int64_t> tensorShape(input_ids_tensor->getShape().nbDims);
|
||||
auto const inputLength = tensorShape[1];
|
||||
auto const [specified, outputLength]
|
||||
= request->getScalarValueFromTensor<int>("request_output_len", {1, 1}, false);
|
||||
assert(specified);
|
||||
auto const inputLength = request->getInputIds()->getSize();
|
||||
auto const maxNewTokens = request->getMaxNewTokensNamed();
|
||||
auto const& outputLengthTensor = maxNewTokens.tensor;
|
||||
TLLM_CHECK_WITH_INFO(outputLengthTensor != nullptr && outputLengthTensor->getSize() > 0,
|
||||
"Undefined scalar vector for %s", maxNewTokens.name.c_str());
|
||||
auto const outputLength = *bufferCast<SizeType>(*outputLengthTensor);
|
||||
auto const start = std::chrono::steady_clock::now();
|
||||
mRequestBenchInfos[requestId] = BenchInfo(inputLength, outputLength, start);
|
||||
}
|
||||
@ -286,20 +283,15 @@ public:
|
||||
mWorkItemsQueue.clear();
|
||||
}
|
||||
|
||||
void enqueue(std::vector<NamedTensor> tensors, uint64_t requestId, bool streaming)
|
||||
void enqueue(std::shared_ptr<InferenceRequest> const& request)
|
||||
{
|
||||
// Create InferenceRequest from a set of tensors
|
||||
auto request = std::make_shared<InferenceRequest>(requestId);
|
||||
TLLM_CHECK(request != nullptr);
|
||||
auto const requestId = request->getRequestId();
|
||||
if (requestId == mTerminateReqId)
|
||||
{
|
||||
mWorkItemsQueue.push(request, requestId);
|
||||
return;
|
||||
}
|
||||
for (auto t : tensors)
|
||||
{
|
||||
request->emplaceInputTensor(t.name, std::move(t.tensor));
|
||||
}
|
||||
request->setIsStreaming(streaming);
|
||||
|
||||
// Enqueue
|
||||
try
|
||||
@ -307,11 +299,14 @@ public:
|
||||
mRecorder->recordStart(request, requestId);
|
||||
mWorkItemsQueue.push(request, requestId);
|
||||
}
|
||||
catch (const tc::TllmException& e)
|
||||
{
|
||||
throw;
|
||||
}
|
||||
catch (const std::exception& e)
|
||||
{
|
||||
throw std::runtime_error(e.what());
|
||||
TLLM_THROW("%s", e.what());
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
void waitForEmpty() const
|
||||
@ -357,8 +352,8 @@ public:
|
||||
}
|
||||
else
|
||||
{
|
||||
std::string warnStr = std::string("request Id ") + std::to_string(workItem->requestId())
|
||||
+ std::string(" has been stopped. Request is ignored.");
|
||||
auto warnStr = tc::fmtstr(
|
||||
"request Id %lu has been stopped. Request is ignored.", workItem->requestId());
|
||||
TLLM_LOG_WARNING(warnStr);
|
||||
sendResponse(workItem->requestId(), {}, true, warnStr);
|
||||
}
|
||||
@ -366,7 +361,7 @@ public:
|
||||
if (world_size > 1)
|
||||
{
|
||||
std::vector<int64_t> packed;
|
||||
for (auto ir : rval)
|
||||
for (auto const& ir : rval)
|
||||
{
|
||||
auto vpacked = ir->serialize();
|
||||
packed.push_back(static_cast<int64_t>(vpacked.size()));
|
||||
@ -400,10 +395,9 @@ public:
|
||||
return rval;
|
||||
}
|
||||
|
||||
void sendResponse(uint64_t requestId, std::list<NamedTensor> const& response_tensors, bool final_response,
|
||||
const std::string& errMsg)
|
||||
void sendResponse(uint64_t requestId, [[maybe_unused]] std::list<NamedTensor> const& response_tensors,
|
||||
bool final_response, [[maybe_unused]] const std::string& errMsg)
|
||||
{
|
||||
std::string errStr = std::string("Failed to send response for requestId: ") + std::to_string(requestId);
|
||||
try
|
||||
{
|
||||
if (final_response)
|
||||
@ -414,7 +408,7 @@ public:
|
||||
}
|
||||
catch (const std::exception& e)
|
||||
{
|
||||
TLLM_LOG_ERROR(errStr);
|
||||
TLLM_LOG_ERROR("Failed to send response for requestId: %ul\n%s", requestId, e.what());
|
||||
}
|
||||
}
|
||||
|
||||
@ -434,24 +428,43 @@ std::pair<std::vector<std::vector<int32_t>>, std::vector<int32_t>> parseDataset(
|
||||
{
|
||||
auto constexpr allowExceptions = true;
|
||||
auto constexpr ingoreComments = true;
|
||||
TLLM_CHECK_WITH_INFO(
|
||||
std::filesystem::exists(datasetPath), std::string("File does not exist: ") + datasetPath.string());
|
||||
TLLM_CHECK_WITH_INFO(std::filesystem::exists(datasetPath), "File does not exist: %s", datasetPath.string().c_str());
|
||||
std::ifstream jsonStream(datasetPath);
|
||||
auto json = nlohmann::json::parse(jsonStream, nullptr, allowExceptions, ingoreComments);
|
||||
|
||||
std::vector<std::vector<int32_t>> input_ids_list;
|
||||
std::vector<int32_t> output_ids_list;
|
||||
std::vector<std::vector<int32_t>> inputIds;
|
||||
std::vector<int32_t> outputIds;
|
||||
for (auto& sample : json)
|
||||
{
|
||||
input_ids_list.push_back(sample["input_ids"]);
|
||||
output_ids_list.push_back(sample["output_len"]);
|
||||
inputIds.push_back(sample["input_ids"]);
|
||||
outputIds.push_back(sample["output_len"]);
|
||||
}
|
||||
return std::make_pair(input_ids_list, output_ids_list);
|
||||
return std::make_pair(inputIds, outputIds);
|
||||
}
|
||||
|
||||
void benchmarkGptManager(std::string const& modelName, std::filesystem::path const& engineDir, std::string const& type,
|
||||
std::string const& datasetPath, int beamWidth, int warmUp, std::shared_ptr<nvinfer1::ILogger> const& logger,
|
||||
TrtGptModelOptionalParams const& optionalParams, batch_scheduler::SchedulerPolicy schedulerPolicy)
|
||||
std::shared_ptr<InferenceRequest> makeRequest(std::uint64_t reqId,
|
||||
std::pair<std::vector<std::vector<int32_t>>, std::vector<int32_t>> const& dataset, std::size_t sample_idx,
|
||||
ITensor::SharedPtr const& beamWidthTensor, ITensor::SharedPtr const& eosId, ITensor::SharedPtr const& padId,
|
||||
BufferManager const& bufferManager)
|
||||
{
|
||||
auto request = std::make_shared<InferenceRequest>(reqId);
|
||||
auto const& inputIds = dataset.first[sample_idx];
|
||||
request->setInputIds(bufferManager.copyFrom(
|
||||
inputIds, ITensor::makeShape({1, static_cast<SizeType>(inputIds.size())}), MemoryType::kPINNED));
|
||||
auto const request_output_len = dataset.second[sample_idx];
|
||||
request->setMaxNewTokens(
|
||||
bufferManager.copyFrom(&request_output_len, ITensor::makeShape({1, 1}), MemoryType::kPINNED));
|
||||
request->setBeamWidth(beamWidthTensor);
|
||||
request->setEndId(eosId);
|
||||
request->setPadId(padId);
|
||||
return request;
|
||||
}
|
||||
|
||||
void benchmarkGptManager([[maybe_unused]] std::string const& modelName, std::filesystem::path const& engineDir,
|
||||
std::string const& type, std::string const& datasetPath, int beamWidth, int warmUp,
|
||||
const std::optional<int32_t>& eosId, const std::optional<int32_t>& padId,
|
||||
std::shared_ptr<nvinfer1::ILogger> const& logger, TrtGptModelOptionalParams const& optionalParams,
|
||||
batch_scheduler::SchedulerPolicy schedulerPolicy)
|
||||
{
|
||||
auto const worldConfig = WorldConfig::mpi(*logger);
|
||||
|
||||
@ -466,57 +479,49 @@ void benchmarkGptManager(std::string const& modelName, std::filesystem::path con
|
||||
}
|
||||
else
|
||||
{
|
||||
const std::string errStr = std::string("Unexpected batching type: ") + type;
|
||||
TLLM_LOG_ERROR(errStr);
|
||||
TLLM_LOG_ERROR("Unexpected batching type: %s", type.c_str());
|
||||
}
|
||||
|
||||
ITensor::SharedPtr beamWidthBuffer = BufferManager::cpu(ITensor::makeShape({1}), nvinfer1::DataType::kINT32);
|
||||
auto beamWidthBufferPtr = bufferCast<SizeType>(*beamWidthBuffer);
|
||||
*beamWidthBufferPtr = beamWidth;
|
||||
auto beamWidthTensor = NamedTensor(beamWidthBuffer, "beam_width");
|
||||
BufferManager bufferManager{std::make_shared<CudaStream>()}; // the stream is not used
|
||||
|
||||
ITensor::SharedPtr beamWidthTensor{
|
||||
bufferManager.copyFrom(&beamWidth, ITensor::makeShape({1}), MemoryType::kPINNED)};
|
||||
|
||||
// Load dataset
|
||||
auto dataset = parseDataset(datasetPath);
|
||||
std::vector<std::vector<NamedTensor>> tensors_list;
|
||||
const auto num_samples = dataset.first.size();
|
||||
for (int i = 0; i < num_samples; ++i)
|
||||
{
|
||||
const auto input_ids = dataset.first[i];
|
||||
const auto request_output_len = dataset.second[i];
|
||||
std::vector<int64_t> input_ids_shape = {1, static_cast<int64_t>(input_ids.size())};
|
||||
auto input_ids_tensor = NamedTensor(nvinfer1::DataType::kINT32, input_ids_shape, "input_ids", input_ids.data());
|
||||
auto request_output_len_tensor
|
||||
= NamedTensor(nvinfer1::DataType::kINT32, {1, 1}, "request_output_len", &request_output_len);
|
||||
std::vector<NamedTensor> tensors
|
||||
= {std::move(input_ids_tensor), std::move(request_output_len_tensor), beamWidthTensor};
|
||||
tensors_list.emplace_back(std::move(tensors));
|
||||
}
|
||||
const auto numSamples = dataset.first.size();
|
||||
|
||||
const int maxBeamWidth = beamWidth;
|
||||
auto recorder = std::make_shared<Recorder>();
|
||||
uint64_t terminateReqId = num_samples + 1;
|
||||
uint64_t terminateReqId = numSamples + 1;
|
||||
auto gptServer = std::make_shared<GptServer>(
|
||||
engineDir, modelType, maxBeamWidth, schedulerPolicy, optionalParams, recorder, terminateReqId);
|
||||
|
||||
ITensor::SharedPtr eosIdTensor{
|
||||
eosId ? bufferManager.copyFrom(&eosId.value(), ITensor::makeShape({1}), MemoryType::kPINNED) : nullptr};
|
||||
ITensor::SharedPtr padIdTensor{
|
||||
padId ? bufferManager.copyFrom(&padId.value(), ITensor::makeShape({1}), MemoryType::kPINNED) : nullptr};
|
||||
|
||||
if (worldConfig.getRank() == 0)
|
||||
{
|
||||
// Warm up
|
||||
for (auto i = 1; i < warmUp + 1; ++i)
|
||||
SizeType reqId = 0;
|
||||
for (auto i = 0; i < warmUp; ++i)
|
||||
{
|
||||
// skip terminateReqId
|
||||
++reqId;
|
||||
if (i == terminateReqId)
|
||||
{
|
||||
i += 1;
|
||||
}
|
||||
gptServer->enqueue(tensors_list[0], i, false);
|
||||
++reqId;
|
||||
auto request = makeRequest(reqId, dataset, 0, beamWidthTensor, eosIdTensor, padIdTensor, bufferManager);
|
||||
gptServer->enqueue(request);
|
||||
}
|
||||
gptServer->waitForEmpty();
|
||||
|
||||
// Benchmark
|
||||
recorder->initialize();
|
||||
for (int i = 0; i < tensors_list.size(); ++i)
|
||||
for (std::size_t i = 0; i < numSamples; ++i)
|
||||
{
|
||||
gptServer->enqueue(tensors_list[i], 1 + i, false);
|
||||
auto request = makeRequest(i + 1, dataset, i, beamWidthTensor, eosIdTensor, padIdTensor, bufferManager);
|
||||
gptServer->enqueue(request);
|
||||
}
|
||||
gptServer->waitForEmpty();
|
||||
recorder->finalize();
|
||||
@ -524,7 +529,7 @@ void benchmarkGptManager(std::string const& modelName, std::filesystem::path con
|
||||
recorder->report();
|
||||
// Send terminateReqId to terminate servers on all ranks
|
||||
// Sever on rank 0 will broadcast the terminate signal to other servers on multi-GPU cases
|
||||
gptServer->enqueue({}, terminateReqId, false);
|
||||
gptServer->enqueue(std::make_shared<InferenceRequest>(terminateReqId));
|
||||
}
|
||||
// Wait until benchmarking is done and batch manager is terminated
|
||||
gptServer->waitBatchManager();
|
||||
@ -548,7 +553,8 @@ int main(int argc, char* argv[])
|
||||
"beam_width", "Specify beam width you want to benchmark.", cxxopts::value<int>()->default_value("1"));
|
||||
options.add_options()(
|
||||
"warm_up", "Specify warm up iterations before benchmark starts.", cxxopts::value<int>()->default_value("2"));
|
||||
|
||||
options.add_options()("eos_id", "Specify the end-of-sequence token id.", cxxopts::value<int>());
|
||||
options.add_options()("pad_id", "Specify the padding token id.", cxxopts::value<int>());
|
||||
options.add_options()("max_num_sequences", "Max number of Sequences.", cxxopts::value<int>());
|
||||
options.add_options()("max_tokens_in_paged_kvcache", "Max tokens in paged K-V Cache.", cxxopts::value<int>());
|
||||
options.add_options()(
|
||||
@ -609,6 +615,20 @@ int main(int argc, char* argv[])
|
||||
optionalParams.enableTrtOverlap = result["enable_trt_overlap"].as<bool>();
|
||||
}
|
||||
|
||||
std::optional<int32_t> padId;
|
||||
// Argument: Padding token id
|
||||
if (result.count("pad_id"))
|
||||
{
|
||||
padId = result["pad_id"].as<int>();
|
||||
}
|
||||
|
||||
std::optional<int32_t> eosId;
|
||||
// Argument: End-of-sentence token id
|
||||
if (result.count("eos_id"))
|
||||
{
|
||||
eosId = result["eos_id"].as<int>();
|
||||
}
|
||||
|
||||
// Argument: Scheduler policy
|
||||
batch_scheduler::SchedulerPolicy schedulerPolicy;
|
||||
auto const schedulerPolicyArg = result["scheduler_policy"].as<std::string>();
|
||||
@ -660,7 +680,7 @@ int main(int argc, char* argv[])
|
||||
try
|
||||
{
|
||||
benchmarkGptManager(result["model"].as<std::string>(), result["engine_dir"].as<std::string>(), type,
|
||||
datasetPath, beamWidth, result["warm_up"].as<int>(), logger, optionalParams, schedulerPolicy);
|
||||
datasetPath, beamWidth, result["warm_up"].as<int>(), eosId, padId, logger, optionalParams, schedulerPolicy);
|
||||
}
|
||||
catch (const std::exception& e)
|
||||
{
|
||||
|
||||
@ -40,13 +40,8 @@ void benchmarkGptSession(std::string const& modelName, std::filesystem::path con
|
||||
std::shared_ptr<nvinfer1::ILogger> const& logger, int warmUp, int numRuns, int duration,
|
||||
GptSession::Config& sessionConfig, bool cudaGraphMode, bool printAllLogits)
|
||||
{
|
||||
|
||||
std::string modelNameHyphen = modelName;
|
||||
std::filesystem::path jsonFileName = dataPath / "config.json";
|
||||
if (tc::strStartsWith(modelName, "chatglm") || tc::strStartsWith(modelName, "glm"))
|
||||
{
|
||||
jsonFileName = dataPath / (modelNameHyphen + std::string("-config.json"));
|
||||
}
|
||||
auto const json = GptJsonConfig::parse(jsonFileName);
|
||||
auto const modelConfig = json.getModelConfig();
|
||||
auto const inputPacked = modelConfig.usePackedInput();
|
||||
|
||||
@ -8,6 +8,7 @@ multiple GPUs or multiple nodes with multiple GPUs.
|
||||
The benchmark implementation and entrypoint can be found in [`benchmarks/python/benchmark.py`](./benchmark.py). There are some other scripts in the directory:
|
||||
|
||||
* [`benchmarks/python/allowed_configs.py`](./allowed_configs.py) to define configuration for each supported model.
|
||||
* [`benchmarks/python/build.py`](./build.py) to build supported models for benchmarking.
|
||||
* [`benchmarks/python/base_benchmark.py`](./base_benchmark.py) to implement the base class for benchmark.
|
||||
* [`benchmarks/python/gpt_benchmark.py`](./gpt_benchmark.py) to implement benchmark scripts for GPT and GPT-like(LLaMA/OPT/GPT-J/SmoothQuant-GPT) models.
|
||||
* [`benchmarks/python/bert_benchmark.py`](./bert_benchmark.py) to implement benchmark scripts for BERT models.
|
||||
@ -30,9 +31,9 @@ python benchmark.py \
|
||||
```
|
||||
Expected outputs:
|
||||
```
|
||||
[BENCHMARK] model_name gpt_350m world_size 1 num_heads 16 num_layers 24 hidden_size 1024 vocab_size 51200 precision float16 batch_size 1 input_length 60 output_length 20 build_time(s) 89.8 tokens_per_sec 378.12 percentile95(ms) 53.284 percentile99(ms) 53.284 latency(ms) 52.893
|
||||
[BENCHMARK] model_name gpt_350m world_size 1 num_heads 16 num_layers 24 hidden_size 1024 vocab_size 51200 precision float16 batch_size 8 input_length 60 output_length 20 build_time(s) 89.8 tokens_per_sec 361.06 percentile95(ms) 55.739 percentile99(ms) 55.739 latency(ms) 55.392
|
||||
[BENCHMARK] model_name gpt_350m world_size 1 num_heads 16 num_layers 24 hidden_size 1024 vocab_size 51200 precision float16 batch_size 64 input_length 60 output_length 20 build_time(s) 89.8 tokens_per_sec 246.03 percentile95(ms) 81.533 percentile99(ms) 81.533 latency(ms) 81.29
|
||||
[BENCHMARK] model_name gpt_350m world_size 1 num_heads 16 num_kv_heads 16 num_layers 24 hidden_size 1024 vocab_size 51200 precision float16 batch_size 1 input_length 60 output_length 20 gpu_peak_mem(gb) 4.2 build_time(s) 25.67 tokens_per_sec 483.54 percentile95(ms) 41.537 percentile99(ms) 42.102 latency(ms) 41.362 compute_cap sm80
|
||||
[BENCHMARK] model_name gpt_350m world_size 1 num_heads 16 num_kv_heads 16 num_layers 24 hidden_size 1024 vocab_size 51200 precision float16 batch_size 8 input_length 60 output_length 20 gpu_peak_mem(gb) 4.28 build_time(s) 25.67 tokens_per_sec 3477.28 percentile95(ms) 46.129 percentile99(ms) 46.276 latency(ms) 46.013 compute_cap sm80
|
||||
[BENCHMARK] model_name gpt_350m world_size 1 num_heads 16 num_kv_heads 16 num_layers 24 hidden_size 1024 vocab_size 51200 precision float16 batch_size 64 input_length 60 output_length 20 gpu_peak_mem(gb) 4.8 build_time(s) 25.67 tokens_per_sec 19698.07 percentile95(ms) 65.739 percentile99(ms) 65.906 latency(ms) 64.981 compute_cap sm80
|
||||
...
|
||||
```
|
||||
*Please note that the expected outputs is only for reference, specific performance numbers depend on the GPU you're using.*
|
||||
|
||||
@ -30,18 +30,17 @@ class BuildConfig(BaseModel, extra=Extra.allow):
|
||||
max_input_len: int
|
||||
num_kv_heads: Optional[int] = None
|
||||
max_output_len: Optional[int] = None
|
||||
max_beam_width: int = 1
|
||||
# TRT builder_optimization_level from 0 to 5
|
||||
builder_opt: Optional[int] = None
|
||||
inter_size: Optional[int] = None
|
||||
rotary_dim: Optional[int] = None
|
||||
type_vocab_size: Optional[int] = None
|
||||
use_smooth_quant: bool = False
|
||||
per_token: bool = False
|
||||
per_channel: bool = False
|
||||
pre_norm: Optional[bool] = None
|
||||
do_layer_norm_before: Optional[bool] = None
|
||||
enable_qk_half_accum: bool = False
|
||||
enable_context_fmha: bool = True
|
||||
enable_multi_block_mode: bool = False
|
||||
# None means using the model family's default value defined in the ctor
|
||||
position_embedding_type: Optional[PositionEmbeddingType] = None
|
||||
# Only when position embedding is RoPE, this value makes sense, make
|
||||
@ -49,6 +48,10 @@ class BuildConfig(BaseModel, extra=Extra.allow):
|
||||
rotary_pct: Optional[float] = None
|
||||
bias: bool = True
|
||||
quantization: Optional[str] = None
|
||||
# use_custom_all_reduce gives better performance with NVLink
|
||||
use_custom_all_reduce: bool = True
|
||||
moe_num_experts: int = None
|
||||
moe_top_k: int = None
|
||||
|
||||
|
||||
class ModelConfig(BaseModel):
|
||||
@ -107,6 +110,24 @@ _allowed_configs = {
|
||||
max_output_len=200,
|
||||
builder_opt=None,
|
||||
)),
|
||||
"gpt_350m_moe":
|
||||
ModelConfig(name="gpt_350m_moe",
|
||||
family="gpt",
|
||||
benchmark_type="gpt",
|
||||
build_config=BuildConfig(
|
||||
num_layers=24,
|
||||
num_heads=16,
|
||||
hidden_size=1024,
|
||||
vocab_size=51200,
|
||||
hidden_act='gelu',
|
||||
n_positions=1024,
|
||||
max_batch_size=256,
|
||||
max_input_len=512,
|
||||
max_output_len=200,
|
||||
builder_opt=None,
|
||||
moe_num_experts=8,
|
||||
moe_top_k=1,
|
||||
)),
|
||||
"gpt_350m_sq_per_tensor":
|
||||
ModelConfig(name="gpt_350m_sq_per_tensor",
|
||||
family="gpt",
|
||||
@ -301,6 +322,40 @@ _allowed_configs = {
|
||||
max_output_len=200,
|
||||
builder_opt=None,
|
||||
)),
|
||||
"llama_70b_long_context":
|
||||
ModelConfig(name="llama_70b_long_context",
|
||||
family="llama",
|
||||
benchmark_type="gpt",
|
||||
build_config=BuildConfig(num_layers=80,
|
||||
num_heads=64,
|
||||
num_kv_heads=8,
|
||||
hidden_size=8192,
|
||||
vocab_size=32000,
|
||||
hidden_act='silu',
|
||||
n_positions=2048,
|
||||
inter_size=28672,
|
||||
max_batch_size=16,
|
||||
max_input_len=8000,
|
||||
max_output_len=200,
|
||||
builder_opt=None,
|
||||
enable_multi_block_mode=True)),
|
||||
"llama_70b_long_generation":
|
||||
ModelConfig(name="llama_70b_long_generation",
|
||||
family="llama",
|
||||
benchmark_type="gpt",
|
||||
build_config=BuildConfig(num_layers=80,
|
||||
num_heads=64,
|
||||
num_kv_heads=8,
|
||||
hidden_size=8192,
|
||||
vocab_size=32000,
|
||||
hidden_act='silu',
|
||||
n_positions=2048,
|
||||
inter_size=28672,
|
||||
max_batch_size=64,
|
||||
max_input_len=200,
|
||||
max_output_len=16384,
|
||||
builder_opt=None,
|
||||
enable_multi_block_mode=True)),
|
||||
"llama_70b_sq_per_tensor":
|
||||
ModelConfig(name="llama_70b_sq_per_tensor",
|
||||
family="llama",
|
||||
|
||||
@ -15,11 +15,13 @@
|
||||
import json
|
||||
import os
|
||||
import subprocess
|
||||
import time
|
||||
from collections import OrderedDict
|
||||
|
||||
import torch
|
||||
|
||||
import tensorrt_llm
|
||||
from tensorrt_llm.logger import logger
|
||||
from tensorrt_llm.quantization import QuantMode
|
||||
|
||||
|
||||
@ -44,23 +46,28 @@ def get_engine_name(model, dtype, tp_size, rank):
|
||||
|
||||
|
||||
def serialize_engine(engine, path):
|
||||
logger.info(f'Serializing engine to {path}...')
|
||||
tik = time.time()
|
||||
with open(path, 'wb') as f:
|
||||
# engine object is already complies with python buffer protocol, no need to
|
||||
# convert it to bytearray before write, converting to bytearray consumes lots of memory
|
||||
f.write(engine)
|
||||
tok = time.time()
|
||||
t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))
|
||||
logger.info(f'Engine serialized. Total time: {t}')
|
||||
|
||||
|
||||
class BaseBenchmark(object):
|
||||
|
||||
def __init__(self, engine_dir, model_name, dtype, output_dir):
|
||||
def __init__(self, engine_dir, model_name, dtype):
|
||||
self.engine_dir = engine_dir
|
||||
self.model_name = model_name
|
||||
self.dtype = dtype
|
||||
self.output_dir = output_dir
|
||||
self.runtime_rank = tensorrt_llm.mpi_rank()
|
||||
self.world_size = tensorrt_llm.mpi_world_size()
|
||||
self.engine_model_name = model_name
|
||||
self.quant_mode = QuantMode(0)
|
||||
self.enable_fp8 = False
|
||||
if engine_dir is not None:
|
||||
# Read config from engine directory
|
||||
config_path = os.path.join(engine_dir, 'config.json')
|
||||
@ -98,7 +105,7 @@ class BaseBenchmark(object):
|
||||
|
||||
self.csv_filename = "" # lazy init
|
||||
|
||||
def get_report_dict(self):
|
||||
def get_report_dict(self, benchmark_profiler=None):
|
||||
report_fields = [
|
||||
"model_name", "world_size", "num_heads", "num_kv_heads",
|
||||
"num_layers", "hidden_size", "vocab_size", "precision",
|
||||
@ -122,9 +129,9 @@ class BaseBenchmark(object):
|
||||
fp8linear=int(self.enable_fp8))
|
||||
return self.csv_filename
|
||||
|
||||
def print_report_header(self, csv=False):
|
||||
def print_report_header(self, csv=False, benchmark_profiler=None):
|
||||
if csv and self.runtime_rank == 0:
|
||||
report_dict = self.get_report_dict()
|
||||
report_dict = self.get_report_dict(benchmark_profiler)
|
||||
line = ",".join(report_dict.keys())
|
||||
print(line)
|
||||
with open(self.get_csv_filename(), "a") as file:
|
||||
@ -136,7 +143,7 @@ class BaseBenchmark(object):
|
||||
def prepare_inputs(self, config):
|
||||
raise NotImplementedError
|
||||
|
||||
def run(self, inputs, config):
|
||||
def run(self, inputs, config, benchmark_profiler=None):
|
||||
raise NotImplementedError
|
||||
|
||||
def report(self, config, latency):
|
||||
|
||||
@ -118,7 +118,7 @@ def parse_arguments():
|
||||
type=str,
|
||||
default=None,
|
||||
help=
|
||||
'If this option is specified, TensorRT engines will be saved to engine_dir.'
|
||||
'If this option is specified, TensorRT engines will be saved to the specified path.'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--engine_dir',
|
||||
@ -128,37 +128,33 @@ def parse_arguments():
|
||||
('If this option is specified, instead of building engines on-air before benchmarking, '
|
||||
'the engines contained in the engine_dir will be used.'))
|
||||
parser.add_argument(
|
||||
'--n_positions',
|
||||
'--max_beam_width',
|
||||
type=int,
|
||||
default=None,
|
||||
help=
|
||||
('If this option is specified, it will override the n_positions of TRT engines to the specified value instead of using pre-defined one'
|
||||
'By default when this option is not used, it will use pre-defined n_positions'
|
||||
))
|
||||
('If this option is specified, it will override the max beam width of '
|
||||
'TRT engines to the specified value instead of using pre-defined one'))
|
||||
parser.add_argument(
|
||||
'--max_input_len',
|
||||
type=int,
|
||||
default=None,
|
||||
help=
|
||||
('If this option is specified, it will override the max input len of TRT engines to the specified value instead of using pre-defined one'
|
||||
'By default when this option is not used, it will use pre-defined max input len'
|
||||
))
|
||||
('If this option is specified, it will override the max input len of '
|
||||
'TRT engines to the specified value instead of using pre-defined one'))
|
||||
parser.add_argument(
|
||||
'--max_output_len',
|
||||
type=int,
|
||||
default=None,
|
||||
help=
|
||||
('If this option is specified, it will override the max output len of TRT engines to the specified value instead of using pre-defined one'
|
||||
'By default when this option is not used, it will use pre-defined max output len'
|
||||
))
|
||||
('If this option is specified, it will override the max output len of '
|
||||
'TRT engines to the specified value instead of using pre-defined one'))
|
||||
parser.add_argument(
|
||||
'--max_batch_size',
|
||||
type=int,
|
||||
default=None,
|
||||
help=
|
||||
('If this option is specified, it will override the max batch size of TRT engines to the specified value instead of using pre-defined one'
|
||||
'By default when this option is not used, it will use pre-defined max batch size'
|
||||
))
|
||||
('If this option is specified, it will override the max batch size of '
|
||||
'TRT engines to the specified value instead of using pre-defined one'))
|
||||
parser.add_argument(
|
||||
'--force_num_layer_1',
|
||||
default=False,
|
||||
@ -175,20 +171,6 @@ def parse_arguments():
|
||||
default=False,
|
||||
action='store_true',
|
||||
help='Execute GPT session with CUDA graph.')
|
||||
parser.add_argument(
|
||||
'--enable_custom_all_reduce',
|
||||
default=False,
|
||||
action='store_true',
|
||||
help=
|
||||
'Use latency-optimized all-reduce for tensor parallelism. Gives better performance with NVLink.'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--strongly_typed',
|
||||
default=False,
|
||||
action='store_true',
|
||||
help=
|
||||
'This option is introduced with trt 9.1.0.1+ and will reduce the building time significantly for fp8.'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--quantization',
|
||||
type=str,
|
||||
@ -209,6 +191,7 @@ def main(args):
|
||||
# the start method `spawn` of Python multiprocessing,
|
||||
# so we set the start method first, then initialize MPI.
|
||||
from allowed_configs import get_allowed_models
|
||||
from benchmark_profiler import BenchmarkProfiler
|
||||
from bert_benchmark import BERTBenchmark
|
||||
from gpt_benchmark import GPTBenchmark
|
||||
from mem_monitor import MemoryMonitor
|
||||
@ -228,47 +211,19 @@ def main(args):
|
||||
in_out_len_options = [[int(i) for i in io.split(',')]
|
||||
for io in in_out_len_options]
|
||||
|
||||
benchmark_profiler = None
|
||||
if args.model in get_allowed_models(benchmark_type="gpt"):
|
||||
benchmarker = GPTBenchmark(
|
||||
args.engine_dir,
|
||||
args.model,
|
||||
args.mode,
|
||||
batch_size_options,
|
||||
in_out_len_options,
|
||||
args.dtype,
|
||||
args.refit,
|
||||
args.num_beams,
|
||||
args.top_k,
|
||||
args.top_p,
|
||||
args.output_dir,
|
||||
args.n_positions,
|
||||
args.max_input_len,
|
||||
args.max_output_len,
|
||||
args.max_batch_size,
|
||||
force_num_layer_1=args.force_num_layer_1,
|
||||
enable_cuda_graph=args.enable_cuda_graph,
|
||||
enable_custom_all_reduce=args.enable_custom_all_reduce,
|
||||
strongly_typed=args.strongly_typed,
|
||||
quantization=args.quantization)
|
||||
benchmark_profiler = BenchmarkProfiler()
|
||||
benchmarker = GPTBenchmark(args, batch_size_options, in_out_len_options)
|
||||
elif args.model in get_allowed_models(benchmark_type="bert"):
|
||||
benchmarker = BERTBenchmark(args.engine_dir,
|
||||
args.model,
|
||||
args.mode,
|
||||
batch_size_options,
|
||||
input_len_options,
|
||||
args.dtype,
|
||||
args.output_dir,
|
||||
args.n_positions,
|
||||
args.max_input_len,
|
||||
args.max_output_len,
|
||||
args.max_batch_size,
|
||||
force_num_layer_1=args.force_num_layer_1)
|
||||
benchmarker = BERTBenchmark(args, batch_size_options, input_len_options)
|
||||
else:
|
||||
raise Exception(f'Unexpected model: {args.model}')
|
||||
|
||||
start = torch.cuda.Event(enable_timing=True)
|
||||
end = torch.cuda.Event(enable_timing=True)
|
||||
benchmarker.print_report_header(args.csv)
|
||||
benchmarker.print_report_header(args.csv,
|
||||
benchmark_profiler=benchmark_profiler)
|
||||
for config in benchmarker.get_config():
|
||||
try:
|
||||
inputs = benchmarker.prepare_inputs(config)
|
||||
@ -290,12 +245,16 @@ def main(args):
|
||||
for _ in range(args.warm_up):
|
||||
benchmarker.run(inputs, config)
|
||||
logger.info('Warm up done. Start benchmarking.')
|
||||
|
||||
if benchmark_profiler is not None:
|
||||
benchmark_profiler.clean()
|
||||
benchmark_profiler.start()
|
||||
cur_duration = 0
|
||||
start_time = time()
|
||||
while iter_idx < args.num_runs or cur_duration < args.duration:
|
||||
start.record()
|
||||
benchmarker.run(inputs, config)
|
||||
benchmarker.run(inputs,
|
||||
config,
|
||||
benchmark_profiler=benchmark_profiler)
|
||||
end.record()
|
||||
|
||||
torch.cuda.synchronize()
|
||||
@ -315,6 +274,9 @@ def main(args):
|
||||
memory_monitor.stop()
|
||||
_, peak_gpu_used = memory_monitor.get_peak_memory_usage("GiB")
|
||||
peak_gpu_used = round(peak_gpu_used, 3)
|
||||
if benchmark_profiler is not None:
|
||||
benchmark_profiler.add_aux_info('iter_count', iter_idx)
|
||||
benchmark_profiler.stop()
|
||||
|
||||
latency = round(sum(latencies) / iter_idx, 3)
|
||||
latencies.sort()
|
||||
@ -325,7 +287,8 @@ def main(args):
|
||||
percentile95,
|
||||
percentile99,
|
||||
peak_gpu_used,
|
||||
csv=args.csv)
|
||||
csv=args.csv,
|
||||
benchmark_profiler=benchmark_profiler)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
77
benchmarks/python/benchmark_profiler.py
Normal file
77
benchmarks/python/benchmark_profiler.py
Normal file
@ -0,0 +1,77 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class BenchmarkProfiler(object):
|
||||
cuda_event_dict: dict
|
||||
timer_dict: dict
|
||||
aux_info: dict
|
||||
started: bool
|
||||
|
||||
def __init__(self):
|
||||
self.cuda_event_dict = {}
|
||||
self.timer_dict = {}
|
||||
self.aux_info = {}
|
||||
self.started = False
|
||||
|
||||
def clean(self):
|
||||
self.cuda_event_dict = {}
|
||||
self.timer_dict = {}
|
||||
self.aux_info = {}
|
||||
|
||||
def start(self):
|
||||
self.started = True
|
||||
|
||||
def stop(self):
|
||||
self.started = False
|
||||
|
||||
def get_cuda_event(self, name: str):
|
||||
if name not in self.cuda_event_dict.keys():
|
||||
event = torch.cuda.Event(enable_timing=True)
|
||||
self.cuda_event_dict[name] = event
|
||||
return self.cuda_event_dict[name]
|
||||
|
||||
def record_cuda_event(self, name: str):
|
||||
if not self.started:
|
||||
return
|
||||
event = self.get_cuda_event(name)
|
||||
event.record()
|
||||
|
||||
def get_timer_value(self, timer_name: str):
|
||||
# timer is in milliseconds
|
||||
return self.timer_dict[timer_name]
|
||||
|
||||
def record_elapsed_time(self, start_event_name: str, end_event_name: str,
|
||||
timer_name: str):
|
||||
if timer_name not in self.timer_dict.keys():
|
||||
self.timer_dict[timer_name] = 0.0
|
||||
if not self.started:
|
||||
return
|
||||
self.get_cuda_event(start_event_name).synchronize()
|
||||
self.get_cuda_event(end_event_name).synchronize()
|
||||
self.timer_dict[timer_name] += self.get_cuda_event(
|
||||
start_event_name).elapsed_time(self.get_cuda_event(end_event_name))
|
||||
|
||||
def get_aux_info(self, aux_name):
|
||||
return self.aux_info[aux_name]
|
||||
|
||||
def add_aux_info(self, aux_name: str, add_value):
|
||||
if aux_name not in self.aux_info.keys():
|
||||
self.aux_info[aux_name] = 0
|
||||
if not self.started:
|
||||
return
|
||||
self.aux_info[aux_name] += add_value
|
||||
@ -14,86 +14,45 @@
|
||||
# limitations under the License.
|
||||
import os
|
||||
import time
|
||||
from collections import OrderedDict
|
||||
|
||||
# isort: off
|
||||
import torch
|
||||
import tensorrt as trt
|
||||
#isort: on
|
||||
from allowed_configs import get_build_config
|
||||
from base_benchmark import BaseBenchmark, serialize_engine
|
||||
from base_benchmark import BaseBenchmark
|
||||
from build import build_bert
|
||||
|
||||
import tensorrt_llm
|
||||
from tensorrt_llm._utils import str_dtype_to_trt, trt_dtype_to_torch
|
||||
from tensorrt_llm.builder import Builder
|
||||
from tensorrt_llm.network import net_guard
|
||||
from tensorrt_llm.plugin.plugin import ContextFMHAType
|
||||
from tensorrt_llm._utils import trt_dtype_to_torch
|
||||
from tensorrt_llm.runtime import TensorInfo
|
||||
|
||||
|
||||
class BERTBenchmark(BaseBenchmark):
|
||||
|
||||
def __init__(self,
|
||||
engine_dir,
|
||||
model_name,
|
||||
mode,
|
||||
batch_sizes,
|
||||
in_lens,
|
||||
dtype,
|
||||
output_dir,
|
||||
n_positions=None,
|
||||
max_input_len=None,
|
||||
max_output_len=None,
|
||||
max_batch_size=None,
|
||||
**kwargs):
|
||||
super().__init__(engine_dir, model_name, dtype, output_dir)
|
||||
def __init__(self, args, batch_sizes, in_lens):
|
||||
super().__init__(args.engine_dir, args.model, args.dtype)
|
||||
self.batch_sizes = batch_sizes
|
||||
self.in_lens = in_lens
|
||||
self.build_time = 0
|
||||
self.mode = args.mode
|
||||
|
||||
if engine_dir is not None:
|
||||
if args.engine_dir is not None:
|
||||
# Deserialize engine from engine directory
|
||||
self.serialize_path = os.path.join(engine_dir, self.engine_name)
|
||||
self.serialize_path = os.path.join(args.engine_dir,
|
||||
self.engine_name)
|
||||
with open(self.serialize_path, 'rb') as f:
|
||||
engine_buffer = f.read()
|
||||
else:
|
||||
# Build engine
|
||||
self.use_bert_attention_plugin = False
|
||||
self.use_gemm_plugin = False
|
||||
self.use_layernorm_plugin = False
|
||||
self.enable_qk_half_accum = False
|
||||
self.enable_context_fmha = False
|
||||
if mode == 'plugin':
|
||||
self.use_bert_attention_plugin = dtype
|
||||
self.use_gemm_plugin = dtype
|
||||
self.use_layernorm_plugin = dtype
|
||||
for key, value in get_build_config(model_name).items():
|
||||
for key, value in get_build_config(args.model).items():
|
||||
setattr(self, key, value)
|
||||
# Override the n_positions/max_input_len/max_output_len/max_batch_size to value from cmd line if that's specified.
|
||||
if n_positions is not None:
|
||||
assert isinstance(
|
||||
n_positions, int
|
||||
) and n_positions > 0, f"n_positions should be a valid int number, got {n_positions}"
|
||||
self.n_positions = n_positions
|
||||
if max_input_len is not None:
|
||||
assert isinstance(
|
||||
max_input_len, int
|
||||
) and max_input_len > 0, f"max_input_len should be a valid int number, got {max_input_len}"
|
||||
self.max_input_len = max_input_len
|
||||
if max_output_len is not None:
|
||||
assert isinstance(
|
||||
max_output_len, int
|
||||
) and max_output_len > 0, f"max_output_len should be a valid int number, got {max_output_len}"
|
||||
self.max_output_len = max_output_len
|
||||
if max_batch_size is not None:
|
||||
assert isinstance(
|
||||
max_batch_size, int
|
||||
) and max_batch_size > 0, f"max_batch_size should be a valid int number, got {max_batch_size}"
|
||||
self.max_batch_size = max_batch_size
|
||||
if kwargs.get('force_num_layer_1', False):
|
||||
if args.force_num_layer_1:
|
||||
self.num_layers = 1
|
||||
|
||||
engine_buffer = self.build()
|
||||
start = time.time()
|
||||
engine_buffer = build_bert(args)
|
||||
self.build_time = round(time.time() - start, 2)
|
||||
|
||||
assert engine_buffer is not None
|
||||
|
||||
@ -128,99 +87,7 @@ class BERTBenchmark(BaseBenchmark):
|
||||
stream = torch.cuda.current_stream().cuda_stream
|
||||
return (inputs, outputs, stream)
|
||||
|
||||
def build(self):
|
||||
bs_range = [1, (self.max_batch_size + 1) // 2, self.max_batch_size]
|
||||
inlen_range = [1, (self.max_input_len + 1) // 2, self.max_input_len]
|
||||
|
||||
builder = Builder()
|
||||
builder_config = builder.create_builder_config(
|
||||
name=self.model_name,
|
||||
precision=self.dtype,
|
||||
timing_cache=None,
|
||||
tensor_parallel=self.world_size, # TP only
|
||||
parallel_build=True,
|
||||
num_layers=self.num_layers,
|
||||
num_heads=self.num_heads,
|
||||
num_kv_heads=self.num_heads,
|
||||
hidden_size=self.hidden_size,
|
||||
vocab_size=self.vocab_size,
|
||||
hidden_act=self.hidden_act,
|
||||
max_position_embeddings=self.n_positions,
|
||||
max_batch_size=self.max_batch_size,
|
||||
max_input_len=self.max_input_len,
|
||||
opt_level=self.builder_opt)
|
||||
# Initialize model
|
||||
tensorrt_llm_bert = tensorrt_llm.models.BertModel(
|
||||
num_layers=self.num_layers,
|
||||
num_heads=self.num_heads,
|
||||
hidden_size=self.hidden_size,
|
||||
vocab_size=self.vocab_size,
|
||||
hidden_act=self.hidden_act,
|
||||
max_position_embeddings=self.n_positions,
|
||||
type_vocab_size=self.type_vocab_size,
|
||||
mapping=tensorrt_llm.Mapping(world_size=self.world_size,
|
||||
tp_size=self.world_size))
|
||||
|
||||
# Module -> Network
|
||||
network = builder.create_network()
|
||||
if self.use_bert_attention_plugin:
|
||||
network.plugin_config.set_bert_attention_plugin(
|
||||
dtype=self.use_bert_attention_plugin)
|
||||
if self.use_gemm_plugin:
|
||||
network.plugin_config.set_gemm_plugin(dtype=self.use_gemm_plugin)
|
||||
if self.use_layernorm_plugin:
|
||||
network.plugin_config.set_layernorm_plugin(
|
||||
dtype=self.use_layernorm_plugin)
|
||||
if self.enable_qk_half_accum:
|
||||
network.plugin_config.enable_qk_half_accum()
|
||||
if self.enable_context_fmha:
|
||||
network.plugin_config.set_context_fmha(ContextFMHAType.enabled)
|
||||
if self.world_size > 1:
|
||||
network.plugin_config.set_nccl_plugin(self.dtype)
|
||||
with net_guard(network):
|
||||
# Prepare
|
||||
network.set_named_parameters(tensorrt_llm_bert.named_parameters())
|
||||
|
||||
# Forward
|
||||
input_ids = tensorrt_llm.Tensor(
|
||||
name='input_ids',
|
||||
dtype=trt.int32,
|
||||
shape=[-1, -1],
|
||||
dim_range=OrderedDict([('batch_size', [bs_range]),
|
||||
('input_len', [inlen_range])]),
|
||||
)
|
||||
input_lengths = tensorrt_llm.Tensor(name='input_lengths',
|
||||
dtype=trt.int32,
|
||||
shape=[-1],
|
||||
dim_range=OrderedDict([
|
||||
('batch_size', [bs_range])
|
||||
]))
|
||||
hidden_states = tensorrt_llm_bert(input_ids=input_ids,
|
||||
input_lengths=input_lengths)
|
||||
|
||||
# Mark outputs
|
||||
hidden_states_dtype = str_dtype_to_trt(self.dtype)
|
||||
hidden_states.mark_output('hidden_states', hidden_states_dtype)
|
||||
|
||||
# Network -> Engine
|
||||
start = time.time()
|
||||
engine = builder.build_engine(network, builder_config)
|
||||
end = time.time()
|
||||
self.build_time = round(end - start, 2)
|
||||
|
||||
if self.output_dir is not None:
|
||||
if not os.path.exists(self.output_dir):
|
||||
os.makedirs(self.output_dir)
|
||||
self.serialize_path = os.path.join(self.output_dir,
|
||||
self.engine_name)
|
||||
serialize_engine(engine, self.serialize_path)
|
||||
if self.runtime_rank == 0:
|
||||
config_path = os.path.join(self.output_dir, 'config.json')
|
||||
builder_config.plugin_config = network.plugin_config
|
||||
builder.save_config(builder_config, config_path)
|
||||
return engine
|
||||
|
||||
def run(self, inputs, config):
|
||||
def run(self, inputs, config, benchmark_profiler=None):
|
||||
ok = self.session.run(*inputs)
|
||||
assert ok, "Runtime execution failed"
|
||||
torch.cuda.synchronize()
|
||||
@ -235,8 +102,14 @@ class BERTBenchmark(BaseBenchmark):
|
||||
f'percentile99(ms) {percentile99} latency(ms) {latency}')
|
||||
print(line)
|
||||
|
||||
def report(self, config, latency, percentile95, percentile99, peak_gpu_used,
|
||||
csv):
|
||||
def report(self,
|
||||
config,
|
||||
latency,
|
||||
percentile95,
|
||||
percentile99,
|
||||
peak_gpu_used,
|
||||
csv,
|
||||
benchmark_profiler=None):
|
||||
report_dict = super().get_report_dict()
|
||||
batch_size, inlen = config[0], config[1]
|
||||
report_dict["num_heads"] = self.num_heads
|
||||
|
||||
606
benchmarks/python/build.py
Normal file
606
benchmarks/python/build.py
Normal file
@ -0,0 +1,606 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import argparse
|
||||
import multiprocessing as mp
|
||||
import os
|
||||
from collections import OrderedDict
|
||||
|
||||
import tensorrt as trt
|
||||
import torch
|
||||
from allowed_configs import (get_allowed_models, get_build_config,
|
||||
get_model_family)
|
||||
from base_benchmark import get_engine_name, serialize_engine
|
||||
|
||||
import tensorrt_llm
|
||||
from tensorrt_llm._utils import str_dtype_to_trt
|
||||
from tensorrt_llm.builder import Builder
|
||||
from tensorrt_llm.layers import PositionEmbeddingType
|
||||
from tensorrt_llm.logger import logger
|
||||
from tensorrt_llm.models import quantize_model
|
||||
from tensorrt_llm.network import net_guard
|
||||
from tensorrt_llm.plugin.plugin import ContextFMHAType
|
||||
from tensorrt_llm.quantization import QuantMode
|
||||
|
||||
|
||||
def parse_arguments():
|
||||
parser = argparse.ArgumentParser(description='Build TensorRT-LLM models.')
|
||||
parser.add_argument('-m',
|
||||
'--model',
|
||||
type=str,
|
||||
required=True,
|
||||
choices=get_allowed_models(),
|
||||
help='Specify model you want to build.')
|
||||
parser.add_argument(
|
||||
'--mode',
|
||||
type=str,
|
||||
default="plugin",
|
||||
choices=['ootb', 'plugin', 'ootb-except-mha'],
|
||||
help=
|
||||
('Choose mode between ootb/plugin/ootb-except-mha. '
|
||||
'\"ootb\" means the engines will be built without any plugins, '
|
||||
'\"plugin\" means the engines will be built with tuned recipe of using plugins.'
|
||||
'\"ootb-except-mha\" means the engines will be built with only attention plugins.'
|
||||
))
|
||||
|
||||
parser.add_argument(
|
||||
'--dtype',
|
||||
type=str,
|
||||
default='float16',
|
||||
choices=['float16', 'bfloat16', 'float32'],
|
||||
help='Choose data type between float16/bfloat16/float32.')
|
||||
parser.add_argument(
|
||||
'--quantization',
|
||||
type=str,
|
||||
default=None,
|
||||
choices=[
|
||||
'fp8', 'fp8_gemm', 'fp8_kv_cache', 'int8_sq_per_tensor',
|
||||
'int8_sq_per_token_channel', 'int8_weight_only', 'int4_weight_only',
|
||||
'int4_weight_only_awq', 'int4_weight_only_gptq'
|
||||
],
|
||||
help="Optimize the model with specified quantization recipe")
|
||||
|
||||
parser.add_argument(
|
||||
'--log_level',
|
||||
type=str,
|
||||
default="error",
|
||||
choices=['verbose', 'info', 'warning', 'error', 'internal_error'],
|
||||
help=
|
||||
'Choose log level between verbose/info/warning/error/internal_error.')
|
||||
|
||||
parser.add_argument(
|
||||
'--output_dir',
|
||||
type=str,
|
||||
required=True,
|
||||
help='TensorRT engines will be saved to the specified path.')
|
||||
|
||||
parser.add_argument(
|
||||
'--max_beam_width',
|
||||
type=int,
|
||||
default=None,
|
||||
help=
|
||||
('If this option is specified, it will override the max beam width of '
|
||||
'TRT engines to the specified value instead of using pre-defined one'))
|
||||
parser.add_argument(
|
||||
'--max_input_len',
|
||||
type=int,
|
||||
default=None,
|
||||
help=
|
||||
('If this option is specified, it will override the max input len of '
|
||||
'TRT engines to the specified value instead of using pre-defined one'))
|
||||
parser.add_argument(
|
||||
'--max_output_len',
|
||||
type=int,
|
||||
default=None,
|
||||
help=
|
||||
('If this option is specified, it will override the max output len of '
|
||||
'TRT engines to the specified value instead of using pre-defined one'))
|
||||
parser.add_argument(
|
||||
'--max_batch_size',
|
||||
type=int,
|
||||
default=None,
|
||||
help=
|
||||
('If this option is specified, it will override the max batch size of '
|
||||
'TRT engines to the specified value instead of using pre-defined one'))
|
||||
parser.add_argument('--force_num_layer_1',
|
||||
default=False,
|
||||
action='store_true',
|
||||
help='Quick sanity check with num_layer=1.')
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def get_quant_mode(quantization):
|
||||
quant_mode = QuantMode(0)
|
||||
strongly_typed = False
|
||||
use_smooth_quant = False
|
||||
per_token = False
|
||||
per_channel = False
|
||||
weight_only_precision = 'int8'
|
||||
|
||||
if quantization == "fp8":
|
||||
strongly_typed = True
|
||||
quant_mode = quant_mode.set_fp8_qdq()
|
||||
quant_mode = quant_mode.set_fp8_kv_cache()
|
||||
|
||||
elif quantization == "fp8_gemm":
|
||||
strongly_typed = True
|
||||
quant_mode = quant_mode.set_fp8_qdq()
|
||||
|
||||
elif quantization == "fp8_kv_cache":
|
||||
strongly_typed = True
|
||||
quant_mode = quant_mode.set_fp8_kv_cache()
|
||||
|
||||
elif quantization == "int8_sq_per_tensor":
|
||||
use_smooth_quant = True
|
||||
quant_mode = QuantMode.use_smooth_quant(per_token, per_channel)
|
||||
|
||||
elif quantization == "int8_sq_per_token_channel":
|
||||
use_smooth_quant = True
|
||||
per_token = True
|
||||
per_channel = True
|
||||
quant_mode = QuantMode.use_smooth_quant(per_token, per_channel)
|
||||
|
||||
elif quantization == "int8_weight_only":
|
||||
use_smooth_quant = False
|
||||
weight_only_precision = 'int8'
|
||||
quant_mode = QuantMode.use_weight_only(False)
|
||||
|
||||
elif quantization == "int4_weight_only":
|
||||
weight_only_precision = 'int4'
|
||||
quant_mode = QuantMode.use_weight_only(True)
|
||||
|
||||
elif quantization == "int4_weight_only_awq":
|
||||
weight_only_precision = 'int4_awq'
|
||||
quant_mode = QuantMode.from_description(quantize_weights=True,
|
||||
quantize_activations=False,
|
||||
per_token=False,
|
||||
per_channel=False,
|
||||
per_group=True,
|
||||
use_int4_weights=True)
|
||||
|
||||
elif quantization == "int4_weight_only_gptq":
|
||||
weight_only_precision = 'int4_gptq'
|
||||
quant_mode = QuantMode.from_description(quantize_weights=True,
|
||||
quantize_activations=False,
|
||||
per_token=False,
|
||||
per_channel=False,
|
||||
per_group=True,
|
||||
use_int4_weights=True)
|
||||
|
||||
elif quantization == None:
|
||||
pass
|
||||
|
||||
else:
|
||||
raise Exception(f'Unexpected quantization: {quantization}')
|
||||
|
||||
return quant_mode, strongly_typed, use_smooth_quant, weight_only_precision
|
||||
|
||||
|
||||
def build_gpt(args):
|
||||
build_config = get_build_config(args.model)
|
||||
if args.force_num_layer_1:
|
||||
build_config['num_layers'] = 1
|
||||
|
||||
# More parameters
|
||||
world_size = tensorrt_llm.mpi_world_size()
|
||||
runtime_rank = tensorrt_llm.mpi_rank()
|
||||
torch.cuda.set_device(runtime_rank)
|
||||
num_kv_heads = build_config['num_heads'] \
|
||||
if build_config['num_kv_heads'] is None else build_config['num_kv_heads']
|
||||
apply_query_key_layer_scaling = False
|
||||
max_batch_size = build_config['max_batch_size'] \
|
||||
if args.max_batch_size is None else args.max_batch_size
|
||||
max_input_len = build_config['max_input_len'] \
|
||||
if args.max_input_len is None else args.max_input_len
|
||||
max_output_len = build_config['max_output_len'] \
|
||||
if args.max_output_len is None else args.max_output_len
|
||||
max_beam_width = build_config['max_beam_width'] \
|
||||
if args.max_beam_width is None else args.max_beam_width
|
||||
quant_mode, strongly_typed, use_smooth_quant, weight_only_precision = get_quant_mode(
|
||||
args.quantization)
|
||||
use_weight_only = quant_mode.is_weight_only()
|
||||
|
||||
builder = Builder()
|
||||
builder_config = builder.create_builder_config(
|
||||
name=args.model,
|
||||
precision=args.dtype,
|
||||
timing_cache=None,
|
||||
tensor_parallel=world_size, # TP only
|
||||
parallel_build=True,
|
||||
num_layers=build_config['num_layers'],
|
||||
num_heads=build_config['num_heads'],
|
||||
num_kv_heads=num_kv_heads,
|
||||
hidden_size=build_config['hidden_size'],
|
||||
vocab_size=build_config['vocab_size'],
|
||||
hidden_act=build_config['hidden_act'],
|
||||
max_position_embeddings=build_config['n_positions'],
|
||||
apply_query_key_layer_scaling=apply_query_key_layer_scaling,
|
||||
max_batch_size=max_batch_size,
|
||||
max_input_len=max_input_len,
|
||||
max_output_len=max_output_len,
|
||||
int8=(quant_mode.has_act_and_weight_quant()
|
||||
or quant_mode.is_int8_weight_only()),
|
||||
quant_mode=quant_mode,
|
||||
use_refit=False,
|
||||
opt_level=build_config['builder_opt'],
|
||||
strongly_typed=strongly_typed)
|
||||
engine_name = get_engine_name(args.model, args.dtype, world_size,
|
||||
runtime_rank)
|
||||
|
||||
kv_dtype = str_dtype_to_trt(args.dtype)
|
||||
|
||||
# Initialize Module
|
||||
family = get_model_family(args.model)
|
||||
if family == "gpt":
|
||||
tensorrt_llm_model = tensorrt_llm.models.GPTLMHeadModel(
|
||||
num_layers=build_config['num_layers'],
|
||||
num_heads=build_config['num_heads'],
|
||||
hidden_size=build_config['hidden_size'],
|
||||
vocab_size=build_config['vocab_size'],
|
||||
hidden_act=build_config['hidden_act'],
|
||||
max_position_embeddings=build_config['n_positions'],
|
||||
dtype=kv_dtype,
|
||||
mapping=tensorrt_llm.Mapping(world_size=world_size,
|
||||
tp_size=world_size), # TP only
|
||||
apply_query_key_layer_scaling=builder_config.
|
||||
apply_query_key_layer_scaling,
|
||||
position_embedding_type=PositionEmbeddingType.learned_absolute
|
||||
if build_config['position_embedding_type'] is None else
|
||||
build_config['position_embedding_type'],
|
||||
rotary_embedding_percentage=build_config['rotary_pct'],
|
||||
quant_mode=quant_mode,
|
||||
bias=build_config['bias'],
|
||||
moe_layer_config=tensorrt_llm.moe_config.MoeLayerConfig(
|
||||
build_config["moe_num_experts"], build_config["moe_top_k"]))
|
||||
elif family == "opt":
|
||||
tensorrt_llm_model = tensorrt_llm.models.OPTLMHeadModel(
|
||||
num_layers=build_config['num_layers'],
|
||||
num_heads=build_config['num_heads'],
|
||||
hidden_size=build_config['hidden_size'],
|
||||
vocab_size=build_config['vocab_size'],
|
||||
hidden_act=build_config['hidden_act'],
|
||||
max_position_embeddings=build_config['n_positions'],
|
||||
dtype=kv_dtype,
|
||||
mapping=tensorrt_llm.Mapping(world_size=world_size,
|
||||
tp_size=world_size), # TP only
|
||||
pre_norm=build_config['pre_norm'],
|
||||
do_layer_norm_before=build_config['do_layer_norm_before'])
|
||||
elif family == "llama":
|
||||
tensorrt_llm_model = tensorrt_llm.models.LLaMAForCausalLM(
|
||||
num_layers=build_config['num_layers'],
|
||||
num_heads=build_config['num_heads'],
|
||||
num_kv_heads=num_kv_heads,
|
||||
hidden_size=build_config['hidden_size'],
|
||||
vocab_size=build_config['vocab_size'],
|
||||
hidden_act=build_config['hidden_act'],
|
||||
max_position_embeddings=build_config['n_positions'],
|
||||
dtype=kv_dtype,
|
||||
mlp_hidden_size=build_config['inter_size'],
|
||||
mapping=tensorrt_llm.Mapping(world_size=world_size,
|
||||
tp_size=world_size), # TP only
|
||||
quant_mode=quant_mode,
|
||||
use_fused_mlp=True)
|
||||
elif family == "gptj":
|
||||
tensorrt_llm_model = tensorrt_llm.models.GPTJForCausalLM(
|
||||
num_layers=build_config['num_layers'],
|
||||
num_heads=build_config['num_heads'],
|
||||
hidden_size=build_config['hidden_size'],
|
||||
vocab_size=build_config['vocab_size'],
|
||||
hidden_act=build_config['hidden_act'],
|
||||
max_position_embeddings=build_config['n_positions'],
|
||||
rotary_dim=build_config['rotary_dim'],
|
||||
dtype=kv_dtype,
|
||||
mapping=tensorrt_llm.Mapping(world_size=world_size,
|
||||
tp_size=world_size), # TP only
|
||||
quant_mode=quant_mode)
|
||||
elif family == "gptneox":
|
||||
tensorrt_llm_model = tensorrt_llm.models.GPTNeoXForCausalLM(
|
||||
num_layers=build_config['num_layers'],
|
||||
num_heads=build_config['num_heads'],
|
||||
hidden_size=build_config['hidden_size'],
|
||||
vocab_size=build_config['vocab_size'],
|
||||
hidden_act=build_config['hidden_act'],
|
||||
max_position_embeddings=build_config['n_positions'],
|
||||
rotary_dim=build_config['rotary_dim'],
|
||||
dtype=kv_dtype,
|
||||
mapping=tensorrt_llm.Mapping(world_size=world_size,
|
||||
tp_size=world_size), # TP only
|
||||
apply_query_key_layer_scaling=builder_config.
|
||||
apply_query_key_layer_scaling)
|
||||
elif family == "chatglm":
|
||||
tensorrt_llm_model = tensorrt_llm.models.ChatGLMHeadModel(
|
||||
num_layers=build_config['num_layers'],
|
||||
num_heads=build_config['num_heads'],
|
||||
hidden_size=build_config['hidden_size'],
|
||||
vocab_size=build_config['vocab_size'],
|
||||
hidden_act=build_config['hidden_act'],
|
||||
max_position_embeddings=build_config['n_positions'],
|
||||
dtype=kv_dtype,
|
||||
mapping=tensorrt_llm.Mapping(world_size=world_size,
|
||||
tp_size=world_size), # TP only
|
||||
apply_query_key_layer_scaling=builder_config.
|
||||
apply_query_key_layer_scaling,
|
||||
quant_mode=quant_mode,
|
||||
model_name="chatglm_6b")
|
||||
|
||||
elif family == "chatglm2":
|
||||
tensorrt_llm_model = tensorrt_llm.models.ChatGLMHeadModel(
|
||||
num_layers=build_config['num_layers'],
|
||||
num_heads=build_config['num_heads'],
|
||||
hidden_size=build_config['hidden_size'],
|
||||
vocab_size=build_config['vocab_size'],
|
||||
hidden_act=build_config['hidden_act'],
|
||||
max_position_embeddings=build_config['n_positions'],
|
||||
dtype=kv_dtype,
|
||||
mapping=tensorrt_llm.Mapping(world_size=world_size,
|
||||
tp_size=world_size), # TP only
|
||||
apply_query_key_layer_scaling=builder_config.
|
||||
apply_query_key_layer_scaling,
|
||||
quant_mode=quant_mode,
|
||||
model_name="chatglm2_6b")
|
||||
|
||||
elif family == "chatglm3":
|
||||
tensorrt_llm_model = tensorrt_llm.models.ChatGLMHeadModel(
|
||||
num_layers=build_config['num_layers'],
|
||||
num_heads=build_config['num_heads'],
|
||||
hidden_size=build_config['hidden_size'],
|
||||
vocab_size=build_config['vocab_size'],
|
||||
hidden_act=build_config['hidden_act'],
|
||||
max_position_embeddings=build_config['n_positions'],
|
||||
dtype=kv_dtype,
|
||||
mapping=tensorrt_llm.Mapping(world_size=world_size,
|
||||
tp_size=world_size), # TP only
|
||||
apply_query_key_layer_scaling=builder_config.
|
||||
apply_query_key_layer_scaling,
|
||||
quant_mode=quant_mode,
|
||||
model_name="chatglm3_6b")
|
||||
|
||||
elif family == "bloom":
|
||||
tensorrt_llm_model = tensorrt_llm.models.BloomForCausalLM(
|
||||
num_layers=build_config['num_layers'],
|
||||
num_heads=build_config['num_heads'],
|
||||
hidden_size=build_config['hidden_size'],
|
||||
vocab_size=build_config['vocab_size'],
|
||||
max_position_embeddings=build_config['n_positions'],
|
||||
dtype=kv_dtype,
|
||||
mapping=tensorrt_llm.Mapping(world_size=world_size,
|
||||
tp_size=world_size), # TP only
|
||||
quant_mode=quant_mode,
|
||||
use_parallel_embedding=(args.model == 'bloom_176b'))
|
||||
elif family == "falcon":
|
||||
tensorrt_llm_model = tensorrt_llm.models.FalconForCausalLM(
|
||||
num_layers=build_config['num_layers'],
|
||||
num_heads=build_config['num_heads'],
|
||||
num_kv_heads=num_kv_heads,
|
||||
hidden_size=build_config['hidden_size'],
|
||||
vocab_size=build_config['vocab_size'],
|
||||
max_position_embeddings=build_config['n_positions'],
|
||||
dtype=kv_dtype,
|
||||
bias=build_config['bias'],
|
||||
quant_mode=quant_mode,
|
||||
use_alibi=build_config['use_alibi'],
|
||||
new_decoder_architecture=build_config['new_decoder_architecture'],
|
||||
parallel_attention=build_config['parallel_attention'],
|
||||
mapping=tensorrt_llm.Mapping(world_size=world_size,
|
||||
tp_size=world_size))
|
||||
else:
|
||||
raise Exception(f'Unexpected model: {args.model}')
|
||||
|
||||
quant_kwargs = {}
|
||||
if family == "llama" and use_weight_only:
|
||||
if weight_only_precision == 'int4_awq':
|
||||
quant_kwargs = {
|
||||
"group_size": 128,
|
||||
"zero": False,
|
||||
"pre_quant_scale": True,
|
||||
"exclude_modules": [],
|
||||
}
|
||||
elif weight_only_precision == 'int4_gptq':
|
||||
quant_kwargs = {
|
||||
"group_size": 128,
|
||||
"zero": True,
|
||||
"pre_quant_scale": False,
|
||||
}
|
||||
tensorrt_llm_model = quantize_model(tensorrt_llm_model, quant_mode,
|
||||
**quant_kwargs)
|
||||
|
||||
# Module -> Network
|
||||
network = builder.create_network()
|
||||
network.trt_network.name = engine_name
|
||||
|
||||
# Plugins
|
||||
if args.mode == 'plugin':
|
||||
network.plugin_config.set_gpt_attention_plugin(dtype=args.dtype)
|
||||
network.plugin_config.set_context_fmha(ContextFMHAType.enabled)
|
||||
network.plugin_config.enable_remove_input_padding()
|
||||
network.plugin_config.set_lookup_plugin(dtype=args.dtype)
|
||||
if args.quantization is None or "fp8" not in args.quantization:
|
||||
network.plugin_config.set_gemm_plugin(dtype=args.dtype)
|
||||
|
||||
# Quantization plugins.
|
||||
if use_smooth_quant:
|
||||
network.plugin_config.set_smooth_quant_gemm_plugin(dtype=args.dtype)
|
||||
network.plugin_config.set_layernorm_quantization_plugin(
|
||||
dtype=args.dtype)
|
||||
network.plugin_config.set_quantize_tensor_plugin()
|
||||
network.plugin_config.set_quantize_per_token_plugin()
|
||||
elif use_weight_only:
|
||||
network.plugin_config.set_weight_only_quant_matmul_plugin(
|
||||
dtype=args.dtype)
|
||||
elif family == "llama" and quant_mode.has_act_and_weight_quant():
|
||||
# RMS norm plugin for SmoothQuant
|
||||
network.plugin_config.set_rmsnorm_quantization_plugin(
|
||||
dtype=args.dtype)
|
||||
elif args.mode == 'ootb-except-mha':
|
||||
network.plugin_config.set_gpt_attention_plugin(dtype=args.dtype)
|
||||
network.plugin_config.set_context_fmha(ContextFMHAType.enabled)
|
||||
|
||||
if world_size > 1:
|
||||
network.plugin_config.set_nccl_plugin(
|
||||
dtype=args.dtype,
|
||||
use_custom_all_reduce=build_config["use_custom_all_reduce"])
|
||||
|
||||
with net_guard(network):
|
||||
# Prepare
|
||||
network.set_named_parameters(tensorrt_llm_model.named_parameters())
|
||||
|
||||
# Forward
|
||||
inputs = tensorrt_llm_model.prepare_inputs(max_batch_size,
|
||||
max_input_len,
|
||||
max_output_len, True,
|
||||
max_beam_width)
|
||||
tensorrt_llm_model(*inputs)
|
||||
|
||||
if args.mode == 'plugin':
|
||||
tensorrt_llm.graph_rewriting.optimize(network)
|
||||
|
||||
# Network -> Engine
|
||||
engine = builder.build_engine(network, builder_config)
|
||||
assert engine is not None, f'Failed to build engine for rank {runtime_rank}'
|
||||
|
||||
if args.output_dir is not None:
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
serialize_path = os.path.join(args.output_dir, engine_name)
|
||||
serialize_engine(engine, serialize_path)
|
||||
if runtime_rank == 0:
|
||||
config_path = os.path.join(args.output_dir, 'config.json')
|
||||
builder_config.plugin_config = network.plugin_config
|
||||
builder.save_config(builder_config, config_path)
|
||||
return engine
|
||||
|
||||
|
||||
def build_bert(args):
|
||||
build_config = get_build_config(args.model)
|
||||
if args.force_num_layer_1:
|
||||
build_config['num_layers'] = 1
|
||||
|
||||
# More parameters
|
||||
world_size = tensorrt_llm.mpi_world_size()
|
||||
runtime_rank = tensorrt_llm.mpi_rank()
|
||||
torch.cuda.set_device(runtime_rank)
|
||||
num_kv_heads = build_config['num_heads'] \
|
||||
if build_config['num_kv_heads'] is None else build_config['num_kv_heads']
|
||||
max_batch_size = build_config['max_batch_size'] \
|
||||
if args.max_batch_size is None else args.max_batch_size
|
||||
max_input_len = build_config['max_input_len'] \
|
||||
if args.max_input_len is None else args.max_input_len
|
||||
bs_range = [1, (max_batch_size + 1) // 2, max_batch_size]
|
||||
inlen_range = [1, (max_input_len + 1) // 2, max_input_len]
|
||||
|
||||
builder = Builder()
|
||||
builder_config = builder.create_builder_config(
|
||||
name=args.model,
|
||||
precision=args.dtype,
|
||||
timing_cache=None,
|
||||
tensor_parallel=world_size, # TP only
|
||||
parallel_build=True,
|
||||
num_layers=build_config['num_layers'],
|
||||
num_heads=build_config['num_heads'],
|
||||
num_kv_heads=num_kv_heads,
|
||||
hidden_size=build_config['hidden_size'],
|
||||
vocab_size=build_config['vocab_size'],
|
||||
hidden_act=build_config['hidden_act'],
|
||||
max_position_embeddings=build_config['n_positions'],
|
||||
max_batch_size=max_batch_size,
|
||||
max_input_len=max_input_len,
|
||||
opt_level=build_config['builder_opt'])
|
||||
engine_name = get_engine_name(args.model, args.dtype, world_size,
|
||||
runtime_rank)
|
||||
|
||||
# Initialize model
|
||||
tensorrt_llm_bert = tensorrt_llm.models.BertModel(
|
||||
num_layers=build_config['num_layers'],
|
||||
num_heads=build_config['num_heads'],
|
||||
hidden_size=build_config['hidden_size'],
|
||||
vocab_size=build_config['vocab_size'],
|
||||
hidden_act=build_config['hidden_act'],
|
||||
max_position_embeddings=build_config['n_positions'],
|
||||
type_vocab_size=build_config['type_vocab_size'],
|
||||
mapping=tensorrt_llm.Mapping(world_size=world_size, tp_size=world_size))
|
||||
|
||||
# Module -> Network
|
||||
network = builder.create_network()
|
||||
network.trt_network.name = engine_name
|
||||
|
||||
# Plugins
|
||||
if args.mode == 'plugin':
|
||||
network.plugin_config.set_bert_attention_plugin(dtype=args.dtype)
|
||||
network.plugin_config.set_gemm_plugin(dtype=args.dtype)
|
||||
network.plugin_config.enable_qk_half_accum()
|
||||
network.plugin_config.set_context_fmha(ContextFMHAType.enabled)
|
||||
elif args.mode == 'ootb-except-mha':
|
||||
network.plugin_config.set_bert_attention_plugin(dtype=args.dtype)
|
||||
network.plugin_config.set_context_fmha(ContextFMHAType.enabled)
|
||||
|
||||
if world_size > 1:
|
||||
network.plugin_config.set_nccl_plugin(
|
||||
dtype=args.dtype,
|
||||
use_custom_all_reduce=build_config["use_custom_all_reduce"])
|
||||
|
||||
with net_guard(network):
|
||||
# Prepare
|
||||
network.set_named_parameters(tensorrt_llm_bert.named_parameters())
|
||||
|
||||
# Forward
|
||||
input_ids = tensorrt_llm.Tensor(
|
||||
name='input_ids',
|
||||
dtype=trt.int32,
|
||||
shape=[-1, -1],
|
||||
dim_range=OrderedDict([('batch_size', [bs_range]),
|
||||
('input_len', [inlen_range])]),
|
||||
)
|
||||
input_lengths = tensorrt_llm.Tensor(name='input_lengths',
|
||||
dtype=trt.int32,
|
||||
shape=[-1],
|
||||
dim_range=OrderedDict([
|
||||
('batch_size', [bs_range])
|
||||
]))
|
||||
hidden_states = tensorrt_llm_bert(input_ids=input_ids,
|
||||
input_lengths=input_lengths)
|
||||
|
||||
# Mark outputs
|
||||
hidden_states_dtype = str_dtype_to_trt(args.dtype)
|
||||
hidden_states.mark_output('hidden_states', hidden_states_dtype)
|
||||
|
||||
# Network -> Engine
|
||||
engine = builder.build_engine(network, builder_config)
|
||||
assert engine is not None, f'Failed to build engine for rank {runtime_rank}'
|
||||
|
||||
if args.output_dir is not None:
|
||||
if not os.path.exists(args.output_dir):
|
||||
os.makedirs(args.output_dir)
|
||||
serialize_path = os.path.join(args.output_dir, engine_name)
|
||||
serialize_engine(engine, serialize_path)
|
||||
if runtime_rank == 0:
|
||||
config_path = os.path.join(args.output_dir, 'config.json')
|
||||
builder_config.plugin_config = network.plugin_config
|
||||
builder.save_config(builder_config, config_path)
|
||||
return engine
|
||||
|
||||
|
||||
def main(args):
|
||||
logger.set_level(args.log_level)
|
||||
if args.model in get_allowed_models(benchmark_type="gpt"):
|
||||
build_gpt(args)
|
||||
elif args.model in get_allowed_models(benchmark_type="bert"):
|
||||
build_bert(args)
|
||||
else:
|
||||
raise Exception(f'Unexpected model: {args.model}')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
mp.set_start_method('spawn')
|
||||
args = parse_arguments()
|
||||
main(args)
|
||||
@ -17,158 +17,95 @@ import time
|
||||
from math import ceil
|
||||
|
||||
import torch
|
||||
from allowed_configs import get_build_config, get_model_family
|
||||
from base_benchmark import BaseBenchmark, get_engine_name, serialize_engine
|
||||
|
||||
import tensorrt_llm
|
||||
from tensorrt_llm._utils import str_dtype_to_trt
|
||||
from tensorrt_llm.builder import Builder
|
||||
from tensorrt_llm.layers import PositionEmbeddingType
|
||||
from tensorrt_llm.models import quantize_model
|
||||
from tensorrt_llm.network import net_guard
|
||||
from tensorrt_llm.plugin.plugin import ContextFMHAType
|
||||
from tensorrt_llm.quantization import QuantMode
|
||||
|
||||
from allowed_configs import get_build_config # isort:skip
|
||||
from base_benchmark import BaseBenchmark # isort:skip
|
||||
from build import build_gpt, get_quant_mode # isort:skip
|
||||
|
||||
|
||||
class GPTBenchmark(BaseBenchmark):
|
||||
|
||||
def __init__(self,
|
||||
engine_dir,
|
||||
model_name,
|
||||
mode,
|
||||
batch_sizes,
|
||||
in_out_lens,
|
||||
dtype,
|
||||
refit,
|
||||
num_beams,
|
||||
top_k,
|
||||
top_p,
|
||||
output_dir,
|
||||
n_positions=None,
|
||||
max_input_len=None,
|
||||
max_output_len=None,
|
||||
max_batch_size=None,
|
||||
enable_custom_all_reduce=None,
|
||||
**kwargs):
|
||||
super().__init__(engine_dir, model_name, dtype, output_dir)
|
||||
def __init__(self, args, batch_sizes, in_out_lens):
|
||||
super().__init__(args.engine_dir, args.model, args.dtype)
|
||||
self.batch_sizes = batch_sizes
|
||||
self.in_out_lens = in_out_lens
|
||||
self.refit = refit
|
||||
self.num_beams = num_beams
|
||||
self.num_beams = args.num_beams
|
||||
self.mode = args.mode
|
||||
self.build_time = 0
|
||||
self.mode = mode # plugin or ootb or ootb-except-mha
|
||||
self.fuse_bias = True
|
||||
|
||||
self.cuda_graph_mode = kwargs.get('enable_cuda_graph', False)
|
||||
self.strongly_typed = kwargs.get('strongly_typed', False)
|
||||
self.enable_custom_all_reduce = enable_custom_all_reduce
|
||||
self.cuda_graph_mode = args.enable_cuda_graph
|
||||
|
||||
if engine_dir is not None:
|
||||
if args.engine_dir is not None:
|
||||
# Get build configs from engine directory is done in base class
|
||||
# Deserialize engine from engine directory
|
||||
self.serialize_path = os.path.join(engine_dir, self.engine_name)
|
||||
self.serialize_path = os.path.join(args.engine_dir,
|
||||
self.engine_name)
|
||||
with open(self.serialize_path, 'rb') as f:
|
||||
engine_buffer = f.read()
|
||||
else:
|
||||
# Build engine
|
||||
self.world_size = tensorrt_llm.mpi_world_size()
|
||||
self.apply_query_key_layer_scaling = False
|
||||
|
||||
self.use_weight_only = False
|
||||
self.per_group = False
|
||||
self.weight_only_precision = 'int8'
|
||||
self.per_token = False
|
||||
self.per_channel = False
|
||||
|
||||
use_mha_plugin = mode == 'plugin' or mode == 'ootb-except-mha'
|
||||
mha_plg_dtype = dtype if use_mha_plugin else False
|
||||
use_non_mha_plugin = mode == 'plugin'
|
||||
non_mha_plg_dtype = dtype if use_non_mha_plugin else False
|
||||
|
||||
self.use_gpt_attention_plugin = mha_plg_dtype
|
||||
self.use_gemm_plugin = non_mha_plg_dtype
|
||||
# Starting TRT9.1 OOTB norm layer sees improvement over plugin norm layer
|
||||
self.use_layernorm_plugin = False
|
||||
self.use_rmsnorm_plugin = False
|
||||
self.use_lookup_plugin = non_mha_plg_dtype
|
||||
self.use_weight_only_quant_gemm_plugin = non_mha_plg_dtype
|
||||
self.enable_context_fmha = use_mha_plugin
|
||||
|
||||
self.remove_input_padding = use_non_mha_plugin
|
||||
|
||||
for key, value in get_build_config(model_name).items():
|
||||
for key, value in get_build_config(args.model).items():
|
||||
setattr(self, key, value)
|
||||
|
||||
if self.quantization is None:
|
||||
self.quantization = kwargs.get('quantization', None)
|
||||
|
||||
self.set_quantization()
|
||||
|
||||
# Override the n_position/max_input_len/max_output_len/max_batch_size to value from cmd line if that's specified.
|
||||
if n_positions is not None:
|
||||
assert isinstance(
|
||||
n_positions, int
|
||||
) and n_positions > 0, f"n_positions should be a valid int number, got {n_positions}"
|
||||
self.n_positions = n_positions
|
||||
if max_input_len is not None:
|
||||
assert isinstance(
|
||||
max_input_len, int
|
||||
) and max_input_len > 0, f"max_input_len should be a valid int number, got {max_input_len}"
|
||||
self.max_input_len = max_input_len
|
||||
if max_output_len is not None:
|
||||
assert isinstance(
|
||||
max_output_len, int
|
||||
) and max_output_len > 0, f"max_output_len should be a valid int number, got {max_output_len}"
|
||||
self.max_output_len = max_output_len
|
||||
if max_batch_size is not None:
|
||||
assert isinstance(
|
||||
max_batch_size, int
|
||||
) and max_batch_size > 0, f"max_batch_size should be a valid int number, got {max_batch_size}"
|
||||
self.max_batch_size = max_batch_size
|
||||
if self.num_kv_heads is None:
|
||||
self.num_kv_heads = self.num_heads
|
||||
if kwargs.get('force_num_layer_1', False):
|
||||
if args.force_num_layer_1:
|
||||
self.num_layers = 1
|
||||
engine_buffer = self.build()
|
||||
self.quant_mode, _, _, _ = get_quant_mode(args.quantization)
|
||||
self.enable_fp8 = self.quant_mode.has_fp8_qdq()
|
||||
self.fp8_kv_cache = self.quant_mode.has_fp8_kv_cache()
|
||||
|
||||
# Plugins
|
||||
self.use_gpt_attention_plugin = False
|
||||
self.remove_input_padding = False
|
||||
if args.mode == 'plugin':
|
||||
self.use_gpt_attention_plugin = True
|
||||
self.remove_input_padding = True
|
||||
elif args.mode == 'ootb-except-mha':
|
||||
self.use_gpt_attention_plugin = True
|
||||
|
||||
start = time.time()
|
||||
engine_buffer = build_gpt(args)
|
||||
self.build_time = round(time.time() - start, 2)
|
||||
|
||||
assert engine_buffer is not None
|
||||
|
||||
if not hasattr(self, 'num_kv_heads') or self.num_kv_heads is None:
|
||||
self.num_kv_heads = self.num_heads
|
||||
model_config = tensorrt_llm.runtime.ModelConfig(
|
||||
vocab_size=self.vocab_size,
|
||||
num_layers=self.num_layers,
|
||||
num_heads=self.num_heads // self.world_size,
|
||||
num_kv_heads=ceil(self.num_kv_heads / self.world_size),
|
||||
hidden_size=self.hidden_size // self.world_size,
|
||||
vocab_size=self.vocab_size,
|
||||
num_layers=self.num_layers,
|
||||
gpt_attention_plugin=self.use_gpt_attention_plugin,
|
||||
remove_input_padding=self.remove_input_padding,
|
||||
quant_mode=self.quant_mode,
|
||||
use_custom_all_reduce=self.enable_custom_all_reduce,
|
||||
use_custom_all_reduce=self.use_custom_all_reduce,
|
||||
)
|
||||
if model_name == 'chatglm_6b':
|
||||
if args.model == 'chatglm_6b':
|
||||
self.sampling_config = tensorrt_llm.runtime.SamplingConfig(
|
||||
end_id=130005,
|
||||
pad_id=3,
|
||||
num_beams=num_beams,
|
||||
top_k=top_k,
|
||||
top_p=top_p)
|
||||
num_beams=self.num_beams,
|
||||
top_k=args.top_k,
|
||||
top_p=args.top_p)
|
||||
self.decoder = tensorrt_llm.runtime.ChatGLMGenerationSession(
|
||||
model_config, engine_buffer, self.runtime_mapping)
|
||||
elif model_name in ['chatglm2_6b', 'chatglm3_6b']:
|
||||
elif args.model in ['chatglm2_6b', 'chatglm3_6b']:
|
||||
self.sampling_config = tensorrt_llm.runtime.SamplingConfig(
|
||||
end_id=2,
|
||||
pad_id=0,
|
||||
num_beams=num_beams,
|
||||
top_k=top_k,
|
||||
top_p=top_p)
|
||||
num_beams=self.num_beams,
|
||||
top_k=args.top_k,
|
||||
top_p=args.top_p)
|
||||
self.decoder = tensorrt_llm.runtime.GenerationSession(
|
||||
model_config, engine_buffer, self.runtime_mapping)
|
||||
else:
|
||||
self.sampling_config = tensorrt_llm.runtime.SamplingConfig(
|
||||
end_id=50256,
|
||||
pad_id=50256,
|
||||
num_beams=num_beams,
|
||||
top_k=top_k,
|
||||
top_p=top_p)
|
||||
num_beams=self.num_beams,
|
||||
top_k=args.top_k,
|
||||
top_p=args.top_p)
|
||||
self.decoder = tensorrt_llm.runtime.GenerationSession(
|
||||
model_config,
|
||||
engine_buffer,
|
||||
@ -201,370 +138,37 @@ class GPTBenchmark(BaseBenchmark):
|
||||
self.decoder.setup(batch_size, inlen, outlen, beam_width=self.num_beams)
|
||||
return (input_ids, input_lengths)
|
||||
|
||||
def set_quantization(self):
|
||||
self.quant_mode = QuantMode(0)
|
||||
def get_report_dict(self, benchmark_profiler=None):
|
||||
report_dict = super().get_report_dict(
|
||||
benchmark_profiler=benchmark_profiler)
|
||||
if benchmark_profiler is not None:
|
||||
report_dict["generation_time(ms)"] = None
|
||||
report_dict["total_generated_tokens"] = None
|
||||
report_dict["generation_tokens_per_second"] = None
|
||||
return report_dict
|
||||
|
||||
if self.quantization == "fp8":
|
||||
self.strongly_typed = True
|
||||
self.quant_mode = self.quant_mode.set_fp8_qdq()
|
||||
self.quant_mode = self.quant_mode.set_fp8_kv_cache()
|
||||
|
||||
elif self.quantization == "fp8_gemm":
|
||||
self.strongly_typed = True
|
||||
self.quant_mode = self.quant_mode.set_fp8_qdq()
|
||||
|
||||
elif self.quantization == "fp8_kv_cache":
|
||||
self.strongly_typed = True
|
||||
self.quant_mode = self.quant_mode.set_fp8_kv_cache()
|
||||
|
||||
elif self.quantization == "int8_sq_per_tensor":
|
||||
self.use_smooth_quant = True
|
||||
self.quant_mode = QuantMode.use_smooth_quant(
|
||||
self.per_token, self.per_channel)
|
||||
|
||||
elif self.quantization == "int8_sq_per_token_channel":
|
||||
self.use_smooth_quant = True
|
||||
self.per_token = True
|
||||
self.per_channel = True
|
||||
self.quant_mode = QuantMode.use_smooth_quant(
|
||||
self.per_token, self.per_channel)
|
||||
|
||||
elif self.quantization == "int8_weight_only":
|
||||
self.use_smooth_quant = False
|
||||
self.use_weight_only = True
|
||||
self.weight_only_precision = 'int8'
|
||||
self.quant_mode = QuantMode.use_weight_only(False)
|
||||
|
||||
elif self.quantization == "int4_weight_only":
|
||||
self.use_weight_only = True
|
||||
self.weight_only_precision = 'int4'
|
||||
self.quant_mode = QuantMode.use_weight_only(True)
|
||||
|
||||
elif self.quantization == "int4_weight_only_awq":
|
||||
self.use_weight_only = True
|
||||
self.per_group = True
|
||||
self.weight_only_precision = 'int4_awq'
|
||||
self.quant_mode = QuantMode.from_description(
|
||||
quantize_weights=True,
|
||||
quantize_activations=False,
|
||||
per_token=False,
|
||||
per_channel=False,
|
||||
per_group=True,
|
||||
use_int4_weights=True)
|
||||
|
||||
elif self.quantization == "int4_weight_only_gptq":
|
||||
self.use_weight_only = True
|
||||
self.per_group = True
|
||||
self.weight_only_precision = 'int4_gptq'
|
||||
self.quant_mode = QuantMode.from_description(
|
||||
quantize_weights=True,
|
||||
quantize_activations=False,
|
||||
per_token=False,
|
||||
per_channel=False,
|
||||
per_group=True,
|
||||
use_int4_weights=True)
|
||||
|
||||
elif self.quantization == None:
|
||||
pass
|
||||
|
||||
else:
|
||||
raise Exception(f'{0} is invalid config: {self.quantization}')
|
||||
|
||||
def build(self):
|
||||
builder = Builder()
|
||||
builder_config = builder.create_builder_config(
|
||||
name=self.model_name,
|
||||
precision=self.dtype,
|
||||
timing_cache=None,
|
||||
tensor_parallel=self.world_size, # TP only
|
||||
parallel_build=True,
|
||||
num_layers=self.num_layers,
|
||||
num_heads=self.num_heads,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
hidden_size=self.hidden_size,
|
||||
vocab_size=self.vocab_size,
|
||||
hidden_act=self.hidden_act,
|
||||
max_position_embeddings=self.n_positions,
|
||||
apply_query_key_layer_scaling=self.apply_query_key_layer_scaling,
|
||||
max_batch_size=self.max_batch_size,
|
||||
max_input_len=self.max_input_len,
|
||||
max_output_len=self.max_output_len,
|
||||
int8=self.quant_mode.has_act_and_weight_quant()
|
||||
or self.quant_mode.is_int8_weight_only(),
|
||||
quant_mode=self.quant_mode,
|
||||
use_refit=self.refit,
|
||||
opt_level=self.builder_opt,
|
||||
strongly_typed=self.strongly_typed)
|
||||
engine_name = get_engine_name(self.model_name, self.dtype,
|
||||
self.world_size, self.runtime_rank)
|
||||
|
||||
kv_dtype = str_dtype_to_trt(self.dtype)
|
||||
|
||||
# Initialize Module
|
||||
family = get_model_family(self.model_name)
|
||||
if family == "gpt":
|
||||
tensorrt_llm_model = tensorrt_llm.models.GPTLMHeadModel(
|
||||
num_layers=self.num_layers,
|
||||
num_heads=self.num_heads,
|
||||
hidden_size=self.hidden_size,
|
||||
vocab_size=self.vocab_size,
|
||||
hidden_act=self.hidden_act,
|
||||
max_position_embeddings=self.n_positions,
|
||||
dtype=kv_dtype,
|
||||
mapping=tensorrt_llm.Mapping(
|
||||
world_size=self.world_size,
|
||||
tp_size=self.world_size), # TP only
|
||||
apply_query_key_layer_scaling=builder_config.
|
||||
apply_query_key_layer_scaling,
|
||||
position_embedding_type=PositionEmbeddingType.learned_absolute
|
||||
if self.position_embedding_type is None else
|
||||
self.position_embedding_type,
|
||||
rotary_embedding_percentage=self.rotary_pct,
|
||||
quant_mode=self.quant_mode,
|
||||
bias=self.bias)
|
||||
elif family == "opt":
|
||||
tensorrt_llm_model = tensorrt_llm.models.OPTLMHeadModel(
|
||||
num_layers=self.num_layers,
|
||||
num_heads=self.num_heads,
|
||||
hidden_size=self.hidden_size,
|
||||
vocab_size=self.vocab_size,
|
||||
hidden_act=self.hidden_act,
|
||||
max_position_embeddings=self.n_positions,
|
||||
dtype=kv_dtype,
|
||||
mapping=tensorrt_llm.Mapping(
|
||||
world_size=self.world_size,
|
||||
tp_size=self.world_size), # TP only
|
||||
pre_norm=self.pre_norm,
|
||||
do_layer_norm_before=self.do_layer_norm_before)
|
||||
elif family == "llama":
|
||||
tensorrt_llm_model = tensorrt_llm.models.LLaMAForCausalLM(
|
||||
num_layers=self.num_layers,
|
||||
num_heads=self.num_heads,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
hidden_size=self.hidden_size,
|
||||
vocab_size=self.vocab_size,
|
||||
hidden_act=self.hidden_act,
|
||||
max_position_embeddings=self.n_positions,
|
||||
dtype=kv_dtype,
|
||||
mlp_hidden_size=self.inter_size,
|
||||
mapping=tensorrt_llm.Mapping(
|
||||
world_size=self.world_size,
|
||||
tp_size=self.world_size), # TP only
|
||||
quant_mode=self.quant_mode)
|
||||
elif family == "gptj":
|
||||
tensorrt_llm_model = tensorrt_llm.models.GPTJForCausalLM(
|
||||
num_layers=self.num_layers,
|
||||
num_heads=self.num_heads,
|
||||
hidden_size=self.hidden_size,
|
||||
vocab_size=self.vocab_size,
|
||||
hidden_act=self.hidden_act,
|
||||
max_position_embeddings=self.n_positions,
|
||||
rotary_dim=self.rotary_dim,
|
||||
dtype=kv_dtype,
|
||||
mapping=tensorrt_llm.Mapping(
|
||||
world_size=self.world_size,
|
||||
tp_size=self.world_size), # TP only
|
||||
quant_mode=self.quant_mode)
|
||||
elif family == "gptneox":
|
||||
tensorrt_llm_model = tensorrt_llm.models.GPTNeoXForCausalLM(
|
||||
num_layers=self.num_layers,
|
||||
num_heads=self.num_heads,
|
||||
hidden_size=self.hidden_size,
|
||||
vocab_size=self.vocab_size,
|
||||
hidden_act=self.hidden_act,
|
||||
max_position_embeddings=self.n_positions,
|
||||
rotary_dim=self.rotary_dim,
|
||||
dtype=kv_dtype,
|
||||
mapping=tensorrt_llm.Mapping(
|
||||
world_size=self.world_size,
|
||||
tp_size=self.world_size), # TP only
|
||||
apply_query_key_layer_scaling=builder_config.
|
||||
apply_query_key_layer_scaling)
|
||||
elif family == "chatglm":
|
||||
tensorrt_llm_model = tensorrt_llm.models.ChatGLMHeadModel(
|
||||
num_layers=self.num_layers,
|
||||
num_heads=self.num_heads,
|
||||
hidden_size=self.hidden_size,
|
||||
vocab_size=self.vocab_size,
|
||||
hidden_act=self.hidden_act,
|
||||
max_position_embeddings=self.n_positions,
|
||||
dtype=kv_dtype,
|
||||
mapping=tensorrt_llm.Mapping(
|
||||
world_size=self.world_size,
|
||||
tp_size=self.world_size), # TP only
|
||||
apply_query_key_layer_scaling=builder_config.
|
||||
apply_query_key_layer_scaling,
|
||||
quant_mode=self.quant_mode,
|
||||
model_name="chatglm_6b")
|
||||
elif family == "chatglm2":
|
||||
tensorrt_llm_model = tensorrt_llm.models.ChatGLMHeadModel(
|
||||
num_layers=self.num_layers,
|
||||
num_heads=self.num_heads,
|
||||
hidden_size=self.hidden_size,
|
||||
vocab_size=self.vocab_size,
|
||||
hidden_act=self.hidden_act,
|
||||
max_position_embeddings=self.n_positions,
|
||||
dtype=kv_dtype,
|
||||
mapping=tensorrt_llm.Mapping(
|
||||
world_size=self.world_size,
|
||||
tp_size=self.world_size), # TP only
|
||||
apply_query_key_layer_scaling=builder_config.
|
||||
apply_query_key_layer_scaling,
|
||||
quant_mode=self.quant_mode,
|
||||
model_name="chatglm2_6b")
|
||||
elif family == "chatglm3":
|
||||
tensorrt_llm_model = tensorrt_llm.models.ChatGLMHeadModel(
|
||||
num_layers=self.num_layers,
|
||||
num_heads=self.num_heads,
|
||||
hidden_size=self.hidden_size,
|
||||
vocab_size=self.vocab_size,
|
||||
hidden_act=self.hidden_act,
|
||||
max_position_embeddings=self.n_positions,
|
||||
dtype=kv_dtype,
|
||||
mapping=tensorrt_llm.Mapping(
|
||||
world_size=self.world_size,
|
||||
tp_size=self.world_size), # TP only
|
||||
apply_query_key_layer_scaling=builder_config.
|
||||
apply_query_key_layer_scaling,
|
||||
quant_mode=self.quant_mode,
|
||||
model_name="chatglm3_6b")
|
||||
elif family == "bloom":
|
||||
tensorrt_llm_model = tensorrt_llm.models.BloomForCausalLM(
|
||||
num_layers=self.num_layers,
|
||||
num_heads=self.num_heads,
|
||||
hidden_size=self.hidden_size,
|
||||
vocab_size=self.vocab_size,
|
||||
max_position_embeddings=self.n_positions,
|
||||
dtype=kv_dtype,
|
||||
mapping=tensorrt_llm.Mapping(
|
||||
world_size=self.world_size,
|
||||
tp_size=self.world_size), # TP only
|
||||
quant_mode=self.quant_mode,
|
||||
use_parallel_embedding=(self.model_name == 'bloom_176b'))
|
||||
elif family == "falcon":
|
||||
tensorrt_llm_model = tensorrt_llm.models.FalconForCausalLM(
|
||||
num_layers=self.num_layers,
|
||||
num_heads=self.num_heads,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
hidden_size=self.hidden_size,
|
||||
vocab_size=self.vocab_size,
|
||||
max_position_embeddings=self.n_positions,
|
||||
dtype=kv_dtype,
|
||||
bias=self.bias,
|
||||
quant_mode=self.quant_mode,
|
||||
use_alibi=self.use_alibi,
|
||||
new_decoder_architecture=self.new_decoder_architecture,
|
||||
parallel_attention=self.parallel_attention,
|
||||
mapping=tensorrt_llm.Mapping(world_size=self.world_size,
|
||||
tp_size=self.world_size))
|
||||
else:
|
||||
raise Exception(f'Unexpected model: {self.model_name}')
|
||||
|
||||
quant_kwargs = {}
|
||||
if family == "llama" and self.use_weight_only:
|
||||
if self.weight_only_precision == 'int4_awq':
|
||||
quant_kwargs = {
|
||||
"group_size": 128,
|
||||
"zero": False,
|
||||
"pre_quant_scale": True,
|
||||
"exclude_modules": [],
|
||||
}
|
||||
elif self.weight_only_precision == 'int4_gptq':
|
||||
quant_kwargs = {
|
||||
"group_size": 128,
|
||||
"zero": True,
|
||||
"pre_quant_scale": False,
|
||||
}
|
||||
tensorrt_llm_model = quantize_model(tensorrt_llm_model, self.quant_mode,
|
||||
**quant_kwargs)
|
||||
|
||||
# Module -> Network
|
||||
network = builder.create_network()
|
||||
network.trt_network.name = engine_name
|
||||
|
||||
not_fp8_quantization = self.quantization is None or "fp8" not in self.quantization
|
||||
|
||||
if self.use_gpt_attention_plugin:
|
||||
network.plugin_config.set_gpt_attention_plugin(
|
||||
dtype=self.use_gpt_attention_plugin)
|
||||
if self.use_gemm_plugin and not_fp8_quantization:
|
||||
network.plugin_config.set_gemm_plugin(dtype=self.use_gemm_plugin)
|
||||
if self.use_layernorm_plugin:
|
||||
network.plugin_config.set_layernorm_plugin(
|
||||
dtype=self.use_layernorm_plugin)
|
||||
if self.use_rmsnorm_plugin:
|
||||
network.plugin_config.set_rmsnorm_plugin(
|
||||
dtype=self.use_rmsnorm_plugin)
|
||||
if self.enable_context_fmha:
|
||||
network.plugin_config.set_context_fmha(ContextFMHAType.enabled)
|
||||
if self.remove_input_padding:
|
||||
network.plugin_config.enable_remove_input_padding()
|
||||
|
||||
# Quantization plugins.
|
||||
if self.use_smooth_quant:
|
||||
network.plugin_config.set_smooth_quant_gemm_plugin(dtype=self.dtype)
|
||||
network.plugin_config.set_layernorm_quantization_plugin(
|
||||
dtype=self.dtype)
|
||||
network.plugin_config.set_quantize_tensor_plugin()
|
||||
network.plugin_config.set_quantize_per_token_plugin()
|
||||
elif self.use_weight_only and self.use_weight_only_quant_gemm_plugin:
|
||||
network.plugin_config.set_weight_only_quant_matmul_plugin(
|
||||
dtype=self.dtype)
|
||||
|
||||
# RMS norm plugin for SmoothQuant
|
||||
if self.quant_mode.has_act_and_weight_quant(
|
||||
) and 'llama' in self.model_name:
|
||||
network.plugin_config.set_rmsnorm_quantization_plugin()
|
||||
|
||||
if self.world_size > 1:
|
||||
network.plugin_config.set_nccl_plugin(self.dtype,
|
||||
self.enable_custom_all_reduce)
|
||||
|
||||
# Use the plugin for the embedding parallelism and sharing
|
||||
network.plugin_config.set_lookup_plugin(dtype=self.use_lookup_plugin)
|
||||
|
||||
with net_guard(network):
|
||||
# Prepare
|
||||
network.set_named_parameters(tensorrt_llm_model.named_parameters())
|
||||
|
||||
# Forward
|
||||
inputs = tensorrt_llm_model.prepare_inputs(self.max_batch_size,
|
||||
self.max_input_len,
|
||||
self.max_output_len,
|
||||
True, self.num_beams)
|
||||
tensorrt_llm_model(*inputs)
|
||||
|
||||
if self.fuse_bias:
|
||||
tensorrt_llm.graph_rewriting.optimize(network)
|
||||
|
||||
# Network -> Engine
|
||||
start = time.time()
|
||||
engine = builder.build_engine(network, builder_config)
|
||||
end = time.time()
|
||||
self.build_time = round(end - start, 2)
|
||||
|
||||
if self.output_dir is not None:
|
||||
os.makedirs(self.output_dir, exist_ok=True)
|
||||
self.serialize_path = os.path.join(self.output_dir,
|
||||
self.engine_name)
|
||||
serialize_engine(engine, self.serialize_path)
|
||||
if self.runtime_rank == 0:
|
||||
config_path = os.path.join(self.output_dir, 'config.json')
|
||||
builder_config.plugin_config = network.plugin_config
|
||||
builder.save_config(builder_config, config_path)
|
||||
return engine
|
||||
|
||||
def run(self, inputs, config):
|
||||
def run(self, inputs, config, benchmark_profiler=None):
|
||||
batch_size, inlen, outlen = config[0], config[1], config[2]
|
||||
self.decoder.setup(batch_size, inlen, outlen, beam_width=self.num_beams)
|
||||
if self.remove_input_padding:
|
||||
self.decoder.decode_batch(inputs[0], self.sampling_config)
|
||||
self.decoder.decode_batch(inputs[0],
|
||||
self.sampling_config,
|
||||
benchmark_profiler=benchmark_profiler)
|
||||
else:
|
||||
self.decoder.decode(inputs[0], inputs[1], self.sampling_config)
|
||||
self.decoder.decode(inputs[0],
|
||||
inputs[1],
|
||||
self.sampling_config,
|
||||
benchmark_profiler=benchmark_profiler)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
def report(self, config, latency, percentile95, percentile99, peak_gpu_used,
|
||||
csv):
|
||||
def report(self,
|
||||
config,
|
||||
latency,
|
||||
percentile95,
|
||||
percentile99,
|
||||
peak_gpu_used,
|
||||
csv,
|
||||
benchmark_profiler=None):
|
||||
report_dict = super().get_report_dict()
|
||||
batch_size, inlen, outlen = config[0], config[1], config[2]
|
||||
tokens_per_sec = round(batch_size * outlen / (latency / 1000), 2)
|
||||
@ -582,6 +186,20 @@ class GPTBenchmark(BaseBenchmark):
|
||||
report_dict["percentile95(ms)"] = percentile95
|
||||
report_dict["percentile99(ms)"] = percentile99
|
||||
report_dict["gpu_peak_mem(gb)"] = peak_gpu_used
|
||||
if benchmark_profiler is not None:
|
||||
iter_count = benchmark_profiler.get_aux_info('iter_count')
|
||||
generation_time_ms = benchmark_profiler.get_timer_value(
|
||||
'generation_time')
|
||||
generation_step_count = benchmark_profiler.get_aux_info(
|
||||
'generation_step_count')
|
||||
token_per_step = batch_size * self.num_beams
|
||||
total_tokens = generation_step_count * token_per_step
|
||||
report_dict["generation_time(ms)"] = generation_time_ms / iter_count
|
||||
report_dict["total_generated_tokens"] = total_tokens / iter_count
|
||||
tokens_per_second = round(
|
||||
total_tokens * 1000.0 / generation_time_ms, 3)
|
||||
report_dict["generation_tokens_per_second"] = tokens_per_second
|
||||
|
||||
if self.runtime_rank == 0:
|
||||
if csv:
|
||||
line = ",".join([str(v) for v in report_dict.values()])
|
||||
|
||||
@ -112,31 +112,6 @@ private:
|
||||
std::atomic<bool> shutdown_requested_;
|
||||
void decoupled_execution_loop();
|
||||
std::shared_ptr<std::thread> worker_thread_;
|
||||
inline static const std::string kInputIdsTensorName_ = "input_ids";
|
||||
inline static const std::string kDraftInputIdsTensorName_ = "draft_input_ids";
|
||||
inline static const std::string kMaxNewTokensTensorName_ = "request_output_len";
|
||||
inline static const std::string kBeamWidthTensorName_ = "beam_width";
|
||||
inline static const std::string kEndIdTensorName_ = "end_id";
|
||||
inline static const std::string kPadIdTensorName_ = "pad_id";
|
||||
inline static const std::string kBadWordsListTensorName_ = "bad_words_list";
|
||||
inline static const std::string kStopWordsListTensorName_ = "stop_words_list";
|
||||
inline static const std::string kEmbeddingBiasTensorName_ = "embedding_bias";
|
||||
inline static const std::string kTemperatureTensorName_ = "temperature";
|
||||
inline static const std::string kRuntimeTopKTensorName_ = "runtime_top_k";
|
||||
inline static const std::string kRuntimeTopPTensorName_ = "runtime_top_p";
|
||||
inline static const std::string kLengthPenaltyTensorName_ = "len_penalty";
|
||||
inline static const std::string kRepetitionPenaltyTensorName_ = "repetition_penalty";
|
||||
inline static const std::string kMinLengthTensorName_ = "min_length";
|
||||
inline static const std::string kPresencePenaltyTensorName_ = "presence_penalty";
|
||||
inline static const std::string kRandomSeedTensorName_ = "random_seed";
|
||||
inline static const std::string kReturnLogProbsTensorName_ = "return_log_probs";
|
||||
inline static const std::string kPromptEmbeddingTableName_ = "prompt_embedding_table";
|
||||
inline static const std::string kPromptVocabSizeName_ = "prompt_vocab_size";
|
||||
inline static const std::string kOutputIdsTensorName_ = "output_ids";
|
||||
inline static const std::string kSequenceLengthTensorName_ = "sequence_length";
|
||||
inline static const std::string kLogProbsTensorName_ = "output_log_probs";
|
||||
inline static const std::string kCumLogProbsTensorName_ = "cum_log_probs";
|
||||
|
||||
std::shared_ptr<nvinfer1::ILogger> mLogger{};
|
||||
};
|
||||
|
||||
|
||||
@ -16,113 +16,82 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cassert>
|
||||
#include <chrono>
|
||||
#include <cstdio>
|
||||
#include <cstdlib>
|
||||
#include <cstring>
|
||||
#include <iostream>
|
||||
#include <list>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <mutex>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <thread>
|
||||
#include <tuple>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorrt_llm/batch_manager/NamedTensor.h"
|
||||
#include "tensorrt_llm/batch_manager/namedTensor.h"
|
||||
#include "tensorrt_llm/common/logger.h"
|
||||
#include "tensorrt_llm/runtime/common.h"
|
||||
#include "tensorrt_llm/runtime/iTensor.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <array>
|
||||
#include <cstdint>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
namespace tensorrt_llm::batch_manager
|
||||
{
|
||||
|
||||
template <typename TTensor, typename TTensorMap>
|
||||
namespace inference_request
|
||||
{
|
||||
// Input tensors
|
||||
auto constexpr kInputIdsTensorName = "input_ids";
|
||||
auto constexpr kDraftInputIdsTensorName = "draft_input_ids";
|
||||
auto constexpr kDraftLogitsTensorName = "draft_logits";
|
||||
auto constexpr kMaxNewTokensTensorName = "request_output_len";
|
||||
auto constexpr kBeamWidthTensorName = "beam_width";
|
||||
auto constexpr kEndIdTensorName = "end_id";
|
||||
auto constexpr kPadIdTensorName = "pad_id";
|
||||
auto constexpr kBadWordsListTensorName = "bad_words_list";
|
||||
auto constexpr kStopWordsListTensorName = "stop_words_list";
|
||||
auto constexpr kEmbeddingBiasTensorName = "embedding_bias";
|
||||
auto constexpr kTemperatureTensorName = "temperature";
|
||||
auto constexpr kRuntimeTopKTensorName = "runtime_top_k";
|
||||
auto constexpr kRuntimeTopPTensorName = "runtime_top_p";
|
||||
auto constexpr kLengthPenaltyTensorName = "len_penalty";
|
||||
auto constexpr kRepetitionPenaltyTensorName = "repetition_penalty";
|
||||
auto constexpr kMinLengthTensorName = "min_length";
|
||||
auto constexpr kPresencePenaltyTensorName = "presence_penalty";
|
||||
auto constexpr kRandomSeedTensorName = "random_seed";
|
||||
auto constexpr kReturnLogProbsTensorName = "return_log_probs";
|
||||
auto constexpr kPromptEmbeddingTableName = "prompt_embedding_table";
|
||||
auto constexpr kPromptVocabSizeName = "prompt_vocab_size";
|
||||
|
||||
// Output tensors
|
||||
auto constexpr kOutputIdsTensorName = "output_ids";
|
||||
auto constexpr kSequenceLengthTensorName = "sequence_length";
|
||||
auto constexpr kLogProbsTensorName = "output_log_probs";
|
||||
auto constexpr kCumLogProbsTensorName = "cum_log_probs";
|
||||
} // namespace inference_request
|
||||
|
||||
template <typename TTensor, typename TNamedTensor>
|
||||
class GenericInferenceRequest
|
||||
{
|
||||
public:
|
||||
using TensorPtr = TTensor;
|
||||
using TensorMap = TTensorMap;
|
||||
using NamedTensorType = TNamedTensor;
|
||||
using TensorMap = std::unordered_map<std::string, TTensor>;
|
||||
|
||||
GenericInferenceRequest(uint64_t requestId)
|
||||
: mRequestId(requestId)
|
||||
, mIsStreaming(false)
|
||||
explicit GenericInferenceRequest(uint64_t requestId)
|
||||
: mRequestId{requestId}
|
||||
, mIsStreaming{false}
|
||||
{
|
||||
}
|
||||
|
||||
GenericInferenceRequest(TensorMap const& inputTensors, uint64_t requestId)
|
||||
: mInputTensors(inputTensors)
|
||||
, mRequestId(requestId)
|
||||
, mIsStreaming(false)
|
||||
GenericInferenceRequest(uint64_t requestId, TensorMap&& tensorMap)
|
||||
: mRequestId{requestId}
|
||||
, mIsStreaming{false}
|
||||
, mInputTensors{std::move(tensorMap)}
|
||||
{
|
||||
}
|
||||
|
||||
GenericInferenceRequest(TensorMap&& inputTensors, uint64_t requestId)
|
||||
: mInputTensors(std::move(inputTensors))
|
||||
, mRequestId(requestId)
|
||||
, mIsStreaming(false)
|
||||
{
|
||||
}
|
||||
|
||||
~GenericInferenceRequest() {}
|
||||
|
||||
template <typename T>
|
||||
std::tuple<bool, T> getScalarValueFromTensor(
|
||||
const std::string& inputTensorName, const std::vector<int64_t>& expectedShape, const bool is_optional) const
|
||||
{
|
||||
T scalarValue;
|
||||
try
|
||||
for (auto const& [name, tensor] : mInputTensors)
|
||||
{
|
||||
const auto& t = getInputTensor(inputTensorName);
|
||||
std::vector<int64_t> tensorShape(t->getShape().nbDims);
|
||||
for (int32_t i = 0; i < t->getShape().nbDims; ++i)
|
||||
{
|
||||
tensorShape[i] = t->getShape().d[i];
|
||||
}
|
||||
|
||||
if (tensorShape != expectedShape)
|
||||
{
|
||||
std::string err = "Invalid shape for " + inputTensorName + ". Expected shape: [";
|
||||
for (auto shape : expectedShape)
|
||||
{
|
||||
err += std::to_string(shape) + ",";
|
||||
}
|
||||
if (!expectedShape.empty())
|
||||
{
|
||||
// Remove last comma
|
||||
err.pop_back();
|
||||
}
|
||||
err += "]";
|
||||
|
||||
throw std::runtime_error(err);
|
||||
}
|
||||
scalarValue = *static_cast<T*>(t->data());
|
||||
validateTensorName(name);
|
||||
}
|
||||
catch (const std::exception& e)
|
||||
{
|
||||
// If parameter is optional, just ignore it
|
||||
if (is_optional)
|
||||
{
|
||||
return {false, scalarValue};
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cerr << "Out of Range error for tensor: " << inputTensorName << ": " << e.what() << '\n';
|
||||
throw;
|
||||
}
|
||||
}
|
||||
return {true, scalarValue};
|
||||
}
|
||||
|
||||
const TensorPtr& getInputTensor(std::string const& inputTensorName) const
|
||||
GenericInferenceRequest(uint64_t requestId, TensorMap const& tensorMap)
|
||||
: GenericInferenceRequest(requestId, TensorMap{tensorMap})
|
||||
{
|
||||
return mInputTensors.at(inputTensorName);
|
||||
}
|
||||
|
||||
void emplaceInputTensor(std::string const& inputTensorName, TensorPtr&& inputTensor)
|
||||
{
|
||||
mInputTensors.emplace(inputTensorName, std::move(inputTensor));
|
||||
}
|
||||
|
||||
void setIsStreaming(bool isStreaming)
|
||||
@ -130,91 +99,152 @@ public:
|
||||
mIsStreaming = isStreaming;
|
||||
}
|
||||
|
||||
bool isStreaming() const
|
||||
[[nodiscard]] bool isStreaming() const
|
||||
{
|
||||
return mIsStreaming;
|
||||
}
|
||||
|
||||
uint64_t getRequestId() const
|
||||
[[nodiscard]] uint64_t getRequestId() const
|
||||
{
|
||||
return mRequestId;
|
||||
}
|
||||
|
||||
static std::array constexpr kTensorNames = {
|
||||
inference_request::kInputIdsTensorName,
|
||||
inference_request::kDraftInputIdsTensorName,
|
||||
inference_request::kDraftLogitsTensorName,
|
||||
inference_request::kMaxNewTokensTensorName,
|
||||
inference_request::kBeamWidthTensorName,
|
||||
inference_request::kEndIdTensorName,
|
||||
inference_request::kPadIdTensorName,
|
||||
inference_request::kBadWordsListTensorName,
|
||||
inference_request::kStopWordsListTensorName,
|
||||
inference_request::kEmbeddingBiasTensorName,
|
||||
inference_request::kTemperatureTensorName,
|
||||
inference_request::kRuntimeTopKTensorName,
|
||||
inference_request::kRuntimeTopPTensorName,
|
||||
inference_request::kLengthPenaltyTensorName,
|
||||
inference_request::kRepetitionPenaltyTensorName,
|
||||
inference_request::kMinLengthTensorName,
|
||||
inference_request::kPresencePenaltyTensorName,
|
||||
inference_request::kRandomSeedTensorName,
|
||||
inference_request::kReturnLogProbsTensorName,
|
||||
inference_request::kPromptEmbeddingTableName,
|
||||
inference_request::kPromptVocabSizeName,
|
||||
};
|
||||
|
||||
#define TENSOR_GETTER_SETTER(funcName, tensorName) \
|
||||
\
|
||||
[[nodiscard]] bool has##funcName() const \
|
||||
{ \
|
||||
return mInputTensors.find(tensorName) != mInputTensors.end(); \
|
||||
} \
|
||||
\
|
||||
[[nodiscard]] TensorPtr const& get##funcName() const \
|
||||
{ \
|
||||
auto it = mInputTensors.find(tensorName); \
|
||||
TLLM_CHECK_WITH_INFO(it != mInputTensors.end(), "Undefined tensor: %s", tensorName); \
|
||||
return it->second; \
|
||||
} \
|
||||
\
|
||||
[[nodiscard]] TensorPtr get##funcName##Unchecked() const \
|
||||
{ \
|
||||
auto it = mInputTensors.find(tensorName); \
|
||||
return it != mInputTensors.end() ? it->second : TensorPtr{}; \
|
||||
} \
|
||||
\
|
||||
[[nodiscard]] NamedTensorType get##funcName##Named() const \
|
||||
{ \
|
||||
auto it = mInputTensors.find(tensorName); \
|
||||
return it != mInputTensors.end() ? NamedTensorType{it->second, tensorName} : NamedTensor{tensorName}; \
|
||||
} \
|
||||
\
|
||||
void set##funcName(TensorPtr const& tensor) \
|
||||
{ \
|
||||
mInputTensors[tensorName] = tensor; \
|
||||
}
|
||||
|
||||
TENSOR_GETTER_SETTER(InputIds, inference_request::kInputIdsTensorName)
|
||||
TENSOR_GETTER_SETTER(DraftInputIds, inference_request::kDraftInputIdsTensorName)
|
||||
TENSOR_GETTER_SETTER(DraftLogits, inference_request::kDraftLogitsTensorName)
|
||||
TENSOR_GETTER_SETTER(MaxNewTokens, inference_request::kMaxNewTokensTensorName)
|
||||
TENSOR_GETTER_SETTER(BeamWidth, inference_request::kBeamWidthTensorName)
|
||||
TENSOR_GETTER_SETTER(EndId, inference_request::kEndIdTensorName)
|
||||
TENSOR_GETTER_SETTER(PadId, inference_request::kPadIdTensorName)
|
||||
TENSOR_GETTER_SETTER(BadWordsList, inference_request::kBadWordsListTensorName)
|
||||
TENSOR_GETTER_SETTER(StopWordsList, inference_request::kStopWordsListTensorName)
|
||||
TENSOR_GETTER_SETTER(EmbeddingBias, inference_request::kEmbeddingBiasTensorName)
|
||||
TENSOR_GETTER_SETTER(Temperature, inference_request::kTemperatureTensorName)
|
||||
TENSOR_GETTER_SETTER(RuntimeTopK, inference_request::kRuntimeTopKTensorName)
|
||||
TENSOR_GETTER_SETTER(RuntimeTopP, inference_request::kRuntimeTopPTensorName)
|
||||
TENSOR_GETTER_SETTER(LengthPenalty, inference_request::kLengthPenaltyTensorName)
|
||||
TENSOR_GETTER_SETTER(RepetitionPenalty, inference_request::kRepetitionPenaltyTensorName)
|
||||
TENSOR_GETTER_SETTER(MinLength, inference_request::kMinLengthTensorName)
|
||||
TENSOR_GETTER_SETTER(PresencePenalty, inference_request::kPresencePenaltyTensorName)
|
||||
TENSOR_GETTER_SETTER(RandomSeed, inference_request::kRandomSeedTensorName)
|
||||
TENSOR_GETTER_SETTER(ReturnLogProbs, inference_request::kReturnLogProbsTensorName)
|
||||
TENSOR_GETTER_SETTER(PromptEmbeddingTable, inference_request::kPromptEmbeddingTableName)
|
||||
TENSOR_GETTER_SETTER(PromptVocabSize, inference_request::kPromptVocabSizeName)
|
||||
|
||||
#undef TENSOR_GETTER_SETTER
|
||||
|
||||
protected:
|
||||
TensorMap mInputTensors;
|
||||
static void validateTensorName(std::string const& tensorName)
|
||||
{
|
||||
// TODO (martinma): Throw an exception if the tensor name is not valid.
|
||||
if (std::find(kTensorNames.begin(), kTensorNames.end(), tensorName) == kTensorNames.end())
|
||||
{
|
||||
TLLM_LOG_WARNING("Invalid tensor name in InferenceRequest: %s", tensorName.c_str());
|
||||
}
|
||||
}
|
||||
|
||||
uint64_t mRequestId;
|
||||
bool mIsStreaming;
|
||||
TensorMap mInputTensors;
|
||||
};
|
||||
|
||||
class InferenceRequest : public GenericInferenceRequest<tensorrt_llm::runtime::ITensor::SharedPtr,
|
||||
tensorrt_llm::runtime::StringPtrMap<tensorrt_llm::runtime::ITensor>>
|
||||
class InferenceRequest : public GenericInferenceRequest<tensorrt_llm::runtime::ITensor::SharedPtr, NamedTensor>
|
||||
{
|
||||
public:
|
||||
using Base = GenericInferenceRequest<tensorrt_llm::runtime::ITensor::SharedPtr,
|
||||
tensorrt_llm::runtime::StringPtrMap<tensorrt_llm::runtime::ITensor>>;
|
||||
using Base = GenericInferenceRequest<tensorrt_llm::runtime::ITensor::SharedPtr, NamedTensor>;
|
||||
using TensorPtr = Base::TensorPtr;
|
||||
using TensorMap = Base::TensorMap;
|
||||
|
||||
InferenceRequest(uint64_t requestId)
|
||||
explicit InferenceRequest(uint64_t requestId)
|
||||
: Base(requestId)
|
||||
{
|
||||
}
|
||||
|
||||
InferenceRequest(TensorMap const& inputTensors, uint64_t requestId)
|
||||
: Base(inputTensors, requestId)
|
||||
: Base(requestId, inputTensors)
|
||||
{
|
||||
}
|
||||
|
||||
InferenceRequest(TensorMap&& inputTensors, uint64_t requestId)
|
||||
: Base(inputTensors, requestId)
|
||||
: Base(requestId, std::move(inputTensors))
|
||||
{
|
||||
}
|
||||
|
||||
const std::vector<int64_t> serialize() const
|
||||
[[deprecated("Use direct tensor access instead")]] [[nodiscard]] TensorPtr const& getInputTensor(
|
||||
std::string const& inputTensorName) const
|
||||
{
|
||||
std::list<int64_t> packed;
|
||||
// mInputTensors
|
||||
packed.push_back(static_cast<int64_t>(mInputTensors.size()));
|
||||
for (auto it = mInputTensors.begin(); it != mInputTensors.end(); ++it)
|
||||
{
|
||||
NamedTensor nt(it->second, it->first);
|
||||
auto packed_tensor = nt.serialize();
|
||||
packed.push_back(static_cast<int64_t>(packed_tensor.size()));
|
||||
packed.insert(packed.end(), packed_tensor.begin(), packed_tensor.end());
|
||||
}
|
||||
// mRequestId
|
||||
packed.push_back(static_cast<int64_t>(mRequestId));
|
||||
// mIsStreaming
|
||||
packed.push_back(mIsStreaming ? 1 : 0);
|
||||
// done
|
||||
std::vector<int64_t> vpacked{
|
||||
std::make_move_iterator(std::begin(packed)), std::make_move_iterator(std::end(packed))};
|
||||
return vpacked;
|
||||
auto it = Base::mInputTensors.find(inputTensorName);
|
||||
TLLM_CHECK_WITH_INFO(it != Base::mInputTensors.end(), "Invalid input tensor name: %s", inputTensorName.c_str());
|
||||
return it->second;
|
||||
}
|
||||
|
||||
static std::shared_ptr<InferenceRequest> deserialize(const std::vector<int64_t>& packed)
|
||||
[[deprecated("Use direct tensor access instead")]] void emplaceInputTensor(
|
||||
std::string const& inputTensorName, TensorPtr inputTensor)
|
||||
{
|
||||
return InferenceRequest::deserialize(packed.data());
|
||||
validateTensorName(inputTensorName);
|
||||
Base::mInputTensors[inputTensorName] = std::move(inputTensor);
|
||||
}
|
||||
|
||||
static std::shared_ptr<InferenceRequest> deserialize(const int64_t* packed_ptr)
|
||||
{
|
||||
int64_t num_tensors = *packed_ptr++;
|
||||
TensorMap InputTensors;
|
||||
for (int64_t i = 0; i < num_tensors; ++i)
|
||||
{
|
||||
int64_t n = *packed_ptr++;
|
||||
auto inputTensor = NamedTensor::deserialize(packed_ptr);
|
||||
packed_ptr += n;
|
||||
auto inputTensorName = inputTensor.name;
|
||||
InputTensors.emplace(inputTensorName, std::move(inputTensor.tensor));
|
||||
}
|
||||
uint64_t RequestId = static_cast<uint64_t>(*packed_ptr++);
|
||||
bool IsStreaming = *packed_ptr++ != 0;
|
||||
std::shared_ptr<InferenceRequest> ir = std::make_shared<InferenceRequest>(InputTensors, RequestId);
|
||||
ir->setIsStreaming(IsStreaming);
|
||||
return ir;
|
||||
}
|
||||
[[nodiscard]] std::vector<int64_t> serialize() const;
|
||||
|
||||
static std::shared_ptr<InferenceRequest> deserialize(const std::vector<int64_t>& packed);
|
||||
|
||||
static std::shared_ptr<InferenceRequest> deserialize(const int64_t* packed_ptr);
|
||||
};
|
||||
|
||||
} // namespace tensorrt_llm::batch_manager
|
||||
|
||||
@ -55,7 +55,8 @@ public:
|
||||
std::optional<TensorPtr> stopWordsList = std::nullopt,
|
||||
std::optional<TensorPtr> promptEmbeddingTable = std::nullopt,
|
||||
std::optional<SizeType> promptVocabSize = std::nullopt, bool returnLogProbs = false,
|
||||
std::optional<std::shared_ptr<VecTokens>> draftTokens = std::nullopt)
|
||||
std::optional<std::shared_ptr<VecTokens>> draftTokens = std::nullopt,
|
||||
std::optional<TensorPtr> draftLogits = std::nullopt)
|
||||
: mRequestId(requestId)
|
||||
, mPromptLen(inputTokens->size())
|
||||
, mMaxNewTokens(maxNewTokens)
|
||||
@ -75,6 +76,7 @@ public:
|
||||
, mLogProbs(samplingConfig.beamWidth)
|
||||
, mCumLogProbs(samplingConfig.beamWidth)
|
||||
, mDraftTokens(draftTokens.value_or(std::make_shared<VecTokens>()))
|
||||
, mDraftLogits(draftLogits)
|
||||
{
|
||||
mMaxSentTokenPos = mPromptLen - 1;
|
||||
// Scatter the input tokens to other beam
|
||||
@ -89,6 +91,13 @@ public:
|
||||
TLLM_LOG_ERROR(errStr);
|
||||
throw std::runtime_error(errStr);
|
||||
}
|
||||
|
||||
if (draftLogits.has_value() && !draftTokens.has_value())
|
||||
{
|
||||
std::string errStr = "Draft tokens must be specified when draft logits are given.";
|
||||
TLLM_LOG_ERROR(errStr);
|
||||
throw std::runtime_error(errStr);
|
||||
}
|
||||
}
|
||||
|
||||
/// @brief Get total number of tokens for this req (prompt + generated)
|
||||
@ -135,6 +144,13 @@ public:
|
||||
return mDraftTokens;
|
||||
}
|
||||
|
||||
/// @brief Get the logits for the draft tokens
|
||||
/// @return Tensor of draft logits
|
||||
std::optional<TensorPtr> getDraftLogits() const
|
||||
{
|
||||
return mDraftLogits;
|
||||
}
|
||||
|
||||
/// @brief Returns true if request has draft tokens
|
||||
/// @return flag
|
||||
bool hasDraftTokens() const
|
||||
@ -162,7 +178,7 @@ public:
|
||||
/// beamTokens is expected to be of size beamWidth
|
||||
void addNewTokens(const std::vector<TokenIdType>& beamTokens)
|
||||
{
|
||||
assert(mSamplingConfig.beamWidth == beamTokens.size());
|
||||
assert(static_cast<size_t>(mSamplingConfig.beamWidth) == beamTokens.size());
|
||||
for (std::size_t beam = 0; beam < beamTokens.size(); ++beam)
|
||||
{
|
||||
const auto outputId = beamTokens[beam];
|
||||
@ -174,7 +190,7 @@ public:
|
||||
/// @param generatedBeamTokens The generated tokens for all beams (vector of vector of tokens)
|
||||
void setGeneratedTokens(const BeamTokens& generatedBeamTokens)
|
||||
{
|
||||
assert(generatedBeamTokens.size() == mSamplingConfig.beamWidth);
|
||||
assert(generatedBeamTokens.size() == static_cast<size_t>(mSamplingConfig.beamWidth));
|
||||
for (std::size_t beam = 0; beam < generatedBeamTokens.size(); ++beam)
|
||||
{
|
||||
auto& beamTokens = mTokens[beam];
|
||||
@ -305,6 +321,11 @@ public:
|
||||
mDraftTokens = draftTokens;
|
||||
}
|
||||
|
||||
void setDraftLogits(const std::optional<TensorPtr>& draftLogits)
|
||||
{
|
||||
mDraftLogits = draftLogits;
|
||||
}
|
||||
|
||||
RequestIdType mRequestId;
|
||||
SizeType mPromptLen;
|
||||
SizeType mMaxNewTokens;
|
||||
@ -333,6 +354,7 @@ protected:
|
||||
std::vector<VecLogProbs> mLogProbs; // [beamSize, seqLen]
|
||||
VecLogProbs mCumLogProbs; // [beamSize]
|
||||
std::shared_ptr<VecTokens> mDraftTokens;
|
||||
std::optional<TensorPtr> mDraftLogits;
|
||||
};
|
||||
|
||||
class LlmRequest : public GenericLlmRequest<runtime::ITensor::SharedPtr>
|
||||
@ -353,9 +375,11 @@ public:
|
||||
std::optional<TensorPtr> badWordsList = std::nullopt, std::optional<TensorPtr> stopWordsList = std::nullopt,
|
||||
std::optional<TensorPtr> promptEmbeddingTable = std::nullopt,
|
||||
std::optional<SizeType> promptVocabSize = std::nullopt, bool returnLogProbs = false,
|
||||
std::optional<std::shared_ptr<VecTokens>> draftTokens = std::nullopt)
|
||||
std::optional<std::shared_ptr<VecTokens>> draftTokens = std::nullopt,
|
||||
std::optional<TensorPtr> draftLogits = std::nullopt)
|
||||
: Base(requestId, maxNewTokens, inputTokens, samplingConfig, isStreaming, endId, padId, embeddingBias,
|
||||
badWordsList, stopWordsList, promptEmbeddingTable, promptVocabSize, returnLogProbs, draftTokens)
|
||||
badWordsList, stopWordsList, promptEmbeddingTable, promptVocabSize, returnLogProbs, draftTokens,
|
||||
draftLogits)
|
||||
{
|
||||
}
|
||||
|
||||
|
||||
@ -18,11 +18,14 @@
|
||||
|
||||
#include "tensorrt_llm/runtime/iTensor.h"
|
||||
|
||||
#include <string>
|
||||
|
||||
namespace tensorrt_llm::batch_manager
|
||||
{
|
||||
template <typename TTensor>
|
||||
struct GenericNamedTensor
|
||||
class GenericNamedTensor
|
||||
{
|
||||
public:
|
||||
using TensorPtr = TTensor;
|
||||
|
||||
TensorPtr tensor;
|
||||
@ -31,24 +34,32 @@ struct GenericNamedTensor
|
||||
GenericNamedTensor() = default;
|
||||
~GenericNamedTensor() = default;
|
||||
|
||||
// Host Tensor constructor
|
||||
GenericNamedTensor(
|
||||
nvinfer1::DataType _type, std::vector<int64_t> const& _shape, std::string _name, const void* _data = nullptr);
|
||||
|
||||
GenericNamedTensor(TensorPtr _tensor, std::string _name)
|
||||
: tensor(std::move(_tensor))
|
||||
, name(std::move(_name))
|
||||
: tensor{std::move(_tensor)}
|
||||
, name{std::move(_name)}
|
||||
{
|
||||
}
|
||||
|
||||
GenericNamedTensor(std::string _name)
|
||||
: name(std::move(_name))
|
||||
explicit GenericNamedTensor(std::string _name)
|
||||
: tensor{}
|
||||
, name{std::move(_name)}
|
||||
{
|
||||
}
|
||||
|
||||
TensorPtr operator()()
|
||||
{
|
||||
return tensor;
|
||||
}
|
||||
|
||||
TensorPtr const& operator()() const
|
||||
{
|
||||
return tensor;
|
||||
}
|
||||
};
|
||||
|
||||
struct NamedTensor : public GenericNamedTensor<tensorrt_llm::runtime::ITensor::SharedPtr>
|
||||
class NamedTensor : public GenericNamedTensor<tensorrt_llm::runtime::ITensor::SharedPtr>
|
||||
{
|
||||
public:
|
||||
using Base = GenericNamedTensor<tensorrt_llm::runtime::ITensor::SharedPtr>;
|
||||
using TensorPtr = Base::TensorPtr;
|
||||
|
||||
@ -56,9 +67,13 @@ struct NamedTensor : public GenericNamedTensor<tensorrt_llm::runtime::ITensor::S
|
||||
nvinfer1::DataType _type, std::vector<int64_t> const& _shape, std::string _name, const void* _data = nullptr);
|
||||
|
||||
NamedTensor(TensorPtr _tensor, std::string _name)
|
||||
: Base(_tensor, _name){};
|
||||
: Base(std::move(_tensor), std::move(_name)){};
|
||||
|
||||
explicit NamedTensor(std::string _name)
|
||||
: Base(std::move(_name)){};
|
||||
|
||||
[[nodiscard]] std::vector<int64_t> serialize() const;
|
||||
|
||||
std::vector<int64_t> serialize();
|
||||
static NamedTensor deserialize(const int64_t* packed);
|
||||
};
|
||||
} // namespace tensorrt_llm::batch_manager
|
||||
@ -33,16 +33,19 @@ public:
|
||||
using SizeType = tensorrt_llm::runtime::SizeType;
|
||||
|
||||
explicit TrtGptModelOptionalParams(KvCacheConfig const& kvCacheConfig = KvCacheConfig{},
|
||||
std::optional<SizeType> maxNumSequences = std::nullopt, bool enableTrtOverlap = true)
|
||||
std::optional<SizeType> maxNumSequences = std::nullopt, bool enableTrtOverlap = true,
|
||||
bool useContextFMHAForGeneration = false)
|
||||
: kvCacheConfig{kvCacheConfig}
|
||||
, maxNumSequences{maxNumSequences}
|
||||
, enableTrtOverlap{enableTrtOverlap}
|
||||
, useContextFMHAForGeneration(useContextFMHAForGeneration)
|
||||
{
|
||||
}
|
||||
|
||||
KvCacheConfig kvCacheConfig;
|
||||
std::optional<SizeType> maxNumSequences;
|
||||
bool enableTrtOverlap;
|
||||
bool useContextFMHAForGeneration;
|
||||
};
|
||||
|
||||
} // namespace tensorrt_llm::batch_manager
|
||||
|
||||
@ -71,9 +71,7 @@ public:
|
||||
// Vector of views on newTokensSteps for each token. Elements are on gpu.
|
||||
|
||||
// optional parameters
|
||||
TensorPtr finishedSteps; // [maxTokensPerStep, batchSize, beamWidth] finished states at each generated token of
|
||||
// maxTokensPerStep, on gpu
|
||||
TensorPtr finished; // [batchSize, beamWidth], usually a view of finishedSteps for current token.
|
||||
TensorPtr finished; // [batchSize, beamWidth],
|
||||
// Set to true by decoding if any of the stop conditions are met or if DecodingInput.finished is
|
||||
// true. In beam search and to determine whether to stop according to
|
||||
// DecodingInput.sequenceLimitLength, on gpu
|
||||
|
||||
@ -21,6 +21,7 @@
|
||||
#include "tensorrt_llm/runtime/decodingInput.h"
|
||||
#include "tensorrt_llm/runtime/decodingOutput.h"
|
||||
#include "tensorrt_llm/runtime/samplingConfig.h"
|
||||
#include <curand_kernel.h>
|
||||
|
||||
#include <cstdint>
|
||||
#include <memory>
|
||||
@ -55,9 +56,17 @@ public:
|
||||
DecodingInput const& decodingInput, BufferManager const& manager)
|
||||
= 0;
|
||||
|
||||
static void acceptTokens(const ITensor& targetTokenIds, const ITensor& draftTokenIds, const ITensor& contextLengths,
|
||||
const ITensor& numDraftTokens, ITensor& sequenceLengths, const ITensor& finishedVec, ITensor& finishedFinal,
|
||||
ITensor& finishedSum, BufferManager::CudaStreamPtr const& stream);
|
||||
virtual const SamplingConfig& getSamplingConfig() = 0;
|
||||
|
||||
static void acceptDraftTokensByIds(const ITensor& targetTokenIds, const ITensor& draftTokenIds,
|
||||
const ITensor& contextLengths, const ITensor& numDraftTokens, ITensor& sequenceLengths,
|
||||
const ITensor& finishedVec, ITensor& finishedFinal, ITensor& finishedSum,
|
||||
BufferManager::CudaStreamPtr const& stream);
|
||||
|
||||
static void acceptDraftTokensByLogits(ITensor& draftLogits, const ITensor& targetLogits, ITensor& draftProbs,
|
||||
ITensor& targetProbs, const ITensor& numDraftTokens, ITensor& finished, SizeType vocabSize,
|
||||
SizeType vocabSizePadded, bool useRandomAcceptThreshold, float randomAcceptThreshold,
|
||||
curandState_t* curandState, BufferManager::CudaStreamPtr const& stream);
|
||||
|
||||
static std::unique_ptr<IGptDecoder> create(
|
||||
nvinfer1::DataType dtype, size_t vocabSize, size_t vocabSizePadded, BufferManager::CudaStreamPtr const& stream);
|
||||
@ -82,6 +91,11 @@ public:
|
||||
void gatherTree(ITensor& finalOutputIds, DecodingOutput const& decodingOutput, DecodingInput const& decodingInput,
|
||||
BufferManager const& manager) override;
|
||||
|
||||
const SamplingConfig& getSamplingConfig() override
|
||||
{
|
||||
return mSamplingConfig;
|
||||
}
|
||||
|
||||
private:
|
||||
BufferManager mManager;
|
||||
|
||||
@ -90,6 +104,7 @@ private:
|
||||
|
||||
TensorPtr mLogProbsTiled; // Buffer used to store the transpose of the logProbs. Needed because the kernels have
|
||||
// been written to use that shape.
|
||||
SamplingConfig mSamplingConfig;
|
||||
};
|
||||
|
||||
inline std::unique_ptr<IGptDecoder> IGptDecoder::create(
|
||||
|
||||
@ -181,7 +181,10 @@ private:
|
||||
DecodingOutputPtr mJointDecodingOutput;
|
||||
|
||||
std::vector<TensorPtr> mDraftTokenIds;
|
||||
std::vector<TensorPtr> mDraftLogits;
|
||||
std::vector<bool> mAcceptByLogits;
|
||||
TensorPtr mNumDraftTokens;
|
||||
TensorPtr mCurandStates;
|
||||
|
||||
std::vector<SizeType> mNbSteps;
|
||||
std::vector<bool> mFinished;
|
||||
@ -189,6 +192,13 @@ private:
|
||||
std::vector<SizeType> mMaxNewTokens;
|
||||
std::vector<SizeType> mBeamWidths;
|
||||
std::vector<SizeType> mGeneratedTokensPerStep;
|
||||
|
||||
TensorPtr mFinishedSteps; // [maxTokensPerStep, batchSize, beamWidth] finished states of type FinishedState
|
||||
// for each generated token of maxTokensPerStep, on gpu
|
||||
TensorPtr mDraftProbs; // [batchSize, maxDraftTokens, beamWidth, vocabPadded], temporary data for speculative
|
||||
// decoding accept by logits kernel, on gpu
|
||||
TensorPtr mTargetProbs; // [batchSize, maxDraftTokens+1, beamWidth, vocabPadded], temporary data for speculative
|
||||
// decoding accept by logits kernel, on gpu
|
||||
SizeType mMaxSequenceLength{};
|
||||
SizeType mMaxKvCacheLength{};
|
||||
SizeType mActualBatchSize{};
|
||||
|
||||
@ -33,6 +33,8 @@
|
||||
#include <typeinfo>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorrt_llm/common/dataType.h"
|
||||
|
||||
namespace tensorrt_llm::runtime
|
||||
{
|
||||
|
||||
@ -213,19 +215,7 @@ public:
|
||||
|
||||
[[nodiscard]] constexpr std::size_t getSize() const noexcept
|
||||
{
|
||||
switch (static_cast<nvinfer1::DataType>(*this))
|
||||
{
|
||||
case nvinfer1::DataType::kINT64: return 8;
|
||||
case nvinfer1::DataType::kINT32: [[fallthrough]];
|
||||
case nvinfer1::DataType::kFLOAT: return 4;
|
||||
case nvinfer1::DataType::kBF16: [[fallthrough]];
|
||||
case nvinfer1::DataType::kHALF: return 2;
|
||||
case nvinfer1::DataType::kBOOL: [[fallthrough]];
|
||||
case nvinfer1::DataType::kUINT8: [[fallthrough]];
|
||||
case nvinfer1::DataType::kINT8: [[fallthrough]];
|
||||
case nvinfer1::DataType::kFP8: return 1;
|
||||
}
|
||||
return 0;
|
||||
return tensorrt_llm::common::getDTypeSize(static_cast<nvinfer1::DataType>(*this));
|
||||
}
|
||||
|
||||
private:
|
||||
|
||||
@ -65,6 +65,8 @@ public:
|
||||
std::optional<SizeType> maxNewTokens; // maximum number of tokens to generate for this request
|
||||
std::optional<SizeType> endId; // end token id
|
||||
BufferPtr draftTokens; // [generatedTokensPerStep - 1], on gpu, draft tokens from speculative decoding
|
||||
std::optional<TensorPtr>
|
||||
draftLogits; // [generatedTokensPerStep - 1, vocabSize], on gpu, draft tokens from speculative decoding
|
||||
TensorPtr embeddingBias; // [vocabSizePadded], on gpu
|
||||
TensorPtr badWordsList; // [2, badWordsLength], on gpu
|
||||
TensorPtr stopWordsList; // [2, stopWordsLength], on gpu
|
||||
|
||||
@ -22,6 +22,7 @@
|
||||
|
||||
#include <NvInferRuntime.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstdint>
|
||||
#include <functional>
|
||||
#include <initializer_list>
|
||||
@ -238,6 +239,39 @@ public:
|
||||
//!
|
||||
static std::string toString(Shape const& dims);
|
||||
|
||||
//!
|
||||
//! \brief A convenience function to compare shapes.
|
||||
//!
|
||||
static bool shapeEquals(Shape const& lhs, Shape const& rhs)
|
||||
{
|
||||
return shapeEquals(lhs, rhs.d, rhs.nbDims);
|
||||
}
|
||||
|
||||
//!
|
||||
//! \brief A convenience function to compare shapes.
|
||||
//!
|
||||
template <typename T>
|
||||
static bool shapeEquals(Shape const& lhs, T const* dims, SizeType count)
|
||||
{
|
||||
return lhs.nbDims == count && std::equal(lhs.d, lhs.d + lhs.nbDims, dims);
|
||||
}
|
||||
|
||||
bool shapeEquals(Shape const& other) const
|
||||
{
|
||||
return shapeEquals(getShape(), other);
|
||||
}
|
||||
|
||||
bool shapeEquals(std::initializer_list<SizeType> const& other) const
|
||||
{
|
||||
return shapeEquals(getShape(), other.begin(), other.size());
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool shapeEquals(T const* dims, SizeType count) const
|
||||
{
|
||||
return shapeEquals(getShape(), dims, count);
|
||||
}
|
||||
|
||||
protected:
|
||||
ITensor() = default;
|
||||
|
||||
|
||||
@ -56,6 +56,9 @@ public:
|
||||
// beam search layer
|
||||
OptVec<FloatType> beamSearchDiversityRate;
|
||||
OptVec<FloatType> lengthPenalty;
|
||||
|
||||
// speculative decoding
|
||||
OptVec<FloatType> draftAcceptanceThreshold;
|
||||
};
|
||||
|
||||
} // namespace tensorrt_llm::runtime
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:9e6a5d7dba399049a4da9ca729153e5a6080986782a314b867e7635454eb36de
|
||||
size 1705954
|
||||
oid sha256:ba982afff27c597c9f5f25bec4ed37debd883c7be2107b47776a014075899fbd
|
||||
size 1719266
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:64fae7bca97be7c3067b4544da0c3d79621ec3632c10e39b7a005d886702e8eb
|
||||
size 1706098
|
||||
oid sha256:04ec1f2f45dde1ef6b6b0f605e79715eebed38b19b4d833fcb668d2cb71f8a03
|
||||
size 1733118
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
02375d908e57e2194e3f28a4e83dd963 libtensorrt_llm_batch_manager_static.a
|
||||
e5c4994ecc347d808f6d38fb686b5cf1 libtensorrt_llm_batch_manager_static.pre_cxx11.a
|
||||
cd83045d7c127af5f907efd7c710bd6fe1f90ec4 commit
|
||||
aab384dfc59de5df4c7ecf53e30d03e9 libtensorrt_llm_batch_manager_static.a
|
||||
e0074afa6959c896f1cbc7ab90872058 libtensorrt_llm_batch_manager_static.pre_cxx11.a
|
||||
c7450aa071e91659a3e2855c0cca21021f96ada8 commit
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:f3cca913fc62df4119e4df10921be97086714740148f54c528da7bb2826f67ba
|
||||
size 1617426
|
||||
oid sha256:546c9e2b79cb3cf2623876902ef2d40c65925157d43850b2505eedf274e060a1
|
||||
size 1638840
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:3d633874e8b32a56758bf8bbdc0955ed8c5d43d531ba330a2274bcec13e1c89f
|
||||
size 1620144
|
||||
oid sha256:935a706ce0d107f8c226566a50946a0f0e35ce926c98b7a12b000b3d72e5f0b6
|
||||
size 1635602
|
||||
|
||||
@ -1,2 +1,2 @@
|
||||
f379e62b3f69afa4bd1d8e5551a6ede4 libtensorrt_llm_batch_manager_static.a
|
||||
92f44b2834d39c2c62a9b0bd0549b159 libtensorrt_llm_batch_manager_static.pre_cxx11.a
|
||||
8e0c5b31d579f4118b84a34ffb00c15a libtensorrt_llm_batch_manager_static.a
|
||||
885ea7b9f594d7aa9cc9018527b95f6d libtensorrt_llm_batch_manager_static.pre_cxx11.a
|
||||
|
||||
@ -45,11 +45,13 @@ extern bool CHECK_DEBUG_ENABLED;
|
||||
: tensorrt_llm::common::throwRuntimeError(__FILE__, __LINE__, #val); \
|
||||
} while (0)
|
||||
|
||||
#define TLLM_CHECK_WITH_INFO(val, info) \
|
||||
#define TLLM_CHECK_WITH_INFO(val, info, ...) \
|
||||
do \
|
||||
{ \
|
||||
TLLM_LIKELY(static_cast<bool>(val)) ? ((void) 0) \
|
||||
: tensorrt_llm::common::throwRuntimeError(__FILE__, __LINE__, info); \
|
||||
TLLM_LIKELY(static_cast<bool>(val)) \
|
||||
? ((void) 0) \
|
||||
: tensorrt_llm::common::throwRuntimeError( \
|
||||
__FILE__, __LINE__, tensorrt_llm::common::fmtstr(info, ##__VA_ARGS__)); \
|
||||
} while (0)
|
||||
|
||||
#define TLLM_CHECK_DEBUG(val) \
|
||||
|
||||
@ -57,6 +57,7 @@ CUDADriverWrapper::CUDADriverWrapper()
|
||||
*(void**) (&_cuModuleLoadData) = load_sym(handle, "cuModuleLoadData");
|
||||
*(void**) (&_cuLinkCreate) = load_sym(handle, "cuLinkCreate_v2");
|
||||
*(void**) (&_cuModuleGetFunction) = load_sym(handle, "cuModuleGetFunction");
|
||||
*(void**) (&_cuModuleGetGlobal) = load_sym(handle, "cuModuleGetGlobal_v2");
|
||||
*(void**) (&_cuLinkAddFile) = load_sym(handle, "cuLinkAddFile_v2");
|
||||
*(void**) (&_cuLinkAddData) = load_sym(handle, "cuLinkAddData_v2");
|
||||
*(void**) (&_cuLaunchCooperativeKernel) = load_sym(handle, "cuLaunchCooperativeKernel");
|
||||
@ -109,6 +110,11 @@ CUresult CUDADriverWrapper::cuModuleGetFunction(CUfunction* hfunc, CUmodule hmod
|
||||
return (*_cuModuleGetFunction)(hfunc, hmod, name);
|
||||
}
|
||||
|
||||
CUresult CUDADriverWrapper::cuModuleGetGlobal(CUdeviceptr* dptr, size_t* bytes, CUmodule hmod, const char* name) const
|
||||
{
|
||||
return (*_cuModuleGetGlobal)(dptr, bytes, hmod, name);
|
||||
}
|
||||
|
||||
CUresult CUDADriverWrapper::cuLinkAddFile(CUlinkState state, CUjitInputType type, const char* path,
|
||||
unsigned int numOptions, CUjit_option* options, void** optionValues) const
|
||||
{
|
||||
|
||||
@ -54,6 +54,8 @@ public:
|
||||
|
||||
CUresult cuModuleGetFunction(CUfunction* hfunc, CUmodule hmod, const char* name) const;
|
||||
|
||||
CUresult cuModuleGetGlobal(CUdeviceptr* dptr, size_t* bytes, CUmodule hmod, const char* name) const;
|
||||
|
||||
CUresult cuLinkAddFile(CUlinkState state, CUjitInputType type, const char* path, unsigned int numOptions,
|
||||
CUjit_option* options, void** optionValues) const;
|
||||
|
||||
@ -78,6 +80,7 @@ private:
|
||||
CUresult (*_cuLinkCreate)(unsigned int, CUjit_option*, void**, CUlinkState*);
|
||||
CUresult (*_cuModuleLoadData)(CUmodule*, const void*);
|
||||
CUresult (*_cuModuleGetFunction)(CUfunction*, CUmodule, const char*);
|
||||
CUresult (*_cuModuleGetGlobal)(CUdeviceptr*, size_t*, CUmodule, const char*);
|
||||
CUresult (*_cuLinkAddFile)(CUlinkState, CUjitInputType, const char*, unsigned int, CUjit_option*, void**);
|
||||
CUresult (*_cuLinkAddData)(
|
||||
CUlinkState, CUjitInputType, void*, size_t, const char*, unsigned int, CUjit_option*, void**);
|
||||
|
||||
@ -287,6 +287,16 @@ inline int getMultiProcessorCount()
|
||||
return multi_processor_count;
|
||||
}
|
||||
|
||||
inline int getMaxSharedMemoryPerBlockOptin()
|
||||
{
|
||||
int device_id;
|
||||
int max_shared_memory_per_block;
|
||||
check_cuda_error(cudaGetDevice(&device_id));
|
||||
check_cuda_error(
|
||||
cudaDeviceGetAttribute(&max_shared_memory_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device_id));
|
||||
return max_shared_memory_per_block;
|
||||
}
|
||||
|
||||
inline int divUp(int a, int n)
|
||||
{
|
||||
return (a + n - 1) / n;
|
||||
|
||||
41
cpp/tensorrt_llm/common/dataType.h
Normal file
41
cpp/tensorrt_llm/common/dataType.h
Normal file
@ -0,0 +1,41 @@
|
||||
/*
|
||||
* Copyright (c) 1993-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <NvInferRuntime.h>
|
||||
|
||||
namespace tensorrt_llm::common
|
||||
{
|
||||
|
||||
constexpr static size_t getDTypeSize(nvinfer1::DataType type)
|
||||
{
|
||||
switch (type)
|
||||
{
|
||||
case nvinfer1::DataType::kINT64: return 8;
|
||||
case nvinfer1::DataType::kINT32: [[fallthrough]];
|
||||
case nvinfer1::DataType::kFLOAT: return 4;
|
||||
case nvinfer1::DataType::kBF16: [[fallthrough]];
|
||||
case nvinfer1::DataType::kHALF: return 2;
|
||||
case nvinfer1::DataType::kBOOL: [[fallthrough]];
|
||||
case nvinfer1::DataType::kUINT8: [[fallthrough]];
|
||||
case nvinfer1::DataType::kINT8: [[fallthrough]];
|
||||
case nvinfer1::DataType::kFP8: return 1;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
} // namespace tensorrt_llm::common
|
||||
@ -38,7 +38,7 @@ std::string vformat(char const* fmt, va_list args)
|
||||
std::string stringBuf(size, char{});
|
||||
auto const size2 = std::vsnprintf(&stringBuf[0], size + 1, fmt, args);
|
||||
|
||||
TLLM_CHECK_WITH_INFO(size2 == size, std::strerror(errno));
|
||||
TLLM_CHECK_WITH_INFO(size2 == size, std::string(std::strerror(errno)));
|
||||
|
||||
return stringBuf;
|
||||
}
|
||||
|
||||
@ -42,6 +42,16 @@ static inline std::basic_ostream<char>& operator<<(std::basic_ostream<char>& str
|
||||
return stream;
|
||||
}
|
||||
|
||||
inline std::string fmtstr(std::string const& s)
|
||||
{
|
||||
return s;
|
||||
}
|
||||
|
||||
inline std::string fmtstr(std::string&& s)
|
||||
{
|
||||
return s;
|
||||
}
|
||||
|
||||
#if defined(_MSC_VER)
|
||||
std::string fmtstr(char const* format, ...);
|
||||
#else
|
||||
|
||||
@ -37,6 +37,7 @@
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/epilogue/thread/activation.h"
|
||||
#include "cutlass/epilogue/thread/linear_combination_generic.h"
|
||||
#include "cutlass/epilogue/thread/scale_type.h"
|
||||
#include "cutlass/functional.h"
|
||||
#include "cutlass/half.h"
|
||||
|
||||
@ -48,11 +48,23 @@ struct EpilogueOpBiasFtGelu
|
||||
{
|
||||
};
|
||||
|
||||
struct EpilogueOpDefaultSilu
|
||||
{
|
||||
};
|
||||
|
||||
struct EpilogueOpDefaultReLU
|
||||
{
|
||||
};
|
||||
|
||||
struct EpilogueOpDefaultFtGelu
|
||||
{
|
||||
};
|
||||
|
||||
struct EpilogueOpBias
|
||||
{
|
||||
};
|
||||
|
||||
struct EpilogueOpNoBias
|
||||
struct EpilogueOpDefault
|
||||
{
|
||||
};
|
||||
|
||||
@ -61,40 +73,66 @@ struct Epilogue
|
||||
{
|
||||
};
|
||||
|
||||
constexpr auto BiasScaleMode = cutlass::epilogue::thread::ScaleType::NoBetaScaling;
|
||||
|
||||
template <typename ElementType, int ElementsPerVectorAccess, typename ElementAccumulator>
|
||||
struct Epilogue<ElementType, ElementsPerVectorAccess, ElementAccumulator, EpilogueOpBiasSilu>
|
||||
{
|
||||
using Op = cutlass::epilogue::thread::LinearCombinationSilu<ElementType, ElementsPerVectorAccess,
|
||||
ElementAccumulator, ElementAccumulator, cutlass::epilogue::thread::ScaleType::NoBetaScaling>;
|
||||
ElementAccumulator, ElementAccumulator, BiasScaleMode>;
|
||||
};
|
||||
|
||||
template <typename ElementType, int ElementsPerVectorAccess, typename ElementAccumulator>
|
||||
struct Epilogue<ElementType, ElementsPerVectorAccess, ElementAccumulator, EpilogueOpBiasReLU>
|
||||
{
|
||||
using Op = cutlass::epilogue::thread::LinearCombinationRelu<ElementType, ElementsPerVectorAccess,
|
||||
ElementAccumulator, ElementAccumulator, cutlass::epilogue::thread::ScaleType::NoBetaScaling>;
|
||||
ElementAccumulator, ElementAccumulator, BiasScaleMode>;
|
||||
};
|
||||
|
||||
template <typename ElementType, int ElementsPerVectorAccess, typename ElementAccumulator>
|
||||
struct Epilogue<ElementType, ElementsPerVectorAccess, ElementAccumulator, EpilogueOpBiasFtGelu>
|
||||
{
|
||||
using Op = cutlass::epilogue::thread::LinearCombinationGeneric<cutlass::epilogue::thread::GELU_taylor, ElementType,
|
||||
ElementsPerVectorAccess, ElementAccumulator, ElementAccumulator,
|
||||
cutlass::epilogue::thread::ScaleType::NoBetaScaling, cutlass::FloatRoundStyle::round_to_nearest, true>;
|
||||
ElementsPerVectorAccess, ElementAccumulator, ElementAccumulator, BiasScaleMode,
|
||||
cutlass::FloatRoundStyle::round_to_nearest, true>;
|
||||
};
|
||||
|
||||
template <typename ElementType, int ElementsPerVectorAccess, typename ElementAccumulator>
|
||||
struct Epilogue<ElementType, ElementsPerVectorAccess, ElementAccumulator, EpilogueOpBias>
|
||||
{
|
||||
using Op = cutlass::epilogue::thread::LinearCombination<ElementType, ElementsPerVectorAccess, ElementAccumulator,
|
||||
ElementAccumulator, cutlass::epilogue::thread::ScaleType::NoBetaScaling>;
|
||||
ElementAccumulator, BiasScaleMode>;
|
||||
};
|
||||
|
||||
constexpr auto DefaultScaleMode = cutlass::epilogue::thread::ScaleType::Default;
|
||||
|
||||
template <typename ElementType, int ElementsPerVectorAccess, typename ElementAccumulator>
|
||||
struct Epilogue<ElementType, ElementsPerVectorAccess, ElementAccumulator, EpilogueOpDefaultSilu>
|
||||
{
|
||||
using Op = cutlass::epilogue::thread::LinearCombinationSilu<ElementType, ElementsPerVectorAccess,
|
||||
ElementAccumulator, ElementAccumulator, DefaultScaleMode>;
|
||||
};
|
||||
|
||||
template <typename ElementType, int ElementsPerVectorAccess, typename ElementAccumulator>
|
||||
struct Epilogue<ElementType, ElementsPerVectorAccess, ElementAccumulator, EpilogueOpNoBias>
|
||||
struct Epilogue<ElementType, ElementsPerVectorAccess, ElementAccumulator, EpilogueOpDefaultReLU>
|
||||
{
|
||||
using Op = cutlass::epilogue::thread::LinearCombinationRelu<ElementType, ElementsPerVectorAccess,
|
||||
ElementAccumulator, ElementAccumulator, DefaultScaleMode>;
|
||||
};
|
||||
|
||||
template <typename ElementType, int ElementsPerVectorAccess, typename ElementAccumulator>
|
||||
struct Epilogue<ElementType, ElementsPerVectorAccess, ElementAccumulator, EpilogueOpDefaultFtGelu>
|
||||
{
|
||||
using Op = cutlass::epilogue::thread::LinearCombinationGeneric<cutlass::epilogue::thread::GELU_taylor, ElementType,
|
||||
ElementsPerVectorAccess, ElementAccumulator, ElementAccumulator, DefaultScaleMode,
|
||||
cutlass::FloatRoundStyle::round_to_nearest, true>;
|
||||
};
|
||||
|
||||
template <typename ElementType, int ElementsPerVectorAccess, typename ElementAccumulator>
|
||||
struct Epilogue<ElementType, ElementsPerVectorAccess, ElementAccumulator, EpilogueOpDefault>
|
||||
{
|
||||
using Op = cutlass::epilogue::thread::LinearCombination<ElementType, ElementsPerVectorAccess, ElementAccumulator,
|
||||
ElementAccumulator, cutlass::epilogue::thread::ScaleType::Default>;
|
||||
ElementAccumulator, DefaultScaleMode>;
|
||||
};
|
||||
|
||||
} // namespace cutlass_extensions
|
||||
|
||||
@ -0,0 +1,73 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
/*! \file
|
||||
\brief Scheduler for grouped GEMM
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/gemm/kernel/gemm_grouped_problem_visitor.h"
|
||||
#include "cutlass/matrix_coord.h"
|
||||
|
||||
#include "cutlass_extensions/gemm/kernel/gemm_moe_problem_visitor.h"
|
||||
#include "cutlass_extensions/gemm/kernel/moe_problem_visitor.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass
|
||||
{
|
||||
namespace gemm
|
||||
{
|
||||
namespace kernel
|
||||
{
|
||||
|
||||
/// Visitor class to abstract away the algorithm for iterating over tiles
|
||||
template <typename ThreadblockShape, GroupScheduleMode GroupScheduleMode_, int PrefetchTileCount, int ThreadCount,
|
||||
bool Transposed = false>
|
||||
struct GemmMoeProblemVisitor
|
||||
: public MoeProblemVisitor<detail::GemmGroupedProblemSizeHelper<ThreadblockShape, Transposed>, ThreadblockShape,
|
||||
GroupScheduleMode_, PrefetchTileCount, ThreadCount>
|
||||
{
|
||||
|
||||
static bool const kTransposed = Transposed;
|
||||
|
||||
using ProblemSizeHelper = detail::GemmGroupedProblemSizeHelper<ThreadblockShape, Transposed>;
|
||||
using Base
|
||||
= MoeProblemVisitor<ProblemSizeHelper, ThreadblockShape, GroupScheduleMode_, PrefetchTileCount, ThreadCount>;
|
||||
using Params = typename Base::Params;
|
||||
using SharedStorage = typename Base::SharedStorage;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
CUTLASS_DEVICE
|
||||
GemmMoeProblemVisitor(Params const& params_, SharedStorage& shared_storage_, int32_t block_idx)
|
||||
: Base(params_, shared_storage_, block_idx)
|
||||
{
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -0,0 +1,522 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
/*! \file
|
||||
\brief
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/complex.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/fast_math.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/matrix_coord.h"
|
||||
#include "cutlass/semaphore.h"
|
||||
|
||||
#include "cutlass/gemm/kernel/gemm_transpose_operands.h"
|
||||
#include "cutlass/layout/matrix.h"
|
||||
#include "cutlass/trace.h"
|
||||
|
||||
#include "cutlass_extensions/gemm/kernel/gemm_moe_problem_visitor.h"
|
||||
#include "cutlass_extensions/tile_interleaved_layout.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass
|
||||
{
|
||||
namespace gemm
|
||||
{
|
||||
namespace kernel
|
||||
{
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// This section exists to that we can use the same kernel code for regular gemm and dequantizing gemms.
|
||||
// It will dispatch to the dequantizing gemm if the Mma type has an Iterator for scales in global.
|
||||
template <typename...>
|
||||
using void_t = void;
|
||||
|
||||
template <typename Mma, typename = void>
|
||||
struct use_dq_gemm : platform::false_type
|
||||
{
|
||||
};
|
||||
|
||||
template <typename Mma>
|
||||
struct use_dq_gemm<Mma, void_t<typename Mma::IteratorScale>> : platform::true_type
|
||||
{
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate
|
||||
typename Epilogue_, ///! Epilogue
|
||||
typename ThreadblockSwizzle_, ///! Threadblock swizzling function
|
||||
typename KernelArch, ///! The Architecture this kernel is compiled for. Used since SIMT kernels lose top-level
|
||||
/// arch.
|
||||
GroupScheduleMode GroupScheduleMode_ ///! Type of scheduling to perform
|
||||
>
|
||||
struct MoeFCGemm
|
||||
{
|
||||
public:
|
||||
using Mma = Mma_;
|
||||
using Epilogue = Epilogue_;
|
||||
using EpilogueOutputOp = typename Epilogue::OutputOp;
|
||||
using ThreadblockSwizzle = ThreadblockSwizzle_;
|
||||
static GroupScheduleMode const kGroupScheduleMode = GroupScheduleMode_;
|
||||
static bool const kTransposed = false;
|
||||
|
||||
// Optional transpose
|
||||
using MapArguments = kernel::detail::MapArguments<typename Mma::IteratorA::Element, typename Mma::IteratorA::Layout,
|
||||
Mma::kTransformA, Mma::IteratorA::AccessType::kElements, typename Mma::IteratorB::Element,
|
||||
typename Mma::IteratorB::Layout, Mma::kTransformB, Mma::IteratorB::AccessType::kElements, typename Mma::LayoutC,
|
||||
kTransposed>;
|
||||
|
||||
// Public-facing type definitions related to operand element type, layout, and complex conjugate
|
||||
// operation. Must interact with the 'kTransposed' notion.
|
||||
static_assert(!kTransposed, "Transpose problem not supported");
|
||||
using ElementA = typename MapArguments::ElementA;
|
||||
using LayoutA = typename MapArguments::LayoutA;
|
||||
using ElementB = typename MapArguments::ElementB;
|
||||
using LayoutB = typename MapArguments::LayoutB;
|
||||
using ElementC = typename Epilogue::OutputTileIterator::Element;
|
||||
using LayoutC = typename MapArguments::LayoutC;
|
||||
using ElementScale = ElementC;
|
||||
|
||||
static ComplexTransform const kTransformA = MapArguments::kTransformA;
|
||||
static ComplexTransform const kTransformB = MapArguments::kTransformB;
|
||||
|
||||
// Type definitions about the mainloop.
|
||||
using Operator = typename Mma::Operator;
|
||||
using OperatorClass = typename Mma::Operator::OperatorClass;
|
||||
using ThreadblockShape = typename Mma::Shape;
|
||||
using WarpShape = typename Mma::Operator::Shape;
|
||||
using InstructionShape = typename Mma::Policy::Operator::InstructionShape;
|
||||
using ArchTag = typename Mma::ArchTag;
|
||||
|
||||
static int const kStages = Mma::kStages;
|
||||
static int const kAlignmentA = MapArguments::kAlignmentA;
|
||||
static int const kAlignmentB = MapArguments::kAlignmentB;
|
||||
static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess;
|
||||
|
||||
/// Warp count (concept: GemmShape)
|
||||
using WarpCount = typename Mma::WarpCount;
|
||||
static int const kThreadCount = 32 * WarpCount::kCount;
|
||||
|
||||
using ProblemVisitor
|
||||
= GemmMoeProblemVisitor<ThreadblockShape, kGroupScheduleMode, kThreadCount, kThreadCount, kTransposed>;
|
||||
|
||||
//
|
||||
// Structures
|
||||
//
|
||||
|
||||
/// Argument structure
|
||||
struct Arguments
|
||||
{
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
int problem_count;
|
||||
int threadblock_count;
|
||||
int group_size;
|
||||
|
||||
typename EpilogueOutputOp::Params output_op;
|
||||
|
||||
ElementA* ptr_A;
|
||||
ElementB* ptr_B;
|
||||
ElementScale* weight_scales;
|
||||
ElementC* ptr_C;
|
||||
ElementC* ptr_D;
|
||||
|
||||
int64_t* total_rows_before_expert;
|
||||
int64_t gemm_n;
|
||||
int64_t gemm_k;
|
||||
|
||||
// Only used by device-level operator
|
||||
GemmCoord* host_problem_sizes;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Default ctor
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments()
|
||||
: problem_count(0)
|
||||
, threadblock_count(0)
|
||||
, ptr_A(nullptr)
|
||||
, ptr_B(nullptr)
|
||||
, weight_scales(nullptr)
|
||||
, ptr_C(nullptr)
|
||||
, ptr_D(nullptr)
|
||||
, total_rows_before_expert(nullptr)
|
||||
, gemm_n(0)
|
||||
, gemm_k(0)
|
||||
, host_problem_sizes(nullptr)
|
||||
{
|
||||
}
|
||||
|
||||
/// Ctor
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments(int problem_count, int threadblock_count, int group_size, typename EpilogueOutputOp::Params output_op,
|
||||
const ElementA* ptr_A, const ElementB* ptr_B, const ElementScale* weight_scales, const ElementC* ptr_C,
|
||||
ElementC* ptr_D, int64_t* total_rows_before_expert, int64_t gemm_n, int64_t gemm_k,
|
||||
GemmCoord* host_problem_sizes = nullptr)
|
||||
: problem_count(problem_count)
|
||||
, threadblock_count(threadblock_count)
|
||||
, group_size(group_size)
|
||||
, output_op(output_op)
|
||||
, ptr_A(const_cast<ElementA*>(ptr_A))
|
||||
, ptr_B(const_cast<ElementB*>(ptr_B))
|
||||
, weight_scales(const_cast<ElementScale*>(weight_scales))
|
||||
, ptr_C(const_cast<ElementC*>(ptr_C))
|
||||
, ptr_D(ptr_D)
|
||||
, total_rows_before_expert(total_rows_before_expert)
|
||||
, gemm_n(gemm_n)
|
||||
, gemm_k(gemm_k)
|
||||
, host_problem_sizes(nullptr)
|
||||
{
|
||||
if (platform::is_same<uint8_t, ElementB>::value || platform::is_same<uint4b_t, ElementB>::value)
|
||||
{
|
||||
assert(weight_scales);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
//
|
||||
// Structure for precomputing values in host memory and passing to kernels
|
||||
//
|
||||
|
||||
/// Parameters structure
|
||||
struct Params
|
||||
{
|
||||
|
||||
typename ProblemVisitor::Params problem_visitor;
|
||||
int threadblock_count;
|
||||
int group_size;
|
||||
|
||||
typename EpilogueOutputOp::Params output_op;
|
||||
|
||||
ElementA* ptr_A;
|
||||
ElementB* ptr_B;
|
||||
ElementScale* weight_scales;
|
||||
ElementC* ptr_C;
|
||||
ElementC* ptr_D;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params()
|
||||
: ptr_A(nullptr)
|
||||
, ptr_B(nullptr)
|
||||
, weight_scales(nullptr)
|
||||
, ptr_C(nullptr)
|
||||
, ptr_D(nullptr)
|
||||
{
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(Arguments const& args, void* workspace = nullptr, int tile_count = 0)
|
||||
: problem_visitor(
|
||||
args.total_rows_before_expert, args.gemm_n, args.gemm_k, args.problem_count, workspace, tile_count)
|
||||
, threadblock_count(args.threadblock_count)
|
||||
, group_size(args.group_size)
|
||||
, output_op(args.output_op)
|
||||
, ptr_A(args.ptr_A)
|
||||
, ptr_B(args.ptr_B)
|
||||
, weight_scales(args.weight_scales)
|
||||
, ptr_C(args.ptr_C)
|
||||
, ptr_D(args.ptr_D)
|
||||
{
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
void update(Arguments const& args, void* workspace = nullptr, int tile_count = 0)
|
||||
{
|
||||
|
||||
problem_visitor = typename ProblemVisitor::Params(
|
||||
args.total_rows_before_expert, args.gemm_n, args.gemm_k, args.problem_count, workspace, tile_count);
|
||||
threadblock_count = args.threadblock_count;
|
||||
output_op = args.output_op;
|
||||
ptr_A = args.ptr_A;
|
||||
ptr_B = args.ptr_B;
|
||||
weight_scales = args.weight_scales;
|
||||
ptr_C = args.ptr_C;
|
||||
ptr_D = args.ptr_D;
|
||||
}
|
||||
};
|
||||
|
||||
/// Shared memory storage structure
|
||||
union SharedStorage
|
||||
{
|
||||
typename ProblemVisitor::SharedStorage problem_visitor;
|
||||
typename Mma::SharedStorage main_loop;
|
||||
typename Epilogue::SharedStorage epilogue;
|
||||
};
|
||||
|
||||
public:
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
CUTLASS_DEVICE
|
||||
MoeFCGemm() {}
|
||||
|
||||
/// Determines whether kernel satisfies alignment
|
||||
static Status can_implement(cutlass::gemm::GemmCoord const& problem_size)
|
||||
{
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
static Status can_implement(Arguments const& args)
|
||||
{
|
||||
if (platform::is_same<uint8_t, ElementB>::value || platform::is_same<uint4b_t, ElementB>::value)
|
||||
{
|
||||
if (args.weight_scales == nullptr)
|
||||
{
|
||||
CUTLASS_TRACE_HOST("MoeFCGemm::can_implement() - weight scales are required for uint8_t and uint4b_t");
|
||||
return Status::kInvalid;
|
||||
}
|
||||
}
|
||||
else if (args.weight_scales != nullptr)
|
||||
{
|
||||
CUTLASS_TRACE_HOST(
|
||||
"MoeFCGemm::can_implement() - weight scales are ignored for all types except uint8_t and uint4b_t");
|
||||
return Status::kInvalid;
|
||||
}
|
||||
else if (args.group_size != args.gemm_k)
|
||||
{
|
||||
CUTLASS_TRACE_HOST("MoeFCGemm::can_implement() - scale shape should be (1, gemm_n)");
|
||||
return Status::kInvalid;
|
||||
}
|
||||
// Handle the case the input is too short
|
||||
else if (args.gemm_n < Mma::IteratorB::AccessType::kElements)
|
||||
{
|
||||
CUTLASS_TRACE_HOST("MoeFCGemm::can_implement() - gemm_n is smaller than the input alignment");
|
||||
return Status::kInvalid;
|
||||
}
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
static size_t get_extra_workspace_size(Arguments const& args, cutlass::gemm::GemmCoord const& grid_tiled_shape)
|
||||
{
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
// The dummy template parameter is not used and exists so that we can compile this code using
|
||||
// a standard earlier than C++17. Prior to C++17, fully specialized templates HAD to exists in
|
||||
// a namespace
|
||||
template <bool B, typename dummy = void>
|
||||
struct KernelRunner
|
||||
{
|
||||
CUTLASS_DEVICE
|
||||
static void run_kernel(Params const& params, SharedStorage& shared_storage)
|
||||
{
|
||||
CUTLASS_NOT_IMPLEMENTED();
|
||||
}
|
||||
};
|
||||
|
||||
template <typename dummy>
|
||||
struct KernelRunner<true, dummy>
|
||||
{
|
||||
CUTLASS_DEVICE
|
||||
static void run_kernel(Params const& params, SharedStorage& shared_storage)
|
||||
{
|
||||
//
|
||||
// These types shadow the type-level definitions and support the ability to implement
|
||||
// a 'transposed' GEMM that computes the transposed problems.
|
||||
//
|
||||
using ElementA = typename Mma::IteratorA::Element;
|
||||
using LayoutA = typename Mma::IteratorA::Layout;
|
||||
using ElementB = typename Mma::IteratorB::Element;
|
||||
using LayoutB = typename Mma::IteratorB::Layout;
|
||||
using ElementC = typename Epilogue::OutputTileIterator::Element;
|
||||
using LayoutC = typename Epilogue::OutputTileIterator::Layout;
|
||||
static constexpr int kInterleave = Mma::IteratorB::Shape::kRow / Mma::Shape::kK;
|
||||
static_assert(platform::is_same<LayoutB, layout::RowMajor>::value && kInterleave == 1
|
||||
|| platform::is_same<LayoutB, layout::ColumnMajor>::value && kInterleave >= 1,
|
||||
"B must be row major/col major OR col major interleaved.");
|
||||
|
||||
//
|
||||
// Problem visitor.
|
||||
//
|
||||
ProblemVisitor problem_visitor(params.problem_visitor, shared_storage.problem_visitor, blockIdx.x);
|
||||
|
||||
const int64_t gemm_k = params.problem_visitor.gemm_k;
|
||||
const int64_t gemm_n = params.problem_visitor.gemm_n;
|
||||
int64_t bytes_per_expert_matrix = (gemm_k * gemm_n / 8) * cutlass::sizeof_bits<ElementB>::value;
|
||||
|
||||
// Outer 'persistent' loop to iterate over tiles
|
||||
int loop = 0;
|
||||
while (problem_visitor.next_tile())
|
||||
{
|
||||
loop++;
|
||||
|
||||
GemmCoord problem_size = problem_visitor.problem_size();
|
||||
int32_t problem_idx = problem_visitor.problem_index();
|
||||
int32_t cta_idx = int32_t(problem_visitor.threadblock_idx());
|
||||
|
||||
GemmCoord grid_shape = problem_visitor.grid_shape(problem_size);
|
||||
|
||||
cutlass::gemm::GemmCoord threadblock_offset(
|
||||
int(cta_idx / grid_shape.n()) * Mma::Shape::kM, int(cta_idx % grid_shape.n()) * Mma::Shape::kN, 0);
|
||||
|
||||
// Load element pointers. Exchange pointers and strides if working on the transpose
|
||||
const int64_t rows_to_jump
|
||||
= problem_idx == 0 ? 0 : params.problem_visitor.last_row_for_problem[problem_idx - 1];
|
||||
ElementA* ptr_A = reinterpret_cast<ElementA*>(params.ptr_A) + rows_to_jump * gemm_k;
|
||||
typename LayoutA::LongIndex ldm_A = gemm_k;
|
||||
|
||||
char* byte_ptr_B = ((char*) params.ptr_B) + problem_idx * bytes_per_expert_matrix;
|
||||
ElementB* ptr_B = reinterpret_cast<ElementB*>(byte_ptr_B);
|
||||
typename LayoutB::LongIndex ldm_B
|
||||
= platform::is_same<layout::RowMajor, LayoutB>::value ? gemm_n : gemm_k * kInterleave;
|
||||
|
||||
// Compute initial location in logical coordinates
|
||||
cutlass::MatrixCoord tb_offset_A{
|
||||
threadblock_offset.m(),
|
||||
0,
|
||||
};
|
||||
|
||||
cutlass::MatrixCoord tb_offset_B{0, threadblock_offset.n() / kInterleave};
|
||||
|
||||
cutlass::MatrixCoord tb_offset_scale{0, threadblock_offset.n()};
|
||||
|
||||
// Compute position within threadblock
|
||||
int thread_idx = threadIdx.x;
|
||||
|
||||
// Construct iterators to A and B operands
|
||||
typename Mma::IteratorA iterator_A(
|
||||
LayoutA(ldm_A), ptr_A, {problem_size.m(), problem_size.k()}, thread_idx, tb_offset_A);
|
||||
|
||||
typename Mma::IteratorB iterator_B(LayoutB(ldm_B), ptr_B,
|
||||
{problem_size.k() * kInterleave, problem_size.n() / kInterleave}, thread_idx, tb_offset_B);
|
||||
|
||||
typename Mma::FragmentC accumulators;
|
||||
|
||||
accumulators.clear();
|
||||
|
||||
// Broadcast the warp_id computed by lane 0 to ensure dependent code
|
||||
// is compiled as warp-uniform.
|
||||
int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
|
||||
|
||||
int lane_idx = threadIdx.x % 32;
|
||||
|
||||
//
|
||||
// Matrix multiply phase
|
||||
//
|
||||
|
||||
// Construct thread-scoped matrix multiply
|
||||
auto CreateMMA = [&]()
|
||||
{
|
||||
if constexpr (use_dq_gemm<Mma>::value)
|
||||
return Mma(shared_storage.main_loop, params.group_size, thread_idx, warp_idx, lane_idx);
|
||||
else
|
||||
return Mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx);
|
||||
};
|
||||
Mma mma = CreateMMA();
|
||||
|
||||
// Compute threadblock-scoped matrix multiply-add
|
||||
int gemm_k_iterations = (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK;
|
||||
|
||||
// Wait for all threads to finish their epilogue phases from the previous tile.
|
||||
__syncthreads();
|
||||
|
||||
// Compute threadblock-scoped matrix multiply-add
|
||||
ElementScale* weight_scale_ptr = params.weight_scales + problem_idx * problem_size.n();
|
||||
|
||||
if constexpr (use_dq_gemm<Mma>::value)
|
||||
{
|
||||
const MatrixCoord scale_extent = {1, problem_size.n()};
|
||||
typename Mma::IteratorScale iterator_scale(Mma::IteratorScale::Layout(scale_extent.column()),
|
||||
weight_scale_ptr, scale_extent, thread_idx, tb_offset_scale);
|
||||
|
||||
mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, iterator_scale, accumulators);
|
||||
}
|
||||
else
|
||||
{
|
||||
mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators);
|
||||
}
|
||||
|
||||
//
|
||||
// Epilogue
|
||||
//
|
||||
|
||||
EpilogueOutputOp output_op(params.output_op);
|
||||
|
||||
ElementC* ptr_C = reinterpret_cast<ElementC*>(params.ptr_C) + problem_idx * gemm_n;
|
||||
ElementC* ptr_D = reinterpret_cast<ElementC*>(params.ptr_D) + rows_to_jump * gemm_n;
|
||||
|
||||
LayoutC layout_C(0);
|
||||
LayoutC layout_D(gemm_n);
|
||||
|
||||
typename Epilogue::OutputTileIterator::Params params_C(layout_C);
|
||||
typename Epilogue::OutputTileIterator::Params params_D(layout_D);
|
||||
|
||||
// Tile iterator loading from source tensor.
|
||||
typename Epilogue::OutputTileIterator iterator_C(
|
||||
params_C, ptr_C, problem_size.mn(), thread_idx, threadblock_offset.mn());
|
||||
|
||||
// Tile iterator writing to destination tensor.
|
||||
typename Epilogue::OutputTileIterator iterator_D(
|
||||
params_D, ptr_D, problem_size.mn(), thread_idx, threadblock_offset.mn());
|
||||
|
||||
Epilogue epilogue(shared_storage.epilogue, thread_idx, warp_idx, lane_idx);
|
||||
|
||||
// Execute the epilogue operator to update the destination tensor.
|
||||
epilogue(output_op, iterator_D, accumulators, iterator_C);
|
||||
|
||||
// Next tile
|
||||
problem_visitor.advance(gridDim.x);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/*
|
||||
To improve compilation speed, we do not compile the device operator if the CUDA_ARCH does not correspond
|
||||
to the ArchTag of the cutlass kernel operator.
|
||||
*/
|
||||
/// Executes one GEMM
|
||||
CUTLASS_DEVICE
|
||||
void operator()(Params const& params, SharedStorage& shared_storage)
|
||||
{
|
||||
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 700) && (__CUDA_ARCH__ < 750)
|
||||
static constexpr bool compile_needed = platform::is_same<KernelArch, arch::Sm70>::value;
|
||||
KernelRunner<compile_needed>::run_kernel(params, shared_storage);
|
||||
#elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 750) && (__CUDA_ARCH__ < 800)
|
||||
static constexpr bool compile_needed = platform::is_same<KernelArch, arch::Sm75>::value;
|
||||
KernelRunner<compile_needed>::run_kernel(params, shared_storage);
|
||||
#elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && (__CUDA_ARCH__ < 900)
|
||||
static constexpr bool compile_needed = platform::is_same<KernelArch, arch::Sm80>::value;
|
||||
KernelRunner<compile_needed>::run_kernel(params, shared_storage);
|
||||
#else
|
||||
CUTLASS_NOT_IMPLEMENTED();
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -0,0 +1,344 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
/*! \file
|
||||
\brief Base scheduler for grouped problems, using MoE
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/gemm/kernel/grouped_problem_visitor.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass
|
||||
{
|
||||
namespace gemm
|
||||
{
|
||||
namespace kernel
|
||||
{
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Visitor class to abstract away the algorithm for iterating over tiles
|
||||
template <typename ProblemSizeHelper, typename ThreadblockShape_>
|
||||
struct BaseMoeProblemVisitor
|
||||
{
|
||||
using ThreadblockShape = ThreadblockShape_;
|
||||
|
||||
struct ProblemInfo
|
||||
{
|
||||
static int32_t const kNoPrefetchEntry = -1;
|
||||
int32_t problem_idx;
|
||||
int32_t problem_start;
|
||||
|
||||
CUTLASS_DEVICE
|
||||
ProblemInfo()
|
||||
: problem_idx(kNoPrefetchEntry)
|
||||
, problem_start(kNoPrefetchEntry)
|
||||
{
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
ProblemInfo(int32_t problem_idx_, int32_t problem_start_)
|
||||
: problem_idx(problem_idx_)
|
||||
, problem_start(problem_start_)
|
||||
{
|
||||
}
|
||||
};
|
||||
|
||||
struct Params
|
||||
{
|
||||
int64_t const* last_row_for_problem;
|
||||
int64_t gemm_n;
|
||||
int64_t gemm_k;
|
||||
int32_t problem_count;
|
||||
void const* workspace;
|
||||
int32_t tile_count;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Ctor
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params()
|
||||
: last_row_for_problem(nullptr)
|
||||
, gemm_n(0)
|
||||
, gemm_k(0)
|
||||
, problem_count(0)
|
||||
, workspace(nullptr)
|
||||
, tile_count(0)
|
||||
{
|
||||
}
|
||||
|
||||
/// Ctor
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(int64_t const* last_row_for_problem, int64_t gemm_n, int64_t gemm_k, int32_t problem_count,
|
||||
void const* workspace = nullptr, int32_t tile_count = 0)
|
||||
: last_row_for_problem(last_row_for_problem)
|
||||
, gemm_n(gemm_n)
|
||||
, gemm_k(gemm_k)
|
||||
, problem_count(problem_count)
|
||||
, workspace(workspace)
|
||||
, tile_count(tile_count)
|
||||
{
|
||||
}
|
||||
};
|
||||
|
||||
Params const& params;
|
||||
int32_t tile_idx;
|
||||
int32_t problem_tile_start;
|
||||
int32_t problem_idx;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
CUTLASS_DEVICE
|
||||
BaseMoeProblemVisitor(Params const& params_, int32_t block_idx)
|
||||
: params(params_)
|
||||
, tile_idx(block_idx)
|
||||
, problem_tile_start(0)
|
||||
, problem_idx(0)
|
||||
{
|
||||
}
|
||||
|
||||
/// Get the grid shape
|
||||
CUTLASS_HOST_DEVICE
|
||||
static cutlass::gemm::GemmCoord grid_shape(const cutlass::gemm::GemmCoord& problem)
|
||||
{
|
||||
|
||||
return cutlass::gemm::GemmCoord(((problem.m() - 1 + ThreadblockShape::kM) / ThreadblockShape::kM),
|
||||
((problem.n() - 1 + ThreadblockShape::kN) / ThreadblockShape::kN), 1);
|
||||
}
|
||||
|
||||
/// Gets the global tile index
|
||||
CUTLASS_HOST_DEVICE
|
||||
int32_t tile_index() const
|
||||
{
|
||||
return tile_idx;
|
||||
}
|
||||
|
||||
/// Gets the index of the problem
|
||||
CUTLASS_HOST_DEVICE
|
||||
int32_t problem_index() const
|
||||
{
|
||||
return problem_idx;
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
int32_t threadblock_idx() const
|
||||
{
|
||||
return tile_idx - problem_tile_start;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void advance(int32_t grid_size)
|
||||
{
|
||||
tile_idx += grid_size;
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
static void possibly_transpose_problem(cutlass::gemm::GemmCoord& problem)
|
||||
{
|
||||
ProblemSizeHelper::possibly_transpose_problem(problem);
|
||||
}
|
||||
|
||||
/// Returns the problem size for the current problem
|
||||
CUTLASS_HOST_DEVICE
|
||||
cutlass::gemm::GemmCoord problem_size() const
|
||||
{
|
||||
return problem_size(problem_idx);
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
cutlass::gemm::GemmCoord problem_size(int idx) const
|
||||
{
|
||||
const int64_t prev_problem_row = idx == 0 ? 0 : params.last_row_for_problem[idx - 1];
|
||||
const int64_t current_problem_row = params.last_row_for_problem[idx];
|
||||
const int64_t gemm_m = current_problem_row - prev_problem_row;
|
||||
GemmCoord problem(GemmCoord::Index(gemm_m), GemmCoord::Index(params.gemm_n), GemmCoord::Index(params.gemm_k));
|
||||
ProblemSizeHelper::possibly_transpose_problem(problem);
|
||||
return problem;
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
static int32_t tile_count(const cutlass::gemm::GemmCoord& grid)
|
||||
{
|
||||
return ProblemSizeHelper::tile_count(grid);
|
||||
}
|
||||
|
||||
static int32_t group_tile_count(const cutlass::gemm::GemmCoord* host_problem_sizes_ptr, int32_t problem_count)
|
||||
{
|
||||
int32_t total_tiles = 0;
|
||||
for (int32_t i = 0; i < problem_count; ++i)
|
||||
{
|
||||
auto problem = host_problem_sizes_ptr[i];
|
||||
possibly_transpose_problem(problem);
|
||||
auto grid = grid_shape(problem);
|
||||
total_tiles += tile_count(grid);
|
||||
}
|
||||
|
||||
return total_tiles;
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename ProblemSizeHelper, typename ThreadblockShape, GroupScheduleMode GroupScheduleMode_,
|
||||
int PrefetchTileCount, int ThreadCount>
|
||||
struct MoeProblemVisitor;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// ProblemVisitor that performs all scheduling on device
|
||||
//
|
||||
template <typename ProblemSizeHelper, typename ThreadblockShape, int PrefetchTileCount, int ThreadCount>
|
||||
struct MoeProblemVisitor<ProblemSizeHelper, ThreadblockShape, GroupScheduleMode::kDeviceOnly, PrefetchTileCount,
|
||||
ThreadCount> : public BaseMoeProblemVisitor<ProblemSizeHelper, ThreadblockShape>
|
||||
{
|
||||
using Base = BaseMoeProblemVisitor<ProblemSizeHelper, ThreadblockShape>;
|
||||
using Params = typename Base::Params;
|
||||
static int const kThreadCount = ThreadCount;
|
||||
static bool const kRequiresPrecomputation = false;
|
||||
static int const kThreadsPerWarp = 32;
|
||||
|
||||
struct SharedStorage
|
||||
{
|
||||
};
|
||||
|
||||
// Final tile of the problem loaded by this thread. Each thread will hold
|
||||
// a separate value.
|
||||
int32_t problem_ending_tile;
|
||||
|
||||
SharedStorage& shared_storage;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
CUTLASS_DEVICE
|
||||
MoeProblemVisitor(Params const& params_, SharedStorage& shared_storage_, int32_t block_idx)
|
||||
: Base(params_, block_idx)
|
||||
, problem_ending_tile(0)
|
||||
, shared_storage(shared_storage_)
|
||||
{
|
||||
this->problem_idx = -1 * kThreadsPerWarp;
|
||||
this->problem_tile_start = 0;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
bool next_tile()
|
||||
{
|
||||
// Check whether the tile to compute is within the range of the current problem.
|
||||
int32_t problem_tile_end = __shfl_sync(0xffffffff, problem_ending_tile, this->problem_idx % kThreadsPerWarp);
|
||||
if (this->tile_idx < problem_tile_end)
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
// Check whether the tile to compute is within the current group of problems fetched by the warp.
|
||||
// The last tile for this group is the final tile of the problem held by the final thread in the warp.
|
||||
int32_t group_tile_end = __shfl_sync(0xffffffff, problem_ending_tile, kThreadsPerWarp - 1);
|
||||
|
||||
// Keep the starting problem for this group in `problem_idx`. This is done to reduce
|
||||
// register pressure. The starting problem for this group is simply the first problem
|
||||
// in the group most recently fetched by the warp.
|
||||
int32_t& group_problem_start = this->problem_idx;
|
||||
group_problem_start = (this->problem_idx / kThreadsPerWarp) * kThreadsPerWarp;
|
||||
|
||||
// Keep the starting tile for this group in `problem_tile_start`. This is done to reduce
|
||||
// register pressure.
|
||||
int32_t& group_tile_start = this->problem_tile_start;
|
||||
|
||||
// Each thread in the warp processes a separate problem to advance until
|
||||
// reaching a problem whose starting tile is less less than tile_idx.
|
||||
while (group_tile_end <= this->tile_idx)
|
||||
{
|
||||
group_problem_start += kThreadsPerWarp;
|
||||
if (group_problem_start > this->params.problem_count)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// Since `group_tile_start` is a reference to `this->problem_tile_start`, this
|
||||
// also sets `this->problem_tile_start`. The fact that `this->problem_tile_start`
|
||||
// is also set here is used later in `next_tile`.
|
||||
group_tile_start = group_tile_end;
|
||||
|
||||
int lane_idx = threadIdx.x % kThreadsPerWarp;
|
||||
int32_t lane_problem = group_problem_start + lane_idx;
|
||||
|
||||
// Compute the number of tiles in the problem assigned to each thread.
|
||||
problem_ending_tile = 0;
|
||||
if (lane_problem < this->params.problem_count)
|
||||
{
|
||||
cutlass::gemm::GemmCoord problem = this->problem_size(lane_problem);
|
||||
cutlass::gemm::GemmCoord grid = this->grid_shape(problem);
|
||||
problem_ending_tile = this->tile_count(grid);
|
||||
}
|
||||
|
||||
// Compute a warp-wide inclusive prefix sum to compute the ending tile index of
|
||||
// each thread's problem.
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 1; i < kThreadsPerWarp; i <<= 1)
|
||||
{
|
||||
int32_t val = __shfl_up_sync(0xffffffff, problem_ending_tile, i);
|
||||
if (lane_idx >= i)
|
||||
{
|
||||
problem_ending_tile += val;
|
||||
}
|
||||
}
|
||||
|
||||
// The total tile count for this group is now in the final position of the prefix sum
|
||||
int32_t tiles_in_group = __shfl_sync(0xffffffff, problem_ending_tile, kThreadsPerWarp - 1);
|
||||
|
||||
problem_ending_tile += group_tile_start;
|
||||
group_tile_end += tiles_in_group;
|
||||
}
|
||||
|
||||
// The next problem to process is the first one that does not have ending tile position
|
||||
// that is greater than or equal to tile index.
|
||||
int32_t problem_idx_in_group = __popc(__ballot_sync(0xffffffff, problem_ending_tile <= this->tile_idx));
|
||||
|
||||
this->problem_idx = group_problem_start + problem_idx_in_group;
|
||||
|
||||
// The starting tile for this problem is the ending tile of the previous problem. In cases
|
||||
// where `problem_idx_in_group` is the first problem in the group, we do not need to reset
|
||||
// `problem_tile_start`, because it is set to the previous group's ending tile in the while
|
||||
// loop above.
|
||||
if (problem_idx_in_group > 0)
|
||||
{
|
||||
this->problem_tile_start = __shfl_sync(0xffffffff, problem_ending_tile, problem_idx_in_group - 1);
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
static size_t get_workspace_size(
|
||||
const cutlass::gemm::GemmCoord* host_problem_sizes_ptr, int32_t problem_count, int32_t block_count)
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
|
||||
static void host_precompute(const cutlass::gemm::GemmCoord* host_problem_sizes_ptr, int32_t problem_count,
|
||||
int32_t block_count, void* host_workspace_ptr)
|
||||
{
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
@ -25,7 +25,7 @@ namespace kernels
|
||||
{
|
||||
|
||||
template <typename T>
|
||||
__global__ void ban_repeat_ngram(T* logits, const int** output_ids_buf, const bool* finished_buf,
|
||||
__global__ void ban_repeat_ngram(T* logits, const int** output_ids_buf, const FinishedState* finished_buf,
|
||||
const int* parent_ids_buf, int batch_size, int beam_width, const int* no_repeat_ngram_size_buf, int id_offset,
|
||||
int vocab_size_padded, size_t step)
|
||||
{
|
||||
@ -60,7 +60,7 @@ __global__ void ban_repeat_ngram(T* logits, const int** output_ids_buf, const bo
|
||||
}
|
||||
|
||||
// if the beam has already finished, skip ngram check
|
||||
if ((finished_buf != nullptr) && (finished_buf[id_offset + local_batch_idx * beam_width + beam_idx]))
|
||||
if ((finished_buf != nullptr) && (finished_buf[id_offset + local_batch_idx * beam_width + beam_idx].isFinished()))
|
||||
{
|
||||
return;
|
||||
}
|
||||
@ -134,9 +134,9 @@ __global__ void ban_repeat_ngram(T* logits, const int** output_ids_buf, const bo
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void invokeBanRepeatNgram(T* logits, const int** output_ids_buf, const bool* finished_buf, const int* parent_ids_buf,
|
||||
int batch_size, int local_batch_size, int beam_width, const int* no_repeat_ngram_size_buf, int id_offset,
|
||||
int vocab_size_padded, size_t step, cudaStream_t stream)
|
||||
void invokeBanRepeatNgram(T* logits, const int** output_ids_buf, const FinishedState* finished_buf,
|
||||
const int* parent_ids_buf, int batch_size, int local_batch_size, int beam_width,
|
||||
const int* no_repeat_ngram_size_buf, int id_offset, int vocab_size_padded, size_t step, cudaStream_t stream)
|
||||
{
|
||||
// each input in the local batch can have different no_repeat_ngram_size. Use max for shmem allocation
|
||||
// getting the max of current batch and allocate shmem as needed is ideal. But here the ngram_buf is on GPU, while
|
||||
@ -160,7 +160,7 @@ void invokeBanRepeatNgram(T* logits, const int** output_ids_buf, const bool* fin
|
||||
}
|
||||
|
||||
#define INVOKE_BAN_REPEAT_NGRAM(T) \
|
||||
template void invokeBanRepeatNgram(T* logits, const int** output_ids_buf, const bool* finished_buf, \
|
||||
template void invokeBanRepeatNgram(T* logits, const int** output_ids_buf, const FinishedState* finished_buf, \
|
||||
const int* parent_ids_buf, int batch_size, int local_batch_size, int beam_width, \
|
||||
const int* no_repeat_ngram_size_buf, int id_offset, int vocab_size_padded, size_t step, cudaStream_t stream);
|
||||
|
||||
|
||||
@ -16,6 +16,7 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "tensorrt_llm/kernels/decodingCommon.h"
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
@ -25,9 +26,9 @@ namespace kernels
|
||||
{
|
||||
|
||||
template <typename T>
|
||||
void invokeBanRepeatNgram(T* logits, const int** output_ids_buf, const bool* finished_buf, const int* parent_ids_buf,
|
||||
int batch_size, int local_batch_size, int beam_width, const int* no_repeat_ngram_size_buf, int id_offset,
|
||||
int vocab_size_padded, size_t step, cudaStream_t stream);
|
||||
void invokeBanRepeatNgram(T* logits, const int** output_ids_buf, const FinishedState* finished_buf,
|
||||
const int* parent_ids_buf, int batch_size, int local_batch_size, int beam_width,
|
||||
const int* no_repeat_ngram_size_buf, int id_offset, int vocab_size_padded, size_t step, cudaStream_t stream);
|
||||
|
||||
} // namespace kernels
|
||||
} // namespace tensorrt_llm
|
||||
|
||||
@ -756,8 +756,8 @@ template void invokeTileEncoderResults(__nv_bfloat16* tiled_output, int* tiled_s
|
||||
const size_t mem_max_seq_len, const size_t d_model, cudaStream_t stream);
|
||||
#endif
|
||||
|
||||
__global__ void insertUnfinishedPath(BeamHypotheses beam_hyps, const bool* finished, const float* cum_log_probs,
|
||||
const int batch_size, const int beam_width)
|
||||
__global__ void insertUnfinishedPath(BeamHypotheses beam_hyps, const FinishedState* finished,
|
||||
const float* cum_log_probs, const int batch_size, const int beam_width)
|
||||
{
|
||||
const int bid = blockIdx.x;
|
||||
const int tgt_start_idx = beam_hyps.num_beams[bid];
|
||||
@ -802,7 +802,7 @@ __global__ void insertUnfinishedPath(BeamHypotheses beam_hyps, const bool* finis
|
||||
// TODO huggingface uses total length to normalize the scores, instead of number of generated tokens.
|
||||
// Check that is it reasonable or not.
|
||||
beam_hyps.normed_scores[tgt_beam_idx] = apply_length_penalty(cum_log_probs[src_beam_idx],
|
||||
finished[src_beam_idx] ? last_token_idx + 1 : last_token_idx, length_penalty);
|
||||
finished[src_beam_idx].isFinished() ? last_token_idx + 1 : last_token_idx, length_penalty);
|
||||
beam_hyps.cum_log_probs[tgt_beam_idx] = cum_log_probs[src_beam_idx];
|
||||
|
||||
beam_hyps.num_beams[bid]++;
|
||||
@ -810,7 +810,7 @@ __global__ void insertUnfinishedPath(BeamHypotheses beam_hyps, const bool* finis
|
||||
}
|
||||
}
|
||||
|
||||
void invokeInsertUnfinishedPath(BeamHypotheses beam_hyps, const bool* finished, const float* cum_log_probs,
|
||||
void invokeInsertUnfinishedPath(BeamHypotheses beam_hyps, const FinishedState* finished, const float* cum_log_probs,
|
||||
const int batch_size, const int beam_width, cudaStream_t stream)
|
||||
{
|
||||
insertUnfinishedPath<<<batch_size, 256, 0, stream>>>(beam_hyps, finished, cum_log_probs, batch_size, beam_width);
|
||||
|
||||
@ -14,10 +14,11 @@
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "tensorrt_llm/kernels/decodingCommon.h"
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
namespace kernels
|
||||
@ -76,7 +77,7 @@ void invokeTileEncoderResults(T* tiled_encoder_output, int* tiled_encoder_sequen
|
||||
const int* encoder_sequence_length, const size_t batch_size, const size_t beam_width, const size_t mem_max_seq_len,
|
||||
const size_t d_model, cudaStream_t stream);
|
||||
|
||||
void invokeInsertUnfinishedPath(BeamHypotheses beam_hyps, const bool* finished, const float* cum_log_probs,
|
||||
void invokeInsertUnfinishedPath(BeamHypotheses beam_hyps, const FinishedState* finished, const float* cum_log_probs,
|
||||
const int batch_size, const int beam_width, cudaStream_t stream);
|
||||
|
||||
void invokeCopyBatchMajorToGeneralPtr(
|
||||
|
||||
@ -82,13 +82,16 @@ public:
|
||||
, mHeadSize(headSize)
|
||||
, mQScaling(qScaling)
|
||||
, sm(sm_)
|
||||
, xmmaKernel(getXMMAKernelsV2(data_type, sm_))
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(
|
||||
(sm == kSM_80 || sm == kSM_86 || sm == kSM_89 || sm == kSM_90), "Unsupported architecture");
|
||||
TLLM_CHECK_WITH_INFO((mDataType == DATA_TYPE_FP16 || mDataType == DATA_TYPE_BF16), "Unsupported data type");
|
||||
|
||||
pagedKVXmmaKernel = getPagedKVXMMAKernelsV2(mDataType, sm);
|
||||
xmmaKernel = getXMMAKernelsV2(mDataType, sm);
|
||||
|
||||
params.clear();
|
||||
pagedKVParams.clear();
|
||||
|
||||
// get device attributes
|
||||
int device_id;
|
||||
@ -99,6 +102,7 @@ public:
|
||||
|
||||
~mhaImpl() {}
|
||||
|
||||
// Support packed QKV.
|
||||
void setup(const int b, const int s, const int sliding_window_size, const int total_seqlen, const bool has_alibi,
|
||||
const bool scale_alibi, const int tp_size, const int tp_rank)
|
||||
{
|
||||
@ -185,6 +189,95 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
// Support paged_kv_cache and chunked_attention.
|
||||
void setup_paged_kv(const int b, const int s_q, const int s_kv, const int blocks_per_context_sequence,
|
||||
const int tokens_per_kv_block, const int sliding_window_size, const int total_seqlen, const bool has_alibi,
|
||||
const bool scale_alibi, const int tp_size, const int tp_rank)
|
||||
{
|
||||
const float inv_sqrt_scale = (1.f / (sqrtf(mHeadSize) * mQScaling));
|
||||
// Note that we apply scales and bias in the order of
|
||||
// (bmm1_output * scale_bmm1 + alibi) * scale_after_alibi
|
||||
const float scale_after_alibi = scale_alibi ? inv_sqrt_scale : 1.0f;
|
||||
const float scale_bmm1 = scale_alibi ? 1.0f : inv_sqrt_scale;
|
||||
const float scale_softmax = 1.f; // Seems to be only required for int8
|
||||
const float scale_bmm2 = 1.f;
|
||||
|
||||
Data_type scale_type = launch_params.force_fp32_acc ? DATA_TYPE_FP32 : mDataType;
|
||||
set_alpha(pagedKVParams.scale_bmm1, scale_bmm1, scale_type);
|
||||
set_alpha(pagedKVParams.scale_softmax, scale_softmax, scale_type);
|
||||
set_alpha(pagedKVParams.scale_bmm2, scale_bmm2, scale_type);
|
||||
|
||||
pagedKVParams.b = b;
|
||||
pagedKVParams.h = mNumHeads;
|
||||
pagedKVParams.s = s_q;
|
||||
pagedKVParams.d = mHeadSize;
|
||||
|
||||
// Total sequence length needed by TMA descriptor
|
||||
// it should be actual total seq length if non-padded input is given.
|
||||
mTotalSeqLen = total_seqlen;
|
||||
|
||||
TLLM_CHECK_WITH_INFO(tokens_per_kv_block >= 128, "FMHA with paged kv cache needs tokens_per_block >= 128 !");
|
||||
// Needed by TMA descriptors.
|
||||
launch_params.blocks_per_context_sequence = blocks_per_context_sequence;
|
||||
pagedKVParams.q_stride_in_bytes = mNumHeads * mHeadSize * sizeof(half);
|
||||
pagedKVParams.kv_stride_in_bytes = tokens_per_kv_block * mHeadSize * sizeof(half);
|
||||
pagedKVParams.o_stride_in_bytes = mNumHeads * mHeadSize * sizeof(half);
|
||||
|
||||
// Hopper: fallback to original fmha_v2 when head_size <= 64 and seq_len <= 256
|
||||
const bool isSm90 = (sm == kSM_90);
|
||||
const bool isSm8x = (sm == kSM_86 || sm == kSM_89);
|
||||
const bool isSm80 = (sm == kSM_80);
|
||||
|
||||
// always use flash attention kernels.
|
||||
launch_params.flash_attention = true;
|
||||
// flash attention kernles s = 0 (support any seq length)
|
||||
launch_params.kernel_s = 0;
|
||||
launch_params.kernel_kv_s = s_kv;
|
||||
launch_params.force_unroll = true;
|
||||
|
||||
// enable warp-specialization kernels when s > 512.
|
||||
if (isSm90 && s_kv > 512)
|
||||
{
|
||||
launch_params.warp_specialization = true;
|
||||
launch_params.use_tma = true;
|
||||
}
|
||||
else
|
||||
{
|
||||
// enable tiled kernels on Ampere/Ada
|
||||
if (launch_params.flash_attention && s_kv <= 64)
|
||||
{
|
||||
// flash attention tiled kernels allows larger free dim tile size (M, N) with flexibility
|
||||
// in unroll dimension tile size (K). for short sequence length (s<=128), tiled kernels
|
||||
// can suffer from tile quantization loss therefore use flash attention non-tiled instead
|
||||
launch_params.granular_tiling = false;
|
||||
}
|
||||
else if (isSm8x && params.d < 256)
|
||||
{
|
||||
// flash attention tiled kernel is faster on Ada and Ampere derivatives when head_size>=256
|
||||
launch_params.granular_tiling = false;
|
||||
}
|
||||
else if (isSm90 || isSm80 || isSm8x)
|
||||
{
|
||||
// otherwise, choose tiled kernel for Ampere/Ada
|
||||
launch_params.granular_tiling = true;
|
||||
}
|
||||
}
|
||||
|
||||
// alibi.
|
||||
if (has_alibi)
|
||||
{
|
||||
pagedKVParams.has_alibi = true;
|
||||
pagedKVParams.alibi_params = AlibiParams(mNumHeads, s_kv, tp_size, tp_rank, scale_after_alibi);
|
||||
}
|
||||
|
||||
// Sliding_window_causal mask.
|
||||
if (s_kv > sliding_window_size && launch_params.attention_mask_type == ContextAttentionMaskType::CAUSAL)
|
||||
{
|
||||
pagedKVParams.sliding_window_size = sliding_window_size;
|
||||
launch_params.attention_mask_type = ContextAttentionMaskType::SLIDING_WINDOW_CAUSAL;
|
||||
}
|
||||
}
|
||||
|
||||
// NOTE: assume that heads_interleaved = false (b, s, 3, h, d), and sequences are padded/non-padded
|
||||
// TMA descriptors are used as grid_constant parameters (remove MemCpyH2D operations)
|
||||
void set_tma_descriptors()
|
||||
@ -246,14 +339,14 @@ public:
|
||||
if (sTmaMetaInfo[i].mD == params.d)
|
||||
{
|
||||
q_step = sTmaMetaInfo[i].mQStep;
|
||||
kv_step = sTmaMetaInfo[i].mKvStep;
|
||||
kv_step = sTmaMetaInfo[i].mKVStep;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// QKV [TOTAL, 3, h, d]
|
||||
// NOTE: we may need to use actual seqlen to set oob_value
|
||||
char* qkv_ptr = reinterpret_cast<char*>(params.qkv_ptr);
|
||||
const char* qkv_ptr = reinterpret_cast<const char*>(params.qkv_ptr);
|
||||
tensor_size_qkv[3] = mTotalSeqLen;
|
||||
|
||||
// Q: STEP_Q
|
||||
@ -275,12 +368,131 @@ public:
|
||||
¶ms.tma_desc_v);
|
||||
}
|
||||
|
||||
// Q are contiguous in the shape of [B, S, H, D]
|
||||
// Paged KV has [B, 2, NumBlocksPerSequence] buffers,
|
||||
// and each points to the contiguous buffer with shape [H, TokensPerBlock, D]
|
||||
// TMA descriptors need cudaMemcpyAsync since we need multiple tma descriptors in device memory.
|
||||
void set_paged_kv_tma_descriptors(cudaStream_t stream)
|
||||
{
|
||||
// split D into multiple groups in order to match the TMA swizzle mode (128B)
|
||||
const uint32_t d_in_bytes = pagedKVParams.d * sizeof(uint16_t);
|
||||
const uint32_t d_groups = d_in_bytes > 128 ? d_in_bytes / 128 : 1;
|
||||
|
||||
uint32_t q_step = 0, kv_step = 0;
|
||||
for (unsigned int i = 0u; i < sizeof(sTmaPagedKVMetaInfo) / sizeof(sTmaPagedKVMetaInfo[0]); ++i)
|
||||
{
|
||||
if (sTmaPagedKVMetaInfo[i].mD == pagedKVParams.d)
|
||||
{
|
||||
q_step = sTmaPagedKVMetaInfo[i].mQStep;
|
||||
kv_step = sTmaPagedKVMetaInfo[i].mKVStep;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Separate q, and paged kv tma descriptors.
|
||||
Multiple_tma_descriptor<4> q_tma_descriptor;
|
||||
Multiple_tma_descriptor<4> paged_kv_tma_descriptor(
|
||||
pagedKVParams.b * 2 * launch_params.blocks_per_context_sequence);
|
||||
// Contiguous Q
|
||||
// query tensor size [B x S, 1, H, D]
|
||||
uint32_t tensor_size_q[4];
|
||||
tensor_size_q[3] = mTotalSeqLen;
|
||||
tensor_size_q[2] = 1;
|
||||
tensor_size_q[1] = pagedKVParams.h;
|
||||
tensor_size_q[0] = pagedKVParams.d;
|
||||
|
||||
// box size for k and v
|
||||
uint32_t box_size_q[4];
|
||||
box_size_q[3] = q_step;
|
||||
box_size_q[2] = 1;
|
||||
box_size_q[1] = 1;
|
||||
box_size_q[0] = pagedKVParams.d / d_groups;
|
||||
|
||||
// stride size in bytes.
|
||||
uint64_t tensor_stride_q[3];
|
||||
tensor_stride_q[0] = tensor_size_q[0] * sizeof(uint16_t);
|
||||
tensor_stride_q[1] = tensor_size_q[1] * tensor_stride_q[0];
|
||||
tensor_stride_q[2] = tensor_size_q[2] * tensor_stride_q[1];
|
||||
|
||||
// traversal stride
|
||||
uint32_t traversal_stride[4] = {1, 1, 1, 1};
|
||||
|
||||
// OOB fill zeros
|
||||
uint32_t oob_fill = 0;
|
||||
|
||||
// FP32 to TF32 conversion disabled
|
||||
uint32_t fp32_to_tf32 = 0;
|
||||
|
||||
// gmma descriptor mode
|
||||
const uint32_t d_bytes_per_group = (pagedKVParams.d * sizeof(uint16_t)) / d_groups;
|
||||
const cudaTmaDescSwizzle swizzle_mode = (d_bytes_per_group > 64
|
||||
? cudaTmaDescSwizzle::SWIZZLE_128B
|
||||
: (d_bytes_per_group > 32 ? cudaTmaDescSwizzle::SWIZZLE_64B : cudaTmaDescSwizzle::SWIZZLE_32B));
|
||||
|
||||
// Q ptr.
|
||||
const char* q_ptr = reinterpret_cast<const char*>(pagedKVParams.q_ptr);
|
||||
|
||||
// Q: STEP_Q.
|
||||
q_tma_descriptor.set_tma_desctriptor(q_ptr, cudaTmaDescFormat::F16_RN,
|
||||
cudaTmaDescInterleave::INTERLEAVE_DISABLED, swizzle_mode, cudaTmaDescPromotion::PROMOTION_DISABLED,
|
||||
tensor_size_q, tensor_stride_q, traversal_stride, box_size_q, oob_fill, fp32_to_tf32,
|
||||
&pagedKVParams.tma_desc_q);
|
||||
|
||||
// Paged KV
|
||||
// Per batch tensor size.
|
||||
uint32_t tensor_size_kv[4];
|
||||
tensor_size_kv[3] = 1;
|
||||
tensor_size_kv[2] = pagedKVParams.h_kv;
|
||||
tensor_size_kv[1] = pagedKVParams.paged_kv_cache.mTokensPerBlock;
|
||||
tensor_size_kv[0] = pagedKVParams.d;
|
||||
|
||||
// Box size for k and v.
|
||||
uint32_t box_size_kv[4];
|
||||
box_size_kv[3] = 1;
|
||||
box_size_kv[2] = 1;
|
||||
box_size_kv[1] = kv_step;
|
||||
box_size_kv[0] = pagedKVParams.d / d_groups;
|
||||
|
||||
// Stride size in bytes.
|
||||
uint64_t tensor_stride_kv[3];
|
||||
tensor_stride_kv[0] = tensor_size_kv[0] * sizeof(uint16_t);
|
||||
tensor_stride_kv[1] = tensor_size_kv[1] * tensor_stride_kv[0];
|
||||
tensor_stride_kv[2] = tensor_size_kv[2] * tensor_stride_kv[1];
|
||||
|
||||
// 2 stands for k, and v blocks.
|
||||
// We only need to prepare as many tma descriptos as the number of paged kv blocks for context.
|
||||
for (int block_idx = 0; block_idx < pagedKVParams.b * 2 * launch_params.blocks_per_context_sequence;
|
||||
block_idx++)
|
||||
{
|
||||
int block_ptr_idx = int(block_idx / launch_params.blocks_per_context_sequence)
|
||||
* pagedKVParams.paged_kv_cache.mMaxBlocksPerSeq
|
||||
+ (block_idx % launch_params.blocks_per_context_sequence);
|
||||
paged_kv_tma_descriptor.set_tma_desctriptor(
|
||||
reinterpret_cast<char*>(launch_params.paged_kv_block_ptrs[block_ptr_idx]), cudaTmaDescFormat::F16_RN,
|
||||
cudaTmaDescInterleave::INTERLEAVE_DISABLED, swizzle_mode, cudaTmaDescPromotion::PROMOTION_DISABLED,
|
||||
tensor_size_kv, tensor_stride_kv, traversal_stride, box_size_kv, oob_fill, fp32_to_tf32, block_idx);
|
||||
}
|
||||
|
||||
// set mMaxBlocksPerSeq to the number of blocks needed for context.
|
||||
pagedKVParams.paged_kv_cache.mMaxBlocksPerSeq = launch_params.blocks_per_context_sequence;
|
||||
|
||||
paged_kv_tma_descriptor.copy_to_device(pagedKVParams.tma_desc_paged_kv, stream);
|
||||
}
|
||||
|
||||
void setup_flags(const bool force_fp32_acc, const bool is_s_padded, const bool causal_mask, const int num_kv_heads)
|
||||
{
|
||||
// BF16 FMHA only accumulates on FP32
|
||||
launch_params.force_fp32_acc = mDataType == DATA_TYPE_BF16 || force_fp32_acc;
|
||||
launch_params.attention_mask_type
|
||||
= causal_mask ? ContextAttentionMaskType::CAUSAL : ContextAttentionMaskType::PADDING;
|
||||
|
||||
// Paged KV Cache.
|
||||
pagedKVParams.h_kv = num_kv_heads;
|
||||
TLLM_CHECK_WITH_INFO(mNumHeads % num_kv_heads == 0, "number of Query heads should be multiple of KV heads !");
|
||||
pagedKVParams.h_q_per_kv = mNumHeads / num_kv_heads;
|
||||
pagedKVParams.is_s_padded = is_s_padded;
|
||||
|
||||
// Contiguous Cache.
|
||||
params.h_kv = num_kv_heads;
|
||||
params.is_s_padded = is_s_padded;
|
||||
}
|
||||
@ -290,11 +502,11 @@ public:
|
||||
return MHARunner::fmha_supported(mHeadSize, sm);
|
||||
}
|
||||
|
||||
void run(const void* qkvPtr, const void* cuSeqlenPtr, void* output, cudaStream_t stream)
|
||||
void run(const void* qkvPtr, const void* cuSeqlenPtr, void* outputPtr, cudaStream_t stream)
|
||||
{
|
||||
params.qkv_ptr = const_cast<void*>(qkvPtr);
|
||||
params.o_ptr = output;
|
||||
params.cu_seqlens = static_cast<int*>(const_cast<void*>(cuSeqlenPtr));
|
||||
params.qkv_ptr = qkvPtr;
|
||||
params.o_ptr = outputPtr;
|
||||
params.cu_seqlens = reinterpret_cast<const int*>(cuSeqlenPtr);
|
||||
|
||||
if (sm == kSM_90 && launch_params.use_tma)
|
||||
{
|
||||
@ -305,9 +517,31 @@ public:
|
||||
xmmaKernel->run(params, launch_params, stream);
|
||||
}
|
||||
|
||||
void run_paged_kv(const void* qPtr, void* pagedKVTmaDesc, const void* pagedKVBlockPtrsOnHost,
|
||||
const KVBlockArray pagedKVCache, const void* cuQSeqlenPtr, const void* cuKVSeqlenPtr, void* outputPtr,
|
||||
cudaStream_t stream)
|
||||
{
|
||||
pagedKVParams.q_ptr = qPtr;
|
||||
pagedKVParams.tma_desc_paged_kv = reinterpret_cast<cudaTmaDesc*>(pagedKVTmaDesc);
|
||||
pagedKVParams.paged_kv_cache = pagedKVCache;
|
||||
pagedKVParams.o_ptr = outputPtr;
|
||||
pagedKVParams.cu_q_seqlens = reinterpret_cast<const int*>(cuQSeqlenPtr);
|
||||
pagedKVParams.cu_seqlens = reinterpret_cast<const int*>(cuKVSeqlenPtr);
|
||||
// paged kv block device ptrs on host (used by tma descriptors).
|
||||
launch_params.paged_kv_block_ptrs = reinterpret_cast<const int64_t*>(pagedKVBlockPtrsOnHost);
|
||||
|
||||
if (sm == kSM_90 && launch_params.use_tma)
|
||||
{
|
||||
// memcpy H2D is needed as we use multiple tma descriptors in device memory.
|
||||
set_paged_kv_tma_descriptors(stream);
|
||||
}
|
||||
|
||||
pagedKVXmmaKernel->run(pagedKVParams, launch_params, stream);
|
||||
}
|
||||
|
||||
bool isValid(int s) const
|
||||
{
|
||||
return xmmaKernel->isValid(s);
|
||||
return pagedKVXmmaKernel->isValid(s) && xmmaKernel->isValid(s);
|
||||
}
|
||||
|
||||
int getSFromMaxSeqLen(const int max_seq_len)
|
||||
@ -345,9 +579,11 @@ public:
|
||||
|
||||
private:
|
||||
Fused_multihead_attention_params_v2 params;
|
||||
Fused_multihead_attention_paged_kv_params_v2 pagedKVParams;
|
||||
Launch_params launch_params;
|
||||
int sm;
|
||||
const FusedMultiHeadAttentionXMMAKernelV2* xmmaKernel;
|
||||
const FusedMultiHeadAttentionPagedKVXMMAKernelV2* pagedKVXmmaKernel;
|
||||
bool use_flash_attention = false;
|
||||
const Data_type mDataType;
|
||||
const int mNumHeads;
|
||||
@ -372,6 +608,14 @@ void FusedMHARunnerV2::setup(const int b, const int s, const int sliding_window_
|
||||
pimpl->setup(b, s, sliding_window_size, total_seqlen, has_alibi, scale_alibi, tp_size, tp_rank);
|
||||
}
|
||||
|
||||
void FusedMHARunnerV2::setup_paged_kv(const int b, const int s_q, const int s_kv, const int blocks_per_context_sequence,
|
||||
const int tokens_per_kv_block, const int sliding_window_size, const int total_seqlen, const bool has_alibi,
|
||||
const bool scale_alibi, const int tp_size, const int tp_rank)
|
||||
{
|
||||
pimpl->setup_paged_kv(b, s_q, s_kv, blocks_per_context_sequence, tokens_per_kv_block, sliding_window_size,
|
||||
total_seqlen, has_alibi, scale_alibi, tp_size, tp_rank);
|
||||
}
|
||||
|
||||
bool FusedMHARunnerV2::fmha_supported()
|
||||
{
|
||||
return pimpl->fmha_supported();
|
||||
@ -383,9 +627,17 @@ void FusedMHARunnerV2::setup_flags(
|
||||
pimpl->setup_flags(force_fp32_acc, is_s_padded, causal_mask, num_kv_heads);
|
||||
}
|
||||
|
||||
void FusedMHARunnerV2::run(const void* qkvPtr, const void* cuSeqlenPtr, void* output, cudaStream_t stream)
|
||||
void FusedMHARunnerV2::run(const void* qkvPtr, const void* cuSeqlenPtr, void* outputPtr, cudaStream_t stream)
|
||||
{
|
||||
pimpl->run(qkvPtr, cuSeqlenPtr, output, stream);
|
||||
pimpl->run(qkvPtr, cuSeqlenPtr, outputPtr, stream);
|
||||
}
|
||||
|
||||
void FusedMHARunnerV2::run_paged_kv(const void* qPtr, void* pagedKVTmaDesc, const void* pagedKVBlockPtrsOnHost,
|
||||
const KVBlockArray pagedKVCache, const void* cuQSeqlenPtr, const void* cuKVSeqlenPtr, void* outputPtr,
|
||||
cudaStream_t stream)
|
||||
{
|
||||
pimpl->run_paged_kv(
|
||||
qPtr, pagedKVTmaDesc, pagedKVBlockPtrsOnHost, pagedKVCache, cuQSeqlenPtr, cuKVSeqlenPtr, outputPtr, stream);
|
||||
}
|
||||
|
||||
bool FusedMHARunnerV2::isValid(int s) const
|
||||
|
||||
@ -51,6 +51,11 @@ public:
|
||||
const bool has_alibi = false, const bool scale_alibi = false, const int tp_size = 1, const int tp_rank = 0)
|
||||
= 0;
|
||||
|
||||
virtual void setup_paged_kv(const int b, const int s_q, const int s_kv, const int blocks_per_context_sequence,
|
||||
const int tokens_per_kv_block, const int sliding_window_size, const int total_seqlen,
|
||||
const bool has_alibi = false, const bool scale_alibi = false, const int tp_size = 1, const int tp_rank = 0)
|
||||
= 0;
|
||||
|
||||
static bool fmha_supported(const int headSize, const int sm);
|
||||
|
||||
virtual bool fmha_supported() = 0;
|
||||
@ -61,6 +66,11 @@ public:
|
||||
|
||||
virtual void run(const void* input, const void* cu_seqlens, void* output, cudaStream_t stream) = 0;
|
||||
|
||||
virtual void run_paged_kv(const void* q_input, void* paged_kv_tma_desc, const void* paged_kv_block_ptrs_on_host,
|
||||
const KVBlockArray paged_kv_cache, const void* cu_q_seqlens, const void* cu_kv_seqlens, void* output,
|
||||
cudaStream_t stream)
|
||||
= 0;
|
||||
|
||||
virtual bool isValid(int s) const = 0;
|
||||
};
|
||||
|
||||
@ -84,9 +94,17 @@ public:
|
||||
const bool has_alibi = false, const bool scale_alibi = false, const int tp_size = 1,
|
||||
const int tp_rank = 0) override;
|
||||
|
||||
void setup_paged_kv(const int b, const int s_q, const int s_kv, const int blocks_per_context_sequence,
|
||||
const int tokens_per_kv_block, const int sliding_window_size, const int total_seqlen,
|
||||
const bool has_alibi = false, const bool scale_alibi = false, const int tp_size = 1,
|
||||
const int tp_rank = 0) override;
|
||||
|
||||
bool fmha_supported() override;
|
||||
|
||||
void run(const void* input, const void* cu_seqlens, void* output, cudaStream_t stream) override;
|
||||
void run_paged_kv(const void* q_input, void* paged_kv_tma_desc, const void* paged_kv_block_ptrs_on_host,
|
||||
const KVBlockArray paged_kv_cache, const void* cu_q_seqlens, const void* cu_kv_seqlens, void* output,
|
||||
cudaStream_t stream) override;
|
||||
|
||||
void setup_flags(const bool force_fp32_acc, const bool is_s_padded, const bool causal_mask,
|
||||
const int num_kv_heads /* MQA or GQA */) override;
|
||||
|
||||
@ -16,26 +16,17 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "tensorrt_llm/kernels/kvCacheUtils.h"
|
||||
#include "tmaDescriptor.h"
|
||||
#include <limits.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#include "tensorrt_llm/kernels/multiHeadAttentionCommon.h"
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
namespace kernels
|
||||
{
|
||||
enum Data_type
|
||||
{
|
||||
DATA_TYPE_BOOL,
|
||||
DATA_TYPE_FP16,
|
||||
DATA_TYPE_FP32,
|
||||
DATA_TYPE_INT4,
|
||||
DATA_TYPE_INT8,
|
||||
DATA_TYPE_INT32,
|
||||
DATA_TYPE_BF16,
|
||||
DATA_TYPE_E4M3,
|
||||
DATA_TYPE_E5M2
|
||||
};
|
||||
|
||||
enum class ContextFMHAType
|
||||
{
|
||||
@ -52,14 +43,6 @@ enum class ContextAttentionMaskType
|
||||
SLIDING_WINDOW_CAUSAL
|
||||
};
|
||||
|
||||
constexpr int32_t kSM_70 = 70;
|
||||
constexpr int32_t kSM_72 = 72;
|
||||
constexpr int32_t kSM_75 = 75;
|
||||
constexpr int32_t kSM_80 = 80;
|
||||
constexpr int32_t kSM_86 = 86;
|
||||
constexpr int32_t kSM_89 = 89;
|
||||
constexpr int32_t kSM_90 = 90;
|
||||
|
||||
struct AlibiParams
|
||||
{
|
||||
constexpr static int round_down_to_power_two(int x)
|
||||
@ -101,9 +84,9 @@ struct AlibiParams
|
||||
struct Fused_multihead_attention_params_v2
|
||||
{
|
||||
// The QKV matrices.
|
||||
void* qkv_ptr;
|
||||
const void* qkv_ptr;
|
||||
// The mask to implement drop-out.
|
||||
void* packed_mask_ptr;
|
||||
const void* packed_mask_ptr;
|
||||
// The O matrix (output).
|
||||
void* o_ptr;
|
||||
|
||||
@ -123,7 +106,7 @@ struct Fused_multihead_attention_params_v2
|
||||
bool enable_i2f_trick;
|
||||
|
||||
// array of length b+1 holding prefix sum of actual sequence lengths
|
||||
int* cu_seqlens;
|
||||
const int* cu_seqlens;
|
||||
|
||||
// use C/32 Format.
|
||||
bool interleaved = false;
|
||||
@ -183,6 +166,102 @@ struct Fused_multihead_attention_params_v2
|
||||
use_int8_scale_max = false;
|
||||
|
||||
h_kv = 0;
|
||||
sliding_window_size = INT_MAX;
|
||||
is_s_padded = false;
|
||||
|
||||
has_alibi = false;
|
||||
alibi_params = AlibiParams{};
|
||||
}
|
||||
};
|
||||
|
||||
struct Fused_multihead_attention_paged_kv_params_v2
|
||||
{
|
||||
// The Q matrices.
|
||||
const void* q_ptr;
|
||||
// Paged KV Cache buffer.
|
||||
KVBlockArray paged_kv_cache;
|
||||
// The O matrix (output).
|
||||
void* o_ptr;
|
||||
// The packed mask for random mask.
|
||||
const void* packed_mask_ptr;
|
||||
|
||||
// The stride between rows of the Q matrices.
|
||||
int64_t q_stride_in_bytes;
|
||||
// The stride between rows of the paged KV matrices.
|
||||
int64_t kv_stride_in_bytes;
|
||||
// The stride between rows of O.
|
||||
int64_t o_stride_in_bytes;
|
||||
// The stride between matrices of packed mask.
|
||||
int64_t packed_mask_stride_in_bytes;
|
||||
|
||||
// The dimensions.
|
||||
int b, h, s, d;
|
||||
// The scaling factors for the kernel.
|
||||
uint32_t scale_bmm1, scale_softmax, scale_bmm2;
|
||||
|
||||
// Do we use Niall's trick to avoid I2F/F2I in the INT8 kernel.
|
||||
// See https://confluence.nvidia.com/pages/viewpage.action?pageId=302779721 for details.
|
||||
bool enable_i2f_trick;
|
||||
|
||||
// true: for int8, instead of doing max reduce, use max value encoded in scale factor
|
||||
bool use_int8_scale_max = false;
|
||||
|
||||
// If the kernel is using alibi or not
|
||||
bool has_alibi = false;
|
||||
AlibiParams alibi_params;
|
||||
|
||||
// array of length b+1 holding prefix sum of actual kv sequence lengths.
|
||||
const int* cu_seqlens;
|
||||
// Chunked attention (only handles one tile of Q).
|
||||
const int* cu_q_seqlens;
|
||||
|
||||
// q with shape [B, S, H, D] in const cache.
|
||||
cudaTmaDesc tma_desc_q;
|
||||
// Tma descriptors for paged kv cache.
|
||||
// paged kv has [B, 2, Num_blocks] buffers,
|
||||
// and each points to [Tokens_per_block, H, D] contiguous memory.
|
||||
cudaTmaDesc* tma_desc_paged_kv;
|
||||
|
||||
// In multi-query or grouped-query attention (MQA/GQA), several Q heads are associated with one KV head
|
||||
int h_kv = 0;
|
||||
int h_q_per_kv = 0;
|
||||
|
||||
// Sliding Window Attention
|
||||
// Only pay attention to [max(0, query_idx - sliding_window_size), query_idx].
|
||||
int sliding_window_size = INT_MAX;
|
||||
|
||||
// is input/output padded
|
||||
bool is_s_padded = false;
|
||||
|
||||
void clear()
|
||||
{
|
||||
q_ptr = nullptr;
|
||||
o_ptr = nullptr;
|
||||
packed_mask_ptr = nullptr;
|
||||
|
||||
q_stride_in_bytes = 0;
|
||||
kv_stride_in_bytes = 0;
|
||||
o_stride_in_bytes = 0;
|
||||
packed_mask_stride_in_bytes = 0;
|
||||
|
||||
b = 0;
|
||||
h = 0;
|
||||
s = 0;
|
||||
d = 0;
|
||||
// The scaling factors for the kernel.
|
||||
scale_bmm1 = 0;
|
||||
scale_softmax = 0;
|
||||
scale_bmm2 = 0;
|
||||
|
||||
enable_i2f_trick = false;
|
||||
|
||||
cu_seqlens = nullptr;
|
||||
cu_q_seqlens = nullptr;
|
||||
use_int8_scale_max = false;
|
||||
|
||||
h_kv = 0;
|
||||
h_q_per_kv = 0;
|
||||
sliding_window_size = INT_MAX;
|
||||
is_s_padded = false;
|
||||
|
||||
has_alibi = false;
|
||||
@ -195,6 +274,8 @@ struct Launch_params
|
||||
{
|
||||
// seq_length to select the kernel
|
||||
int kernel_s = 0;
|
||||
// kv_seq_length to set launch strategies.
|
||||
int kernel_kv_s = 0;
|
||||
// flags to control small batch kernel choice
|
||||
// true: never unroll
|
||||
bool ignore_b1opt = false;
|
||||
@ -208,6 +289,10 @@ struct Launch_params
|
||||
bool use_tma = false;
|
||||
// host seqlens to set tma descriptors
|
||||
int* seqlens = nullptr;
|
||||
// number of paged kv blocks for context sequence.
|
||||
int blocks_per_context_sequence = 0;
|
||||
// device ptrs on the host for paged kv cache.
|
||||
const int64_t* paged_kv_block_ptrs = nullptr;
|
||||
// if flash attention is used (only FP16)
|
||||
bool flash_attention = false;
|
||||
// if warp_specialized kernels are used (only SM90 HGMMA + TMA)
|
||||
|
||||
@ -18,6 +18,7 @@
|
||||
#include "cubin/fmha_cubin.h"
|
||||
#include "cuda_runtime_api.h"
|
||||
#include "fused_multihead_attention_common.h"
|
||||
#include "pagedKVCubin/fmha_cubin.h"
|
||||
#include "tensorrt_llm/common/assert.h"
|
||||
#include "tensorrt_llm/common/cudaDriverWrapper.h"
|
||||
#include "tmaDescriptor.h"
|
||||
@ -44,9 +45,19 @@ static const struct TmaKernelMetaInfo
|
||||
{
|
||||
unsigned int mD;
|
||||
unsigned int mQStep;
|
||||
unsigned int mKvStep;
|
||||
unsigned int mKVStep;
|
||||
} sTmaMetaInfo[] = {{32, 64, 256}, {64, 64, 256}, {128, 64, 128}, {256, 64, 64}};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// meta info for tma warp-specialized kernels that supports paged kv cache
|
||||
static const struct TmaPagedKVKernelMetaInfo
|
||||
{
|
||||
unsigned int mD;
|
||||
unsigned int mQStep;
|
||||
unsigned int mKVStep;
|
||||
} sTmaPagedKVMetaInfo[] = {{32, 64, 128}, {64, 64, 128}, {128, 64, 128}, {256, 64, 64}};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// Base Class
|
||||
|
||||
@ -206,6 +217,8 @@ private:
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// FMHA kernels that support Contiguous QKV input.
|
||||
|
||||
class FusedMultiHeadAttentionXMMAKernelV2
|
||||
: public TFusedMultiHeadAttentionXMMAKernel<FusedMultiHeadAttentionKernelMetaInfoV2,
|
||||
Fused_multihead_attention_params_v2>
|
||||
@ -291,7 +304,7 @@ public:
|
||||
kernelMeta.mSharedMemBytes, stream, kernelParams, nullptr),
|
||||
mDriver);
|
||||
} // forceunroll = true for flash attention kernels
|
||||
else if (mSM == kSM_90 && launch_params.flash_attention)
|
||||
else if (mSM == kSM_90 && launch_params.flash_attention && launch_params.warp_specialization)
|
||||
{
|
||||
// tricks for launching warp-specialized flash attention kernels on Hopper
|
||||
dim3 block_size(1, std::min(params.b * params.h, launch_params.multi_processor_count));
|
||||
@ -302,9 +315,9 @@ public:
|
||||
size_t m_steps = size_t((params.s + kernelMeta.mUnrollStep - 1) / kernelMeta.mUnrollStep);
|
||||
m_steps = size_t((m_steps + NUM_COMPUTE_GROUPS - 1) / NUM_COMPUTE_GROUPS) * NUM_COMPUTE_GROUPS;
|
||||
|
||||
size_t size_in_bytes = params.b * params.s * params.qkv_stride_in_bytes;
|
||||
if (launch_params.attention_mask_type == ContextAttentionMaskType::PADDING
|
||||
&& size_in_bytes <= launch_params.device_l2_cache_size)
|
||||
// 2 * 2 stands for kv cache and 2 bytes per element.
|
||||
size_t size_in_bytes = block_size.y * params.s * params.d * 2 * 2;
|
||||
if (size_in_bytes <= launch_params.device_l2_cache_size)
|
||||
{
|
||||
// strategy 1: limit to only 1 wave
|
||||
block_size.x = std::min(m_steps / NUM_COMPUTE_GROUPS, sms_per_head);
|
||||
@ -328,8 +341,8 @@ public:
|
||||
{
|
||||
unroll = (params.s + kernelMeta.mUnrollStep - 1) / kernelMeta.mUnrollStep;
|
||||
}
|
||||
// on Hopper, we still launch blocks (h, b, steps)
|
||||
if (mSM == kSM_90)
|
||||
// on Hopper non-flash-attention, we still launch blocks (h, b, steps)
|
||||
if (mSM == kSM_90 && !launch_params.flash_attention)
|
||||
{
|
||||
cuErrCheck(mDriver.cuLaunchKernel(func, params.h, params.b, unroll, kernelMeta.mThreadsPerCTA, 1, 1,
|
||||
kernelMeta.mSharedMemBytes, stream, kernelParams, nullptr),
|
||||
@ -353,5 +366,115 @@ inline const FusedMultiHeadAttentionXMMAKernelV2* getXMMAKernelsV2(Data_type typ
|
||||
sMhaKernelMetaInfosV2, sizeof(sMhaKernelMetaInfosV2) / sizeof(sMhaKernelMetaInfosV2[0]), type, sm);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// FMHA kernels that support Paged KV Cache and Chunked Attention.
|
||||
|
||||
class FusedMultiHeadAttentionPagedKVXMMAKernelV2
|
||||
: public TFusedMultiHeadAttentionXMMAKernel<FusedMultiHeadAttentionPagedKVKernelMetaInfoV2,
|
||||
Fused_multihead_attention_paged_kv_params_v2>
|
||||
{
|
||||
public:
|
||||
FusedMultiHeadAttentionPagedKVXMMAKernelV2(const FusedMultiHeadAttentionPagedKVKernelMetaInfoV2* pMetaStart,
|
||||
unsigned int nMetaCount, Data_type type, unsigned int sm)
|
||||
: TFusedMultiHeadAttentionXMMAKernel<FusedMultiHeadAttentionPagedKVKernelMetaInfoV2,
|
||||
Fused_multihead_attention_paged_kv_params_v2>(pMetaStart, nMetaCount, type, sm)
|
||||
{
|
||||
}
|
||||
|
||||
inline uint64_t hashID(unsigned int s, unsigned int d, bool interleaved, bool unroll, bool force_fp32_acc,
|
||||
bool flash_attention, bool warp_specialization, int attention_mask_type, bool tiled) const
|
||||
{
|
||||
s = flash_attention ? 0 : s;
|
||||
// D <= 2048
|
||||
return (uint64_t) s << 32 | d << 16 | (attention_mask_type << 6) | (warp_specialization ? 16ull : 0ull)
|
||||
| (tiled ? 16ull : 0ull) | (force_fp32_acc ? 8ull : 0ull) | (flash_attention ? 4ull : 0ull)
|
||||
| (interleaved ? 2ull : 0ull) | (unroll ? 1ull : 0ull);
|
||||
}
|
||||
|
||||
virtual uint64_t hashID(const KernelMeta& kernelMeta) const
|
||||
{
|
||||
return hashID(kernelMeta.mS, kernelMeta.mD, kernelMeta.mInterleaved, kernelMeta.mUnrollStep,
|
||||
kernelMeta.mFP32Accumulation, kernelMeta.mFlashAttention, kernelMeta.mWarpSpecialization,
|
||||
kernelMeta.mAttentionMaskType, kernelMeta.mTiled);
|
||||
}
|
||||
|
||||
virtual void run(
|
||||
Fused_multihead_attention_paged_kv_params_v2& params, Launch_params& launch_params, cudaStream_t stream) const
|
||||
{
|
||||
|
||||
const auto findIter = mFunctions.find(
|
||||
hashID(launch_params.kernel_s, params.d, launch_params.interleaved, launch_params.force_unroll,
|
||||
launch_params.force_fp32_acc, launch_params.flash_attention, launch_params.warp_specialization,
|
||||
static_cast<int>(launch_params.attention_mask_type), launch_params.granular_tiling));
|
||||
TLLM_CHECK_WITH_INFO(findIter != mFunctions.end(), "FMHA kernels are not found");
|
||||
|
||||
const auto& kernelMeta = mKernelMeta[findIter->second.mMetaInfoIndex];
|
||||
const CUfunction func = findIter->second.mDeviceFunction;
|
||||
|
||||
void* kernelParams[] = {¶ms, nullptr};
|
||||
|
||||
if (mSM == kSM_90 && launch_params.flash_attention && launch_params.warp_specialization)
|
||||
{
|
||||
// tricks for launching warp-specialized flash attention kernels on Hopper
|
||||
dim3 block_size(1, std::min(params.b * params.h, launch_params.multi_processor_count));
|
||||
|
||||
// distribute m steps to multiple blocks (fully utilize SMs)
|
||||
// block.x = blocks that handle single head, block.y = blocks that handle different heads
|
||||
size_t sms_per_head = (launch_params.multi_processor_count) / block_size.y;
|
||||
size_t m_steps = size_t((params.s + kernelMeta.mUnrollStep - 1) / kernelMeta.mUnrollStep);
|
||||
m_steps = size_t((m_steps + NUM_COMPUTE_GROUPS - 1) / NUM_COMPUTE_GROUPS) * NUM_COMPUTE_GROUPS;
|
||||
|
||||
// 2 * 2 stands for kv cache and 2 bytes per element.
|
||||
size_t size_in_bytes = block_size.y * launch_params.kernel_kv_s * params.d * 2 * 2;
|
||||
if (size_in_bytes <= launch_params.device_l2_cache_size)
|
||||
{
|
||||
// strategy 1: limit to only 1 wave
|
||||
block_size.x = std::min(m_steps / NUM_COMPUTE_GROUPS, sms_per_head);
|
||||
}
|
||||
else
|
||||
{
|
||||
// strategy 2: fully unroll the q loops (contiguous blocks handle all q loops)
|
||||
block_size.x = m_steps / NUM_COMPUTE_GROUPS;
|
||||
}
|
||||
|
||||
cuErrCheck(mDriver.cuLaunchKernel(func, block_size.x, block_size.y, block_size.z, kernelMeta.mThreadsPerCTA,
|
||||
1, 1, kernelMeta.mSharedMemBytes, stream, kernelParams, nullptr),
|
||||
mDriver);
|
||||
}
|
||||
else
|
||||
{ // forceunroll = true for flash attention kernels
|
||||
int unroll = kernelMeta.mS / kernelMeta.mUnrollStep;
|
||||
TLLM_CHECK_WITH_INFO(kernelMeta.mS == kernelMeta.mUnrollStep * unroll, "Wrong launching sequence length");
|
||||
// flash attention supports any sequence length, so we runtime s here
|
||||
if (launch_params.flash_attention)
|
||||
{
|
||||
unroll = (params.s + kernelMeta.mUnrollStep - 1) / kernelMeta.mUnrollStep;
|
||||
}
|
||||
// on Hopper non-flash-attention, we still launch blocks (h, b, steps)
|
||||
if (mSM == kSM_90 && !launch_params.flash_attention)
|
||||
{
|
||||
cuErrCheck(mDriver.cuLaunchKernel(func, params.h, params.b, unroll, kernelMeta.mThreadsPerCTA, 1, 1,
|
||||
kernelMeta.mSharedMemBytes, stream, kernelParams, nullptr),
|
||||
mDriver);
|
||||
} // on Ampere/Ada flash attention, we launch blocks (steps, h, b)
|
||||
else
|
||||
{
|
||||
cuErrCheck(mDriver.cuLaunchKernel(func, unroll, params.h, params.b, kernelMeta.mThreadsPerCTA, 1, 1,
|
||||
kernelMeta.mSharedMemBytes, stream, kernelParams, nullptr),
|
||||
mDriver);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
using FusedMHAPagedKVKernelFactoryV2 = TFusedMHAKernelFactory<FusedMultiHeadAttentionPagedKVXMMAKernelV2>;
|
||||
|
||||
inline const FusedMultiHeadAttentionPagedKVXMMAKernelV2* getPagedKVXMMAKernelsV2(Data_type type, unsigned int sm)
|
||||
{
|
||||
return FusedMHAPagedKVKernelFactoryV2::Get().getXMMAKernels(sMhaPagedKVKernelMetaInfosV2,
|
||||
sizeof(sMhaPagedKVKernelMetaInfosV2) / sizeof(sMhaPagedKVKernelMetaInfosV2[0]), type, sm);
|
||||
}
|
||||
|
||||
} // namespace kernels
|
||||
} // namespace tensorrt_llm
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user