Update TensorRT-LLM (#524)

This commit is contained in:
Kaiyu Xie 2023-12-01 22:27:51 +08:00 committed by GitHub
parent 711a28d9bf
commit 71f60f6df0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
464 changed files with 2098084 additions and 6802 deletions

4
.gitignore vendored
View File

@ -18,6 +18,10 @@ venv/
.hypothesis/
.idea/
cpp/cmake-build-*
cpp/.ccache/
tensorrt_llm/libs
tensorrt_llm/bindings.pyi
tensorrt_llm/bindings/*.pyi
# Testing
.coverage.*

1
.gitmodules vendored
View File

@ -1,7 +1,6 @@
[submodule "3rdparty/cutlass"]
path = 3rdparty/cutlass
url = https://github.com/NVIDIA/cutlass.git
branch = v2.10.0
[submodule "3rdparty/json"]
path = 3rdparty/json
url = https://github.com/nlohmann/json.git

View File

@ -15,7 +15,7 @@ repos:
rev: v4.1.0
hooks:
- id: check-added-large-files
exclude: 'cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin'
exclude: 'cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/'
- id: check-merge-conflict
- id: check-symlinks
- id: detect-private-key
@ -33,9 +33,7 @@ repos:
- id: clang-format
types_or: [c++, c, cuda]
exclude: |
(?x)^(
cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/.*
)$
(?x)^(.*cubin.cpp$ | .*fmha_cubin.h)$
- repo: https://github.com/cheshirekow/cmake-format-precommit
rev: v0.6.10
hooks:

124
README.md
View File

@ -36,7 +36,6 @@ H200 FP8 achieves 11,819 tok/s on Llama2-13B on a single GPU, and is up to 1.9x
[2023/10/4 - Perplexity](https://blog.perplexity.ai/blog/introducing-pplx-api) ;
[2023/9/27 - CloudFlare](https://www.cloudflare.com/press-releases/2023/cloudflare-powers-hyper-local-ai-inference-with-nvidia/);
## Table of Contents
- [TensorRT-LLM Overview](#tensorrt-llm-overview)
@ -186,7 +185,8 @@ TensorRT-LLM is rigorously tested on the following GPUs:
* [H100](https://www.nvidia.com/en-us/data-center/h100/)
* [L40S](https://www.nvidia.com/en-us/data-center/l40s/)
* [A100](https://www.nvidia.com/en-us/data-center/a100/)/[A30](https://www.nvidia.com/en-us/data-center/products/a30-gpu/)
* [A100](https://www.nvidia.com/en-us/data-center/a100/)
* [A30](https://www.nvidia.com/en-us/data-center/products/a30-gpu/)
* [V100](https://www.nvidia.com/en-us/data-center/v100/) (experimental)
If a GPU is not listed above, it is important to note that TensorRT-LLM is
@ -254,6 +254,7 @@ The list of supported models is:
* [LLaMA-v2](examples/llama)
* [Mistral](examples/llama)
* [MPT](examples/mpt)
* [mT5](examples/enc_dec)
* [OPT](examples/opt)
* [Qwen](examples/qwen)
* [Replit Code](examples/mpt)
@ -261,7 +262,10 @@ The list of supported models is:
* [StarCoder](examples/gpt)
* [T5](examples/enc_dec)
Note: [Encoder-Decoder](examples/enc_dec/) provides general encoder-decoder support that contains many encoder-decoder models such as T5, Flan-T5, etc. We unroll the exact model names in the list above to let users find specific models easier.
Note: [Encoder-Decoder](examples/enc_dec/) provides general encoder-decoder
support that contains many encoder-decoder models such as T5, Flan-T5, etc. We
unroll the exact model names in the list above to let users find specific
models easier.
## Performance
@ -325,7 +329,11 @@ enable plugins, for example: `--use_gpt_attention_plugin`.
* MPI + Slurm
TensorRT-LLM is a [MPI](https://en.wikipedia.org/wiki/Message_Passing_Interface)-aware package that uses [`mpi4py`](https://mpi4py.readthedocs.io/en/stable/). If you are running scripts in a [Slurm](https://slurm.schedmd.com/) environment, you might encounter interferences:
TensorRT-LLM is a
[MPI](https://en.wikipedia.org/wiki/Message_Passing_Interface)-aware package
that uses [`mpi4py`](https://mpi4py.readthedocs.io/en/stable/). If you are
running scripts in a [Slurm](https://slurm.schedmd.com/) environment, you might
encounter interferences:
```
--------------------------------------------------------------------------
PMI2_Init failed to initialize. Return code: 14
@ -347,19 +355,123 @@ SLURM, depending upon the SLURM version you are using:
Please configure as appropriate and try again.
--------------------------------------------------------------------------
```
As a rule of thumb, if you are running TensorRT-LLM interactively on a Slurm node, prefix your commands with `mpirun -n 1` to run TensorRT-LLM in a dedicated MPI environment, not the one provided by your Slurm allocation.
As a rule of thumb, if you are running TensorRT-LLM interactively on a Slurm
node, prefix your commands with `mpirun -n 1` to run TensorRT-LLM in a
dedicated MPI environment, not the one provided by your Slurm allocation.
For example: `mpirun -n 1 python3 examples/gpt/build.py ...`
## Release notes
* TensorRT-LLM requires TensorRT 9.1.0.4 and 23.08 containers.
* TensorRT-LLM requires TensorRT 9.2 and 23.10 containers.
### Change Log
#### Version 0.6.0
* Models
* ChatGLM3
* InternLM (contributed by @wangruohui)
* Mistral 7B (developed in collaboration with Mistral.AI)
* MQA/GQA support to MPT (and GPT) models (contributed by @bheilbrun)
* Qwen (contributed by @Tlntin and @zhaohb)
* Replit Code V-1.5 3B (external contribution)
* T5, mT5, Flan-T5 (Python runtime only)
* Features
* Add runtime statistics related to active requests and KV cache
utilization from the batch manager (see
the [batch manager](docs/source/batch_manager.md) documentation)
* Add `sequence_length` tensor to support proper lengths in beam-search
(when beam-width > 1 - see
[tensorrt_llm/batch_manager/GptManager.h](cpp/include/tensorrt_llm/batch_manager/GptManager.h))
* BF16 support for encoder-decoder models (Python runtime - see
[examples/enc_dec](examples/enc_dec/README.md))
* Improvements to memory utilization (CPU and GPU - including memory
leaks)
* Improved error reporting and memory consumption
* Improved support for stop and bad words
* INT8 SmoothQuant and INT8 KV Cache support for the Baichuan models (see
[examples/baichuan](examples/baichuan/README.md))
* INT4 AWQ Tensor Parallelism support and INT8 KV cache + AWQ/weight-only
support for the GPT-J model (see [examples/gptj](examples/gptj/README.md))
* INT4 AWQ support for the Falcon models
(see [examples/falcon](examples/falcon/README.md))
* LoRA support (functional preview only - limited to the Python runtime,
only QKV support and not optimized in terms of runtime performance) for
the GPT model (see the
[Run LoRA with the Nemo checkpoint](examples/gpt/README.md#Run-LoRA-with-the-Nemo-checkpoint)
in the GPT example)
* Multi-GPU support for encoder-decoder models (Python runtime - see
[examples/enc_dec](examples/enc_dec/README.md))
* New heuristic for launching the Multi-block Masked MHA kernel (similar
to FlashDecoding - see
[decoderMaskedMultiheadAttentionLaunch.h](cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderMaskedMultiheadAttentionLaunch.h))
* Prompt-Tuning support for GPT and LLaMA models (see the
[Prompt-tuning](examples/gpt/README.md#Prompt-tuning) Section in the GPT example)
* Performance optimizations in various CUDA kernels
* Possibility to exclude input tokens from the output (see `excludeInputInOutput` in
[`GptManager`](cpp/include/tensorrt_llm/batch_manager/GptManager.h))
* Python binding for the C++ runtime (GptSession - see [`pybind`](cpp/tensorrt_llm/pybind))
* Support for different micro batch sizes for context and generation
phases with pipeline parallelism (see `GptSession::Config::ctxMicroBatchSize` and
`GptSession::Config::genMicroBatchSize` in
[tensorrt_llm/runtime/gptSession.h](cpp/include/tensorrt_llm/runtime/gptSession.h))
* Support for "remove input padding" for encoder-decoder models (see
[examples/enc_dec](examples/enc_dec/README.md))
* Support for context and generation logits (see `mComputeContextLogits` and
`mComputeGenerationLogits` in
[tensorrt_llm/runtime/gptModelConfig.h](cpp/include/tensorrt_llm/runtime/gptModelConfig.h))
* Support for `logProbs` and `cumLogProbs` (see `"output_log_probs"` and
`"cum_log_probs"` in [`GptManager`](cpp/include/tensorrt_llm/batch_manager/GptManager.h))
* Update to CUTLASS 3.x
* Bug fixes
* Fix for ChatGLM2 #93 and #138
* Fix tensor names error "RuntimeError: Tensor names
(`host_max_kv_cache_length`) in engine are not the same as expected in
the main branch" #369
* Fix weights split issue in BLOOM when `world_size = 2` ("array split
does not result in an equal division") #374
* Fix SmoothQuant multi-GPU failure with tensor parallelism is 2 #267
* Fix a crash in GenerationSession if stream keyword argument is not None
#202
* Fix a typo when calling PyNVML API [BUG] code bug #410
* Fix bugs related to the improper management of the `end_id` for various
models [C++ and Python]
* Fix memory leaks [C++ code and Python models]
* Fix the std::alloc error when running the gptManagerBenchmark -- issue
gptManagerBenchmark std::bad_alloc error #66
* Fix a bug in pipeline parallelism when beam-width > 1
* Fix a bug with Llama GPTQ due to improper support of GQA
* Fix issue #88
* Fix an issue with the Huggingface Transformers version #16
* Fix link jump in windows readme.md #30 - by @yuanlehome
* Fix typo in batchScheduler.h #56 - by @eltociear
* Fix typo #58 - by @RichardScottOZ
* Fix Multi-block MMHA: Difference between `max_batch_size` in the engine
builder and `max_num_sequences` in TrtGptModelOptionalParams? #65
* Fix the log message to be more accurate on KV cache #224
* Fix Windows release wheel installation: Failed to install the release
wheel for Windows using pip #261
* Fix missing torch dependencies: [BUG] The batch_manage.a choice error
in --cpp-only when torch's cxx_abi version is different with gcc #151
* Fix linking error during compiling google-test & benchmarks #277
* Fix logits dtype for Baichuan and ChatGLM: segmentation fault caused by
the lack of bfloat16 #335
* Minor bug fixes
#### Version 0.5.0
* TensorRT-LLM v0.5.0 is the first public release.
### Known Issues
* The hang reported in issue
[#149](https://github.com/triton-inference-server/tensorrtllm_backend/issues/149)
has not been reproduced by the TensorRT-LLM team. If it is caused by a bug
in TensorRT-LLM, that bug may be present in that release
### Report Issues
You can use GitHub issues to report issues with TensorRT-LLM.

View File

@ -18,9 +18,13 @@ instead, and be sure to set DLL paths as specified in
### 2. Launch C++ benchmarking (Fixed BatchSize/InputLen/OutputLen)
#### Prepare TensorRT-LLM engine(s)
Before you launch C++ benchmarking, please make sure that you have already built engine(s) using TensorRT-LLM API, C++ benchmarking code cannot generate engine(s) for you.
You can reuse the engine built by benchmarking code for Python Runtime, please see that [`document`](../python/README.md).
You can use the [`build.py`](../python/build.py) script to build the engine(s). Alternatively, if you have already benchmarked Python Runtime, you can reuse the engine(s) built by benchmarking code, please see that [`document`](../python/README.md).
#### Launch benchmarking
For detailed usage, you can do the following
```

View File

@ -15,11 +15,14 @@
* limitations under the License.
*/
#include "tensorrt_llm/batch_manager/GptManager.h"
#include "tensorrt_llm/batch_manager/NamedTensor.h"
#include "tensorrt_llm/batch_manager/inferenceRequest.h"
#include "tensorrt_llm/batch_manager/namedTensor.h"
#include "tensorrt_llm/common/assert.h"
#include "tensorrt_llm/common/logger.h"
#include "tensorrt_llm/common/memoryUtils.h"
#include "tensorrt_llm/common/mpiUtils.h"
#include "tensorrt_llm/common/stringUtils.h"
#include "tensorrt_llm/plugins/api/tllmPlugin.h"
#include "tensorrt_llm/runtime/tllmLogger.h"
@ -42,9 +45,9 @@ namespace trt = nvinfer1;
class WorkItem
{
public:
WorkItem(std::shared_ptr<InferenceRequest> ir, uint64_t RequestId)
WorkItem(std::shared_ptr<InferenceRequest> ir, uint64_t requestId)
: mInferenceRequest(ir)
, mRequestId(RequestId)
, mRequestId(requestId)
{
}
@ -100,18 +103,12 @@ public:
void push(std::shared_ptr<InferenceRequest> request, uint64_t requestId)
{
std::lock_guard<std::mutex> lk(mMutex);
if (hasInProgressReqId(requestId) || hasPendingReqId(requestId))
{
std::string errStr
= "requestId " + std::to_string(requestId) + " is already in progress, request is ignored.";
throw std::runtime_error(errStr);
}
else
{
auto workItem = std::make_shared<WorkItem>(request, requestId);
mPendingWorkItems.push_back(workItem);
mPendingWorkItemsReqIds.insert(workItem->requestId());
}
TLLM_CHECK_WITH_INFO(!hasInProgressReqId(requestId) && !hasPendingReqId(requestId),
"requestId %lu is already in progress, request is ignored.", requestId);
auto workItem = std::make_shared<WorkItem>(request, requestId);
mPendingWorkItems.push_back(workItem);
mPendingWorkItemsReqIds.insert(workItem->requestId());
}
/// @brief Get a new work item from the queue, and move it to the list of
@ -208,12 +205,12 @@ public:
void recordStart(std::shared_ptr<InferenceRequest> request, uint64_t requestId)
{
const auto& input_ids_tensor = request->getInputTensor("input_ids");
std::vector<int64_t> tensorShape(input_ids_tensor->getShape().nbDims);
auto const inputLength = tensorShape[1];
auto const [specified, outputLength]
= request->getScalarValueFromTensor<int>("request_output_len", {1, 1}, false);
assert(specified);
auto const inputLength = request->getInputIds()->getSize();
auto const maxNewTokens = request->getMaxNewTokensNamed();
auto const& outputLengthTensor = maxNewTokens.tensor;
TLLM_CHECK_WITH_INFO(outputLengthTensor != nullptr && outputLengthTensor->getSize() > 0,
"Undefined scalar vector for %s", maxNewTokens.name.c_str());
auto const outputLength = *bufferCast<SizeType>(*outputLengthTensor);
auto const start = std::chrono::steady_clock::now();
mRequestBenchInfos[requestId] = BenchInfo(inputLength, outputLength, start);
}
@ -286,20 +283,15 @@ public:
mWorkItemsQueue.clear();
}
void enqueue(std::vector<NamedTensor> tensors, uint64_t requestId, bool streaming)
void enqueue(std::shared_ptr<InferenceRequest> const& request)
{
// Create InferenceRequest from a set of tensors
auto request = std::make_shared<InferenceRequest>(requestId);
TLLM_CHECK(request != nullptr);
auto const requestId = request->getRequestId();
if (requestId == mTerminateReqId)
{
mWorkItemsQueue.push(request, requestId);
return;
}
for (auto t : tensors)
{
request->emplaceInputTensor(t.name, std::move(t.tensor));
}
request->setIsStreaming(streaming);
// Enqueue
try
@ -307,11 +299,14 @@ public:
mRecorder->recordStart(request, requestId);
mWorkItemsQueue.push(request, requestId);
}
catch (const tc::TllmException& e)
{
throw;
}
catch (const std::exception& e)
{
throw std::runtime_error(e.what());
TLLM_THROW("%s", e.what());
}
return;
}
void waitForEmpty() const
@ -357,8 +352,8 @@ public:
}
else
{
std::string warnStr = std::string("request Id ") + std::to_string(workItem->requestId())
+ std::string(" has been stopped. Request is ignored.");
auto warnStr = tc::fmtstr(
"request Id %lu has been stopped. Request is ignored.", workItem->requestId());
TLLM_LOG_WARNING(warnStr);
sendResponse(workItem->requestId(), {}, true, warnStr);
}
@ -366,7 +361,7 @@ public:
if (world_size > 1)
{
std::vector<int64_t> packed;
for (auto ir : rval)
for (auto const& ir : rval)
{
auto vpacked = ir->serialize();
packed.push_back(static_cast<int64_t>(vpacked.size()));
@ -400,10 +395,9 @@ public:
return rval;
}
void sendResponse(uint64_t requestId, std::list<NamedTensor> const& response_tensors, bool final_response,
const std::string& errMsg)
void sendResponse(uint64_t requestId, [[maybe_unused]] std::list<NamedTensor> const& response_tensors,
bool final_response, [[maybe_unused]] const std::string& errMsg)
{
std::string errStr = std::string("Failed to send response for requestId: ") + std::to_string(requestId);
try
{
if (final_response)
@ -414,7 +408,7 @@ public:
}
catch (const std::exception& e)
{
TLLM_LOG_ERROR(errStr);
TLLM_LOG_ERROR("Failed to send response for requestId: %ul\n%s", requestId, e.what());
}
}
@ -434,24 +428,43 @@ std::pair<std::vector<std::vector<int32_t>>, std::vector<int32_t>> parseDataset(
{
auto constexpr allowExceptions = true;
auto constexpr ingoreComments = true;
TLLM_CHECK_WITH_INFO(
std::filesystem::exists(datasetPath), std::string("File does not exist: ") + datasetPath.string());
TLLM_CHECK_WITH_INFO(std::filesystem::exists(datasetPath), "File does not exist: %s", datasetPath.string().c_str());
std::ifstream jsonStream(datasetPath);
auto json = nlohmann::json::parse(jsonStream, nullptr, allowExceptions, ingoreComments);
std::vector<std::vector<int32_t>> input_ids_list;
std::vector<int32_t> output_ids_list;
std::vector<std::vector<int32_t>> inputIds;
std::vector<int32_t> outputIds;
for (auto& sample : json)
{
input_ids_list.push_back(sample["input_ids"]);
output_ids_list.push_back(sample["output_len"]);
inputIds.push_back(sample["input_ids"]);
outputIds.push_back(sample["output_len"]);
}
return std::make_pair(input_ids_list, output_ids_list);
return std::make_pair(inputIds, outputIds);
}
void benchmarkGptManager(std::string const& modelName, std::filesystem::path const& engineDir, std::string const& type,
std::string const& datasetPath, int beamWidth, int warmUp, std::shared_ptr<nvinfer1::ILogger> const& logger,
TrtGptModelOptionalParams const& optionalParams, batch_scheduler::SchedulerPolicy schedulerPolicy)
std::shared_ptr<InferenceRequest> makeRequest(std::uint64_t reqId,
std::pair<std::vector<std::vector<int32_t>>, std::vector<int32_t>> const& dataset, std::size_t sample_idx,
ITensor::SharedPtr const& beamWidthTensor, ITensor::SharedPtr const& eosId, ITensor::SharedPtr const& padId,
BufferManager const& bufferManager)
{
auto request = std::make_shared<InferenceRequest>(reqId);
auto const& inputIds = dataset.first[sample_idx];
request->setInputIds(bufferManager.copyFrom(
inputIds, ITensor::makeShape({1, static_cast<SizeType>(inputIds.size())}), MemoryType::kPINNED));
auto const request_output_len = dataset.second[sample_idx];
request->setMaxNewTokens(
bufferManager.copyFrom(&request_output_len, ITensor::makeShape({1, 1}), MemoryType::kPINNED));
request->setBeamWidth(beamWidthTensor);
request->setEndId(eosId);
request->setPadId(padId);
return request;
}
void benchmarkGptManager([[maybe_unused]] std::string const& modelName, std::filesystem::path const& engineDir,
std::string const& type, std::string const& datasetPath, int beamWidth, int warmUp,
const std::optional<int32_t>& eosId, const std::optional<int32_t>& padId,
std::shared_ptr<nvinfer1::ILogger> const& logger, TrtGptModelOptionalParams const& optionalParams,
batch_scheduler::SchedulerPolicy schedulerPolicy)
{
auto const worldConfig = WorldConfig::mpi(*logger);
@ -466,57 +479,49 @@ void benchmarkGptManager(std::string const& modelName, std::filesystem::path con
}
else
{
const std::string errStr = std::string("Unexpected batching type: ") + type;
TLLM_LOG_ERROR(errStr);
TLLM_LOG_ERROR("Unexpected batching type: %s", type.c_str());
}
ITensor::SharedPtr beamWidthBuffer = BufferManager::cpu(ITensor::makeShape({1}), nvinfer1::DataType::kINT32);
auto beamWidthBufferPtr = bufferCast<SizeType>(*beamWidthBuffer);
*beamWidthBufferPtr = beamWidth;
auto beamWidthTensor = NamedTensor(beamWidthBuffer, "beam_width");
BufferManager bufferManager{std::make_shared<CudaStream>()}; // the stream is not used
ITensor::SharedPtr beamWidthTensor{
bufferManager.copyFrom(&beamWidth, ITensor::makeShape({1}), MemoryType::kPINNED)};
// Load dataset
auto dataset = parseDataset(datasetPath);
std::vector<std::vector<NamedTensor>> tensors_list;
const auto num_samples = dataset.first.size();
for (int i = 0; i < num_samples; ++i)
{
const auto input_ids = dataset.first[i];
const auto request_output_len = dataset.second[i];
std::vector<int64_t> input_ids_shape = {1, static_cast<int64_t>(input_ids.size())};
auto input_ids_tensor = NamedTensor(nvinfer1::DataType::kINT32, input_ids_shape, "input_ids", input_ids.data());
auto request_output_len_tensor
= NamedTensor(nvinfer1::DataType::kINT32, {1, 1}, "request_output_len", &request_output_len);
std::vector<NamedTensor> tensors
= {std::move(input_ids_tensor), std::move(request_output_len_tensor), beamWidthTensor};
tensors_list.emplace_back(std::move(tensors));
}
const auto numSamples = dataset.first.size();
const int maxBeamWidth = beamWidth;
auto recorder = std::make_shared<Recorder>();
uint64_t terminateReqId = num_samples + 1;
uint64_t terminateReqId = numSamples + 1;
auto gptServer = std::make_shared<GptServer>(
engineDir, modelType, maxBeamWidth, schedulerPolicy, optionalParams, recorder, terminateReqId);
ITensor::SharedPtr eosIdTensor{
eosId ? bufferManager.copyFrom(&eosId.value(), ITensor::makeShape({1}), MemoryType::kPINNED) : nullptr};
ITensor::SharedPtr padIdTensor{
padId ? bufferManager.copyFrom(&padId.value(), ITensor::makeShape({1}), MemoryType::kPINNED) : nullptr};
if (worldConfig.getRank() == 0)
{
// Warm up
for (auto i = 1; i < warmUp + 1; ++i)
SizeType reqId = 0;
for (auto i = 0; i < warmUp; ++i)
{
// skip terminateReqId
++reqId;
if (i == terminateReqId)
{
i += 1;
}
gptServer->enqueue(tensors_list[0], i, false);
++reqId;
auto request = makeRequest(reqId, dataset, 0, beamWidthTensor, eosIdTensor, padIdTensor, bufferManager);
gptServer->enqueue(request);
}
gptServer->waitForEmpty();
// Benchmark
recorder->initialize();
for (int i = 0; i < tensors_list.size(); ++i)
for (std::size_t i = 0; i < numSamples; ++i)
{
gptServer->enqueue(tensors_list[i], 1 + i, false);
auto request = makeRequest(i + 1, dataset, i, beamWidthTensor, eosIdTensor, padIdTensor, bufferManager);
gptServer->enqueue(request);
}
gptServer->waitForEmpty();
recorder->finalize();
@ -524,7 +529,7 @@ void benchmarkGptManager(std::string const& modelName, std::filesystem::path con
recorder->report();
// Send terminateReqId to terminate servers on all ranks
// Sever on rank 0 will broadcast the terminate signal to other servers on multi-GPU cases
gptServer->enqueue({}, terminateReqId, false);
gptServer->enqueue(std::make_shared<InferenceRequest>(terminateReqId));
}
// Wait until benchmarking is done and batch manager is terminated
gptServer->waitBatchManager();
@ -548,7 +553,8 @@ int main(int argc, char* argv[])
"beam_width", "Specify beam width you want to benchmark.", cxxopts::value<int>()->default_value("1"));
options.add_options()(
"warm_up", "Specify warm up iterations before benchmark starts.", cxxopts::value<int>()->default_value("2"));
options.add_options()("eos_id", "Specify the end-of-sequence token id.", cxxopts::value<int>());
options.add_options()("pad_id", "Specify the padding token id.", cxxopts::value<int>());
options.add_options()("max_num_sequences", "Max number of Sequences.", cxxopts::value<int>());
options.add_options()("max_tokens_in_paged_kvcache", "Max tokens in paged K-V Cache.", cxxopts::value<int>());
options.add_options()(
@ -609,6 +615,20 @@ int main(int argc, char* argv[])
optionalParams.enableTrtOverlap = result["enable_trt_overlap"].as<bool>();
}
std::optional<int32_t> padId;
// Argument: Padding token id
if (result.count("pad_id"))
{
padId = result["pad_id"].as<int>();
}
std::optional<int32_t> eosId;
// Argument: End-of-sentence token id
if (result.count("eos_id"))
{
eosId = result["eos_id"].as<int>();
}
// Argument: Scheduler policy
batch_scheduler::SchedulerPolicy schedulerPolicy;
auto const schedulerPolicyArg = result["scheduler_policy"].as<std::string>();
@ -660,7 +680,7 @@ int main(int argc, char* argv[])
try
{
benchmarkGptManager(result["model"].as<std::string>(), result["engine_dir"].as<std::string>(), type,
datasetPath, beamWidth, result["warm_up"].as<int>(), logger, optionalParams, schedulerPolicy);
datasetPath, beamWidth, result["warm_up"].as<int>(), eosId, padId, logger, optionalParams, schedulerPolicy);
}
catch (const std::exception& e)
{

View File

@ -40,13 +40,8 @@ void benchmarkGptSession(std::string const& modelName, std::filesystem::path con
std::shared_ptr<nvinfer1::ILogger> const& logger, int warmUp, int numRuns, int duration,
GptSession::Config& sessionConfig, bool cudaGraphMode, bool printAllLogits)
{
std::string modelNameHyphen = modelName;
std::filesystem::path jsonFileName = dataPath / "config.json";
if (tc::strStartsWith(modelName, "chatglm") || tc::strStartsWith(modelName, "glm"))
{
jsonFileName = dataPath / (modelNameHyphen + std::string("-config.json"));
}
auto const json = GptJsonConfig::parse(jsonFileName);
auto const modelConfig = json.getModelConfig();
auto const inputPacked = modelConfig.usePackedInput();

View File

@ -8,6 +8,7 @@ multiple GPUs or multiple nodes with multiple GPUs.
The benchmark implementation and entrypoint can be found in [`benchmarks/python/benchmark.py`](./benchmark.py). There are some other scripts in the directory:
* [`benchmarks/python/allowed_configs.py`](./allowed_configs.py) to define configuration for each supported model.
* [`benchmarks/python/build.py`](./build.py) to build supported models for benchmarking.
* [`benchmarks/python/base_benchmark.py`](./base_benchmark.py) to implement the base class for benchmark.
* [`benchmarks/python/gpt_benchmark.py`](./gpt_benchmark.py) to implement benchmark scripts for GPT and GPT-like(LLaMA/OPT/GPT-J/SmoothQuant-GPT) models.
* [`benchmarks/python/bert_benchmark.py`](./bert_benchmark.py) to implement benchmark scripts for BERT models.
@ -30,9 +31,9 @@ python benchmark.py \
```
Expected outputs:
```
[BENCHMARK] model_name gpt_350m world_size 1 num_heads 16 num_layers 24 hidden_size 1024 vocab_size 51200 precision float16 batch_size 1 input_length 60 output_length 20 build_time(s) 89.8 tokens_per_sec 378.12 percentile95(ms) 53.284 percentile99(ms) 53.284 latency(ms) 52.893
[BENCHMARK] model_name gpt_350m world_size 1 num_heads 16 num_layers 24 hidden_size 1024 vocab_size 51200 precision float16 batch_size 8 input_length 60 output_length 20 build_time(s) 89.8 tokens_per_sec 361.06 percentile95(ms) 55.739 percentile99(ms) 55.739 latency(ms) 55.392
[BENCHMARK] model_name gpt_350m world_size 1 num_heads 16 num_layers 24 hidden_size 1024 vocab_size 51200 precision float16 batch_size 64 input_length 60 output_length 20 build_time(s) 89.8 tokens_per_sec 246.03 percentile95(ms) 81.533 percentile99(ms) 81.533 latency(ms) 81.29
[BENCHMARK] model_name gpt_350m world_size 1 num_heads 16 num_kv_heads 16 num_layers 24 hidden_size 1024 vocab_size 51200 precision float16 batch_size 1 input_length 60 output_length 20 gpu_peak_mem(gb) 4.2 build_time(s) 25.67 tokens_per_sec 483.54 percentile95(ms) 41.537 percentile99(ms) 42.102 latency(ms) 41.362 compute_cap sm80
[BENCHMARK] model_name gpt_350m world_size 1 num_heads 16 num_kv_heads 16 num_layers 24 hidden_size 1024 vocab_size 51200 precision float16 batch_size 8 input_length 60 output_length 20 gpu_peak_mem(gb) 4.28 build_time(s) 25.67 tokens_per_sec 3477.28 percentile95(ms) 46.129 percentile99(ms) 46.276 latency(ms) 46.013 compute_cap sm80
[BENCHMARK] model_name gpt_350m world_size 1 num_heads 16 num_kv_heads 16 num_layers 24 hidden_size 1024 vocab_size 51200 precision float16 batch_size 64 input_length 60 output_length 20 gpu_peak_mem(gb) 4.8 build_time(s) 25.67 tokens_per_sec 19698.07 percentile95(ms) 65.739 percentile99(ms) 65.906 latency(ms) 64.981 compute_cap sm80
...
```
*Please note that the expected outputs is only for reference, specific performance numbers depend on the GPU you're using.*

View File

@ -30,18 +30,17 @@ class BuildConfig(BaseModel, extra=Extra.allow):
max_input_len: int
num_kv_heads: Optional[int] = None
max_output_len: Optional[int] = None
max_beam_width: int = 1
# TRT builder_optimization_level from 0 to 5
builder_opt: Optional[int] = None
inter_size: Optional[int] = None
rotary_dim: Optional[int] = None
type_vocab_size: Optional[int] = None
use_smooth_quant: bool = False
per_token: bool = False
per_channel: bool = False
pre_norm: Optional[bool] = None
do_layer_norm_before: Optional[bool] = None
enable_qk_half_accum: bool = False
enable_context_fmha: bool = True
enable_multi_block_mode: bool = False
# None means using the model family's default value defined in the ctor
position_embedding_type: Optional[PositionEmbeddingType] = None
# Only when position embedding is RoPE, this value makes sense, make
@ -49,6 +48,10 @@ class BuildConfig(BaseModel, extra=Extra.allow):
rotary_pct: Optional[float] = None
bias: bool = True
quantization: Optional[str] = None
# use_custom_all_reduce gives better performance with NVLink
use_custom_all_reduce: bool = True
moe_num_experts: int = None
moe_top_k: int = None
class ModelConfig(BaseModel):
@ -107,6 +110,24 @@ _allowed_configs = {
max_output_len=200,
builder_opt=None,
)),
"gpt_350m_moe":
ModelConfig(name="gpt_350m_moe",
family="gpt",
benchmark_type="gpt",
build_config=BuildConfig(
num_layers=24,
num_heads=16,
hidden_size=1024,
vocab_size=51200,
hidden_act='gelu',
n_positions=1024,
max_batch_size=256,
max_input_len=512,
max_output_len=200,
builder_opt=None,
moe_num_experts=8,
moe_top_k=1,
)),
"gpt_350m_sq_per_tensor":
ModelConfig(name="gpt_350m_sq_per_tensor",
family="gpt",
@ -301,6 +322,40 @@ _allowed_configs = {
max_output_len=200,
builder_opt=None,
)),
"llama_70b_long_context":
ModelConfig(name="llama_70b_long_context",
family="llama",
benchmark_type="gpt",
build_config=BuildConfig(num_layers=80,
num_heads=64,
num_kv_heads=8,
hidden_size=8192,
vocab_size=32000,
hidden_act='silu',
n_positions=2048,
inter_size=28672,
max_batch_size=16,
max_input_len=8000,
max_output_len=200,
builder_opt=None,
enable_multi_block_mode=True)),
"llama_70b_long_generation":
ModelConfig(name="llama_70b_long_generation",
family="llama",
benchmark_type="gpt",
build_config=BuildConfig(num_layers=80,
num_heads=64,
num_kv_heads=8,
hidden_size=8192,
vocab_size=32000,
hidden_act='silu',
n_positions=2048,
inter_size=28672,
max_batch_size=64,
max_input_len=200,
max_output_len=16384,
builder_opt=None,
enable_multi_block_mode=True)),
"llama_70b_sq_per_tensor":
ModelConfig(name="llama_70b_sq_per_tensor",
family="llama",

View File

@ -15,11 +15,13 @@
import json
import os
import subprocess
import time
from collections import OrderedDict
import torch
import tensorrt_llm
from tensorrt_llm.logger import logger
from tensorrt_llm.quantization import QuantMode
@ -44,23 +46,28 @@ def get_engine_name(model, dtype, tp_size, rank):
def serialize_engine(engine, path):
logger.info(f'Serializing engine to {path}...')
tik = time.time()
with open(path, 'wb') as f:
# engine object is already complies with python buffer protocol, no need to
# convert it to bytearray before write, converting to bytearray consumes lots of memory
f.write(engine)
tok = time.time()
t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))
logger.info(f'Engine serialized. Total time: {t}')
class BaseBenchmark(object):
def __init__(self, engine_dir, model_name, dtype, output_dir):
def __init__(self, engine_dir, model_name, dtype):
self.engine_dir = engine_dir
self.model_name = model_name
self.dtype = dtype
self.output_dir = output_dir
self.runtime_rank = tensorrt_llm.mpi_rank()
self.world_size = tensorrt_llm.mpi_world_size()
self.engine_model_name = model_name
self.quant_mode = QuantMode(0)
self.enable_fp8 = False
if engine_dir is not None:
# Read config from engine directory
config_path = os.path.join(engine_dir, 'config.json')
@ -98,7 +105,7 @@ class BaseBenchmark(object):
self.csv_filename = "" # lazy init
def get_report_dict(self):
def get_report_dict(self, benchmark_profiler=None):
report_fields = [
"model_name", "world_size", "num_heads", "num_kv_heads",
"num_layers", "hidden_size", "vocab_size", "precision",
@ -122,9 +129,9 @@ class BaseBenchmark(object):
fp8linear=int(self.enable_fp8))
return self.csv_filename
def print_report_header(self, csv=False):
def print_report_header(self, csv=False, benchmark_profiler=None):
if csv and self.runtime_rank == 0:
report_dict = self.get_report_dict()
report_dict = self.get_report_dict(benchmark_profiler)
line = ",".join(report_dict.keys())
print(line)
with open(self.get_csv_filename(), "a") as file:
@ -136,7 +143,7 @@ class BaseBenchmark(object):
def prepare_inputs(self, config):
raise NotImplementedError
def run(self, inputs, config):
def run(self, inputs, config, benchmark_profiler=None):
raise NotImplementedError
def report(self, config, latency):

View File

@ -118,7 +118,7 @@ def parse_arguments():
type=str,
default=None,
help=
'If this option is specified, TensorRT engines will be saved to engine_dir.'
'If this option is specified, TensorRT engines will be saved to the specified path.'
)
parser.add_argument(
'--engine_dir',
@ -128,37 +128,33 @@ def parse_arguments():
('If this option is specified, instead of building engines on-air before benchmarking, '
'the engines contained in the engine_dir will be used.'))
parser.add_argument(
'--n_positions',
'--max_beam_width',
type=int,
default=None,
help=
('If this option is specified, it will override the n_positions of TRT engines to the specified value instead of using pre-defined one'
'By default when this option is not used, it will use pre-defined n_positions'
))
('If this option is specified, it will override the max beam width of '
'TRT engines to the specified value instead of using pre-defined one'))
parser.add_argument(
'--max_input_len',
type=int,
default=None,
help=
('If this option is specified, it will override the max input len of TRT engines to the specified value instead of using pre-defined one'
'By default when this option is not used, it will use pre-defined max input len'
))
('If this option is specified, it will override the max input len of '
'TRT engines to the specified value instead of using pre-defined one'))
parser.add_argument(
'--max_output_len',
type=int,
default=None,
help=
('If this option is specified, it will override the max output len of TRT engines to the specified value instead of using pre-defined one'
'By default when this option is not used, it will use pre-defined max output len'
))
('If this option is specified, it will override the max output len of '
'TRT engines to the specified value instead of using pre-defined one'))
parser.add_argument(
'--max_batch_size',
type=int,
default=None,
help=
('If this option is specified, it will override the max batch size of TRT engines to the specified value instead of using pre-defined one'
'By default when this option is not used, it will use pre-defined max batch size'
))
('If this option is specified, it will override the max batch size of '
'TRT engines to the specified value instead of using pre-defined one'))
parser.add_argument(
'--force_num_layer_1',
default=False,
@ -175,20 +171,6 @@ def parse_arguments():
default=False,
action='store_true',
help='Execute GPT session with CUDA graph.')
parser.add_argument(
'--enable_custom_all_reduce',
default=False,
action='store_true',
help=
'Use latency-optimized all-reduce for tensor parallelism. Gives better performance with NVLink.'
)
parser.add_argument(
'--strongly_typed',
default=False,
action='store_true',
help=
'This option is introduced with trt 9.1.0.1+ and will reduce the building time significantly for fp8.'
)
parser.add_argument(
'--quantization',
type=str,
@ -209,6 +191,7 @@ def main(args):
# the start method `spawn` of Python multiprocessing,
# so we set the start method first, then initialize MPI.
from allowed_configs import get_allowed_models
from benchmark_profiler import BenchmarkProfiler
from bert_benchmark import BERTBenchmark
from gpt_benchmark import GPTBenchmark
from mem_monitor import MemoryMonitor
@ -228,47 +211,19 @@ def main(args):
in_out_len_options = [[int(i) for i in io.split(',')]
for io in in_out_len_options]
benchmark_profiler = None
if args.model in get_allowed_models(benchmark_type="gpt"):
benchmarker = GPTBenchmark(
args.engine_dir,
args.model,
args.mode,
batch_size_options,
in_out_len_options,
args.dtype,
args.refit,
args.num_beams,
args.top_k,
args.top_p,
args.output_dir,
args.n_positions,
args.max_input_len,
args.max_output_len,
args.max_batch_size,
force_num_layer_1=args.force_num_layer_1,
enable_cuda_graph=args.enable_cuda_graph,
enable_custom_all_reduce=args.enable_custom_all_reduce,
strongly_typed=args.strongly_typed,
quantization=args.quantization)
benchmark_profiler = BenchmarkProfiler()
benchmarker = GPTBenchmark(args, batch_size_options, in_out_len_options)
elif args.model in get_allowed_models(benchmark_type="bert"):
benchmarker = BERTBenchmark(args.engine_dir,
args.model,
args.mode,
batch_size_options,
input_len_options,
args.dtype,
args.output_dir,
args.n_positions,
args.max_input_len,
args.max_output_len,
args.max_batch_size,
force_num_layer_1=args.force_num_layer_1)
benchmarker = BERTBenchmark(args, batch_size_options, input_len_options)
else:
raise Exception(f'Unexpected model: {args.model}')
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
benchmarker.print_report_header(args.csv)
benchmarker.print_report_header(args.csv,
benchmark_profiler=benchmark_profiler)
for config in benchmarker.get_config():
try:
inputs = benchmarker.prepare_inputs(config)
@ -290,12 +245,16 @@ def main(args):
for _ in range(args.warm_up):
benchmarker.run(inputs, config)
logger.info('Warm up done. Start benchmarking.')
if benchmark_profiler is not None:
benchmark_profiler.clean()
benchmark_profiler.start()
cur_duration = 0
start_time = time()
while iter_idx < args.num_runs or cur_duration < args.duration:
start.record()
benchmarker.run(inputs, config)
benchmarker.run(inputs,
config,
benchmark_profiler=benchmark_profiler)
end.record()
torch.cuda.synchronize()
@ -315,6 +274,9 @@ def main(args):
memory_monitor.stop()
_, peak_gpu_used = memory_monitor.get_peak_memory_usage("GiB")
peak_gpu_used = round(peak_gpu_used, 3)
if benchmark_profiler is not None:
benchmark_profiler.add_aux_info('iter_count', iter_idx)
benchmark_profiler.stop()
latency = round(sum(latencies) / iter_idx, 3)
latencies.sort()
@ -325,7 +287,8 @@ def main(args):
percentile95,
percentile99,
peak_gpu_used,
csv=args.csv)
csv=args.csv,
benchmark_profiler=benchmark_profiler)
if __name__ == '__main__':

View File

@ -0,0 +1,77 @@
# SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
class BenchmarkProfiler(object):
cuda_event_dict: dict
timer_dict: dict
aux_info: dict
started: bool
def __init__(self):
self.cuda_event_dict = {}
self.timer_dict = {}
self.aux_info = {}
self.started = False
def clean(self):
self.cuda_event_dict = {}
self.timer_dict = {}
self.aux_info = {}
def start(self):
self.started = True
def stop(self):
self.started = False
def get_cuda_event(self, name: str):
if name not in self.cuda_event_dict.keys():
event = torch.cuda.Event(enable_timing=True)
self.cuda_event_dict[name] = event
return self.cuda_event_dict[name]
def record_cuda_event(self, name: str):
if not self.started:
return
event = self.get_cuda_event(name)
event.record()
def get_timer_value(self, timer_name: str):
# timer is in milliseconds
return self.timer_dict[timer_name]
def record_elapsed_time(self, start_event_name: str, end_event_name: str,
timer_name: str):
if timer_name not in self.timer_dict.keys():
self.timer_dict[timer_name] = 0.0
if not self.started:
return
self.get_cuda_event(start_event_name).synchronize()
self.get_cuda_event(end_event_name).synchronize()
self.timer_dict[timer_name] += self.get_cuda_event(
start_event_name).elapsed_time(self.get_cuda_event(end_event_name))
def get_aux_info(self, aux_name):
return self.aux_info[aux_name]
def add_aux_info(self, aux_name: str, add_value):
if aux_name not in self.aux_info.keys():
self.aux_info[aux_name] = 0
if not self.started:
return
self.aux_info[aux_name] += add_value

View File

@ -14,86 +14,45 @@
# limitations under the License.
import os
import time
from collections import OrderedDict
# isort: off
import torch
import tensorrt as trt
#isort: on
from allowed_configs import get_build_config
from base_benchmark import BaseBenchmark, serialize_engine
from base_benchmark import BaseBenchmark
from build import build_bert
import tensorrt_llm
from tensorrt_llm._utils import str_dtype_to_trt, trt_dtype_to_torch
from tensorrt_llm.builder import Builder
from tensorrt_llm.network import net_guard
from tensorrt_llm.plugin.plugin import ContextFMHAType
from tensorrt_llm._utils import trt_dtype_to_torch
from tensorrt_llm.runtime import TensorInfo
class BERTBenchmark(BaseBenchmark):
def __init__(self,
engine_dir,
model_name,
mode,
batch_sizes,
in_lens,
dtype,
output_dir,
n_positions=None,
max_input_len=None,
max_output_len=None,
max_batch_size=None,
**kwargs):
super().__init__(engine_dir, model_name, dtype, output_dir)
def __init__(self, args, batch_sizes, in_lens):
super().__init__(args.engine_dir, args.model, args.dtype)
self.batch_sizes = batch_sizes
self.in_lens = in_lens
self.build_time = 0
self.mode = args.mode
if engine_dir is not None:
if args.engine_dir is not None:
# Deserialize engine from engine directory
self.serialize_path = os.path.join(engine_dir, self.engine_name)
self.serialize_path = os.path.join(args.engine_dir,
self.engine_name)
with open(self.serialize_path, 'rb') as f:
engine_buffer = f.read()
else:
# Build engine
self.use_bert_attention_plugin = False
self.use_gemm_plugin = False
self.use_layernorm_plugin = False
self.enable_qk_half_accum = False
self.enable_context_fmha = False
if mode == 'plugin':
self.use_bert_attention_plugin = dtype
self.use_gemm_plugin = dtype
self.use_layernorm_plugin = dtype
for key, value in get_build_config(model_name).items():
for key, value in get_build_config(args.model).items():
setattr(self, key, value)
# Override the n_positions/max_input_len/max_output_len/max_batch_size to value from cmd line if that's specified.
if n_positions is not None:
assert isinstance(
n_positions, int
) and n_positions > 0, f"n_positions should be a valid int number, got {n_positions}"
self.n_positions = n_positions
if max_input_len is not None:
assert isinstance(
max_input_len, int
) and max_input_len > 0, f"max_input_len should be a valid int number, got {max_input_len}"
self.max_input_len = max_input_len
if max_output_len is not None:
assert isinstance(
max_output_len, int
) and max_output_len > 0, f"max_output_len should be a valid int number, got {max_output_len}"
self.max_output_len = max_output_len
if max_batch_size is not None:
assert isinstance(
max_batch_size, int
) and max_batch_size > 0, f"max_batch_size should be a valid int number, got {max_batch_size}"
self.max_batch_size = max_batch_size
if kwargs.get('force_num_layer_1', False):
if args.force_num_layer_1:
self.num_layers = 1
engine_buffer = self.build()
start = time.time()
engine_buffer = build_bert(args)
self.build_time = round(time.time() - start, 2)
assert engine_buffer is not None
@ -128,99 +87,7 @@ class BERTBenchmark(BaseBenchmark):
stream = torch.cuda.current_stream().cuda_stream
return (inputs, outputs, stream)
def build(self):
bs_range = [1, (self.max_batch_size + 1) // 2, self.max_batch_size]
inlen_range = [1, (self.max_input_len + 1) // 2, self.max_input_len]
builder = Builder()
builder_config = builder.create_builder_config(
name=self.model_name,
precision=self.dtype,
timing_cache=None,
tensor_parallel=self.world_size, # TP only
parallel_build=True,
num_layers=self.num_layers,
num_heads=self.num_heads,
num_kv_heads=self.num_heads,
hidden_size=self.hidden_size,
vocab_size=self.vocab_size,
hidden_act=self.hidden_act,
max_position_embeddings=self.n_positions,
max_batch_size=self.max_batch_size,
max_input_len=self.max_input_len,
opt_level=self.builder_opt)
# Initialize model
tensorrt_llm_bert = tensorrt_llm.models.BertModel(
num_layers=self.num_layers,
num_heads=self.num_heads,
hidden_size=self.hidden_size,
vocab_size=self.vocab_size,
hidden_act=self.hidden_act,
max_position_embeddings=self.n_positions,
type_vocab_size=self.type_vocab_size,
mapping=tensorrt_llm.Mapping(world_size=self.world_size,
tp_size=self.world_size))
# Module -> Network
network = builder.create_network()
if self.use_bert_attention_plugin:
network.plugin_config.set_bert_attention_plugin(
dtype=self.use_bert_attention_plugin)
if self.use_gemm_plugin:
network.plugin_config.set_gemm_plugin(dtype=self.use_gemm_plugin)
if self.use_layernorm_plugin:
network.plugin_config.set_layernorm_plugin(
dtype=self.use_layernorm_plugin)
if self.enable_qk_half_accum:
network.plugin_config.enable_qk_half_accum()
if self.enable_context_fmha:
network.plugin_config.set_context_fmha(ContextFMHAType.enabled)
if self.world_size > 1:
network.plugin_config.set_nccl_plugin(self.dtype)
with net_guard(network):
# Prepare
network.set_named_parameters(tensorrt_llm_bert.named_parameters())
# Forward
input_ids = tensorrt_llm.Tensor(
name='input_ids',
dtype=trt.int32,
shape=[-1, -1],
dim_range=OrderedDict([('batch_size', [bs_range]),
('input_len', [inlen_range])]),
)
input_lengths = tensorrt_llm.Tensor(name='input_lengths',
dtype=trt.int32,
shape=[-1],
dim_range=OrderedDict([
('batch_size', [bs_range])
]))
hidden_states = tensorrt_llm_bert(input_ids=input_ids,
input_lengths=input_lengths)
# Mark outputs
hidden_states_dtype = str_dtype_to_trt(self.dtype)
hidden_states.mark_output('hidden_states', hidden_states_dtype)
# Network -> Engine
start = time.time()
engine = builder.build_engine(network, builder_config)
end = time.time()
self.build_time = round(end - start, 2)
if self.output_dir is not None:
if not os.path.exists(self.output_dir):
os.makedirs(self.output_dir)
self.serialize_path = os.path.join(self.output_dir,
self.engine_name)
serialize_engine(engine, self.serialize_path)
if self.runtime_rank == 0:
config_path = os.path.join(self.output_dir, 'config.json')
builder_config.plugin_config = network.plugin_config
builder.save_config(builder_config, config_path)
return engine
def run(self, inputs, config):
def run(self, inputs, config, benchmark_profiler=None):
ok = self.session.run(*inputs)
assert ok, "Runtime execution failed"
torch.cuda.synchronize()
@ -235,8 +102,14 @@ class BERTBenchmark(BaseBenchmark):
f'percentile99(ms) {percentile99} latency(ms) {latency}')
print(line)
def report(self, config, latency, percentile95, percentile99, peak_gpu_used,
csv):
def report(self,
config,
latency,
percentile95,
percentile99,
peak_gpu_used,
csv,
benchmark_profiler=None):
report_dict = super().get_report_dict()
batch_size, inlen = config[0], config[1]
report_dict["num_heads"] = self.num_heads

606
benchmarks/python/build.py Normal file
View File

@ -0,0 +1,606 @@
# SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import multiprocessing as mp
import os
from collections import OrderedDict
import tensorrt as trt
import torch
from allowed_configs import (get_allowed_models, get_build_config,
get_model_family)
from base_benchmark import get_engine_name, serialize_engine
import tensorrt_llm
from tensorrt_llm._utils import str_dtype_to_trt
from tensorrt_llm.builder import Builder
from tensorrt_llm.layers import PositionEmbeddingType
from tensorrt_llm.logger import logger
from tensorrt_llm.models import quantize_model
from tensorrt_llm.network import net_guard
from tensorrt_llm.plugin.plugin import ContextFMHAType
from tensorrt_llm.quantization import QuantMode
def parse_arguments():
parser = argparse.ArgumentParser(description='Build TensorRT-LLM models.')
parser.add_argument('-m',
'--model',
type=str,
required=True,
choices=get_allowed_models(),
help='Specify model you want to build.')
parser.add_argument(
'--mode',
type=str,
default="plugin",
choices=['ootb', 'plugin', 'ootb-except-mha'],
help=
('Choose mode between ootb/plugin/ootb-except-mha. '
'\"ootb\" means the engines will be built without any plugins, '
'\"plugin\" means the engines will be built with tuned recipe of using plugins.'
'\"ootb-except-mha\" means the engines will be built with only attention plugins.'
))
parser.add_argument(
'--dtype',
type=str,
default='float16',
choices=['float16', 'bfloat16', 'float32'],
help='Choose data type between float16/bfloat16/float32.')
parser.add_argument(
'--quantization',
type=str,
default=None,
choices=[
'fp8', 'fp8_gemm', 'fp8_kv_cache', 'int8_sq_per_tensor',
'int8_sq_per_token_channel', 'int8_weight_only', 'int4_weight_only',
'int4_weight_only_awq', 'int4_weight_only_gptq'
],
help="Optimize the model with specified quantization recipe")
parser.add_argument(
'--log_level',
type=str,
default="error",
choices=['verbose', 'info', 'warning', 'error', 'internal_error'],
help=
'Choose log level between verbose/info/warning/error/internal_error.')
parser.add_argument(
'--output_dir',
type=str,
required=True,
help='TensorRT engines will be saved to the specified path.')
parser.add_argument(
'--max_beam_width',
type=int,
default=None,
help=
('If this option is specified, it will override the max beam width of '
'TRT engines to the specified value instead of using pre-defined one'))
parser.add_argument(
'--max_input_len',
type=int,
default=None,
help=
('If this option is specified, it will override the max input len of '
'TRT engines to the specified value instead of using pre-defined one'))
parser.add_argument(
'--max_output_len',
type=int,
default=None,
help=
('If this option is specified, it will override the max output len of '
'TRT engines to the specified value instead of using pre-defined one'))
parser.add_argument(
'--max_batch_size',
type=int,
default=None,
help=
('If this option is specified, it will override the max batch size of '
'TRT engines to the specified value instead of using pre-defined one'))
parser.add_argument('--force_num_layer_1',
default=False,
action='store_true',
help='Quick sanity check with num_layer=1.')
return parser.parse_args()
def get_quant_mode(quantization):
quant_mode = QuantMode(0)
strongly_typed = False
use_smooth_quant = False
per_token = False
per_channel = False
weight_only_precision = 'int8'
if quantization == "fp8":
strongly_typed = True
quant_mode = quant_mode.set_fp8_qdq()
quant_mode = quant_mode.set_fp8_kv_cache()
elif quantization == "fp8_gemm":
strongly_typed = True
quant_mode = quant_mode.set_fp8_qdq()
elif quantization == "fp8_kv_cache":
strongly_typed = True
quant_mode = quant_mode.set_fp8_kv_cache()
elif quantization == "int8_sq_per_tensor":
use_smooth_quant = True
quant_mode = QuantMode.use_smooth_quant(per_token, per_channel)
elif quantization == "int8_sq_per_token_channel":
use_smooth_quant = True
per_token = True
per_channel = True
quant_mode = QuantMode.use_smooth_quant(per_token, per_channel)
elif quantization == "int8_weight_only":
use_smooth_quant = False
weight_only_precision = 'int8'
quant_mode = QuantMode.use_weight_only(False)
elif quantization == "int4_weight_only":
weight_only_precision = 'int4'
quant_mode = QuantMode.use_weight_only(True)
elif quantization == "int4_weight_only_awq":
weight_only_precision = 'int4_awq'
quant_mode = QuantMode.from_description(quantize_weights=True,
quantize_activations=False,
per_token=False,
per_channel=False,
per_group=True,
use_int4_weights=True)
elif quantization == "int4_weight_only_gptq":
weight_only_precision = 'int4_gptq'
quant_mode = QuantMode.from_description(quantize_weights=True,
quantize_activations=False,
per_token=False,
per_channel=False,
per_group=True,
use_int4_weights=True)
elif quantization == None:
pass
else:
raise Exception(f'Unexpected quantization: {quantization}')
return quant_mode, strongly_typed, use_smooth_quant, weight_only_precision
def build_gpt(args):
build_config = get_build_config(args.model)
if args.force_num_layer_1:
build_config['num_layers'] = 1
# More parameters
world_size = tensorrt_llm.mpi_world_size()
runtime_rank = tensorrt_llm.mpi_rank()
torch.cuda.set_device(runtime_rank)
num_kv_heads = build_config['num_heads'] \
if build_config['num_kv_heads'] is None else build_config['num_kv_heads']
apply_query_key_layer_scaling = False
max_batch_size = build_config['max_batch_size'] \
if args.max_batch_size is None else args.max_batch_size
max_input_len = build_config['max_input_len'] \
if args.max_input_len is None else args.max_input_len
max_output_len = build_config['max_output_len'] \
if args.max_output_len is None else args.max_output_len
max_beam_width = build_config['max_beam_width'] \
if args.max_beam_width is None else args.max_beam_width
quant_mode, strongly_typed, use_smooth_quant, weight_only_precision = get_quant_mode(
args.quantization)
use_weight_only = quant_mode.is_weight_only()
builder = Builder()
builder_config = builder.create_builder_config(
name=args.model,
precision=args.dtype,
timing_cache=None,
tensor_parallel=world_size, # TP only
parallel_build=True,
num_layers=build_config['num_layers'],
num_heads=build_config['num_heads'],
num_kv_heads=num_kv_heads,
hidden_size=build_config['hidden_size'],
vocab_size=build_config['vocab_size'],
hidden_act=build_config['hidden_act'],
max_position_embeddings=build_config['n_positions'],
apply_query_key_layer_scaling=apply_query_key_layer_scaling,
max_batch_size=max_batch_size,
max_input_len=max_input_len,
max_output_len=max_output_len,
int8=(quant_mode.has_act_and_weight_quant()
or quant_mode.is_int8_weight_only()),
quant_mode=quant_mode,
use_refit=False,
opt_level=build_config['builder_opt'],
strongly_typed=strongly_typed)
engine_name = get_engine_name(args.model, args.dtype, world_size,
runtime_rank)
kv_dtype = str_dtype_to_trt(args.dtype)
# Initialize Module
family = get_model_family(args.model)
if family == "gpt":
tensorrt_llm_model = tensorrt_llm.models.GPTLMHeadModel(
num_layers=build_config['num_layers'],
num_heads=build_config['num_heads'],
hidden_size=build_config['hidden_size'],
vocab_size=build_config['vocab_size'],
hidden_act=build_config['hidden_act'],
max_position_embeddings=build_config['n_positions'],
dtype=kv_dtype,
mapping=tensorrt_llm.Mapping(world_size=world_size,
tp_size=world_size), # TP only
apply_query_key_layer_scaling=builder_config.
apply_query_key_layer_scaling,
position_embedding_type=PositionEmbeddingType.learned_absolute
if build_config['position_embedding_type'] is None else
build_config['position_embedding_type'],
rotary_embedding_percentage=build_config['rotary_pct'],
quant_mode=quant_mode,
bias=build_config['bias'],
moe_layer_config=tensorrt_llm.moe_config.MoeLayerConfig(
build_config["moe_num_experts"], build_config["moe_top_k"]))
elif family == "opt":
tensorrt_llm_model = tensorrt_llm.models.OPTLMHeadModel(
num_layers=build_config['num_layers'],
num_heads=build_config['num_heads'],
hidden_size=build_config['hidden_size'],
vocab_size=build_config['vocab_size'],
hidden_act=build_config['hidden_act'],
max_position_embeddings=build_config['n_positions'],
dtype=kv_dtype,
mapping=tensorrt_llm.Mapping(world_size=world_size,
tp_size=world_size), # TP only
pre_norm=build_config['pre_norm'],
do_layer_norm_before=build_config['do_layer_norm_before'])
elif family == "llama":
tensorrt_llm_model = tensorrt_llm.models.LLaMAForCausalLM(
num_layers=build_config['num_layers'],
num_heads=build_config['num_heads'],
num_kv_heads=num_kv_heads,
hidden_size=build_config['hidden_size'],
vocab_size=build_config['vocab_size'],
hidden_act=build_config['hidden_act'],
max_position_embeddings=build_config['n_positions'],
dtype=kv_dtype,
mlp_hidden_size=build_config['inter_size'],
mapping=tensorrt_llm.Mapping(world_size=world_size,
tp_size=world_size), # TP only
quant_mode=quant_mode,
use_fused_mlp=True)
elif family == "gptj":
tensorrt_llm_model = tensorrt_llm.models.GPTJForCausalLM(
num_layers=build_config['num_layers'],
num_heads=build_config['num_heads'],
hidden_size=build_config['hidden_size'],
vocab_size=build_config['vocab_size'],
hidden_act=build_config['hidden_act'],
max_position_embeddings=build_config['n_positions'],
rotary_dim=build_config['rotary_dim'],
dtype=kv_dtype,
mapping=tensorrt_llm.Mapping(world_size=world_size,
tp_size=world_size), # TP only
quant_mode=quant_mode)
elif family == "gptneox":
tensorrt_llm_model = tensorrt_llm.models.GPTNeoXForCausalLM(
num_layers=build_config['num_layers'],
num_heads=build_config['num_heads'],
hidden_size=build_config['hidden_size'],
vocab_size=build_config['vocab_size'],
hidden_act=build_config['hidden_act'],
max_position_embeddings=build_config['n_positions'],
rotary_dim=build_config['rotary_dim'],
dtype=kv_dtype,
mapping=tensorrt_llm.Mapping(world_size=world_size,
tp_size=world_size), # TP only
apply_query_key_layer_scaling=builder_config.
apply_query_key_layer_scaling)
elif family == "chatglm":
tensorrt_llm_model = tensorrt_llm.models.ChatGLMHeadModel(
num_layers=build_config['num_layers'],
num_heads=build_config['num_heads'],
hidden_size=build_config['hidden_size'],
vocab_size=build_config['vocab_size'],
hidden_act=build_config['hidden_act'],
max_position_embeddings=build_config['n_positions'],
dtype=kv_dtype,
mapping=tensorrt_llm.Mapping(world_size=world_size,
tp_size=world_size), # TP only
apply_query_key_layer_scaling=builder_config.
apply_query_key_layer_scaling,
quant_mode=quant_mode,
model_name="chatglm_6b")
elif family == "chatglm2":
tensorrt_llm_model = tensorrt_llm.models.ChatGLMHeadModel(
num_layers=build_config['num_layers'],
num_heads=build_config['num_heads'],
hidden_size=build_config['hidden_size'],
vocab_size=build_config['vocab_size'],
hidden_act=build_config['hidden_act'],
max_position_embeddings=build_config['n_positions'],
dtype=kv_dtype,
mapping=tensorrt_llm.Mapping(world_size=world_size,
tp_size=world_size), # TP only
apply_query_key_layer_scaling=builder_config.
apply_query_key_layer_scaling,
quant_mode=quant_mode,
model_name="chatglm2_6b")
elif family == "chatglm3":
tensorrt_llm_model = tensorrt_llm.models.ChatGLMHeadModel(
num_layers=build_config['num_layers'],
num_heads=build_config['num_heads'],
hidden_size=build_config['hidden_size'],
vocab_size=build_config['vocab_size'],
hidden_act=build_config['hidden_act'],
max_position_embeddings=build_config['n_positions'],
dtype=kv_dtype,
mapping=tensorrt_llm.Mapping(world_size=world_size,
tp_size=world_size), # TP only
apply_query_key_layer_scaling=builder_config.
apply_query_key_layer_scaling,
quant_mode=quant_mode,
model_name="chatglm3_6b")
elif family == "bloom":
tensorrt_llm_model = tensorrt_llm.models.BloomForCausalLM(
num_layers=build_config['num_layers'],
num_heads=build_config['num_heads'],
hidden_size=build_config['hidden_size'],
vocab_size=build_config['vocab_size'],
max_position_embeddings=build_config['n_positions'],
dtype=kv_dtype,
mapping=tensorrt_llm.Mapping(world_size=world_size,
tp_size=world_size), # TP only
quant_mode=quant_mode,
use_parallel_embedding=(args.model == 'bloom_176b'))
elif family == "falcon":
tensorrt_llm_model = tensorrt_llm.models.FalconForCausalLM(
num_layers=build_config['num_layers'],
num_heads=build_config['num_heads'],
num_kv_heads=num_kv_heads,
hidden_size=build_config['hidden_size'],
vocab_size=build_config['vocab_size'],
max_position_embeddings=build_config['n_positions'],
dtype=kv_dtype,
bias=build_config['bias'],
quant_mode=quant_mode,
use_alibi=build_config['use_alibi'],
new_decoder_architecture=build_config['new_decoder_architecture'],
parallel_attention=build_config['parallel_attention'],
mapping=tensorrt_llm.Mapping(world_size=world_size,
tp_size=world_size))
else:
raise Exception(f'Unexpected model: {args.model}')
quant_kwargs = {}
if family == "llama" and use_weight_only:
if weight_only_precision == 'int4_awq':
quant_kwargs = {
"group_size": 128,
"zero": False,
"pre_quant_scale": True,
"exclude_modules": [],
}
elif weight_only_precision == 'int4_gptq':
quant_kwargs = {
"group_size": 128,
"zero": True,
"pre_quant_scale": False,
}
tensorrt_llm_model = quantize_model(tensorrt_llm_model, quant_mode,
**quant_kwargs)
# Module -> Network
network = builder.create_network()
network.trt_network.name = engine_name
# Plugins
if args.mode == 'plugin':
network.plugin_config.set_gpt_attention_plugin(dtype=args.dtype)
network.plugin_config.set_context_fmha(ContextFMHAType.enabled)
network.plugin_config.enable_remove_input_padding()
network.plugin_config.set_lookup_plugin(dtype=args.dtype)
if args.quantization is None or "fp8" not in args.quantization:
network.plugin_config.set_gemm_plugin(dtype=args.dtype)
# Quantization plugins.
if use_smooth_quant:
network.plugin_config.set_smooth_quant_gemm_plugin(dtype=args.dtype)
network.plugin_config.set_layernorm_quantization_plugin(
dtype=args.dtype)
network.plugin_config.set_quantize_tensor_plugin()
network.plugin_config.set_quantize_per_token_plugin()
elif use_weight_only:
network.plugin_config.set_weight_only_quant_matmul_plugin(
dtype=args.dtype)
elif family == "llama" and quant_mode.has_act_and_weight_quant():
# RMS norm plugin for SmoothQuant
network.plugin_config.set_rmsnorm_quantization_plugin(
dtype=args.dtype)
elif args.mode == 'ootb-except-mha':
network.plugin_config.set_gpt_attention_plugin(dtype=args.dtype)
network.plugin_config.set_context_fmha(ContextFMHAType.enabled)
if world_size > 1:
network.plugin_config.set_nccl_plugin(
dtype=args.dtype,
use_custom_all_reduce=build_config["use_custom_all_reduce"])
with net_guard(network):
# Prepare
network.set_named_parameters(tensorrt_llm_model.named_parameters())
# Forward
inputs = tensorrt_llm_model.prepare_inputs(max_batch_size,
max_input_len,
max_output_len, True,
max_beam_width)
tensorrt_llm_model(*inputs)
if args.mode == 'plugin':
tensorrt_llm.graph_rewriting.optimize(network)
# Network -> Engine
engine = builder.build_engine(network, builder_config)
assert engine is not None, f'Failed to build engine for rank {runtime_rank}'
if args.output_dir is not None:
os.makedirs(args.output_dir, exist_ok=True)
serialize_path = os.path.join(args.output_dir, engine_name)
serialize_engine(engine, serialize_path)
if runtime_rank == 0:
config_path = os.path.join(args.output_dir, 'config.json')
builder_config.plugin_config = network.plugin_config
builder.save_config(builder_config, config_path)
return engine
def build_bert(args):
build_config = get_build_config(args.model)
if args.force_num_layer_1:
build_config['num_layers'] = 1
# More parameters
world_size = tensorrt_llm.mpi_world_size()
runtime_rank = tensorrt_llm.mpi_rank()
torch.cuda.set_device(runtime_rank)
num_kv_heads = build_config['num_heads'] \
if build_config['num_kv_heads'] is None else build_config['num_kv_heads']
max_batch_size = build_config['max_batch_size'] \
if args.max_batch_size is None else args.max_batch_size
max_input_len = build_config['max_input_len'] \
if args.max_input_len is None else args.max_input_len
bs_range = [1, (max_batch_size + 1) // 2, max_batch_size]
inlen_range = [1, (max_input_len + 1) // 2, max_input_len]
builder = Builder()
builder_config = builder.create_builder_config(
name=args.model,
precision=args.dtype,
timing_cache=None,
tensor_parallel=world_size, # TP only
parallel_build=True,
num_layers=build_config['num_layers'],
num_heads=build_config['num_heads'],
num_kv_heads=num_kv_heads,
hidden_size=build_config['hidden_size'],
vocab_size=build_config['vocab_size'],
hidden_act=build_config['hidden_act'],
max_position_embeddings=build_config['n_positions'],
max_batch_size=max_batch_size,
max_input_len=max_input_len,
opt_level=build_config['builder_opt'])
engine_name = get_engine_name(args.model, args.dtype, world_size,
runtime_rank)
# Initialize model
tensorrt_llm_bert = tensorrt_llm.models.BertModel(
num_layers=build_config['num_layers'],
num_heads=build_config['num_heads'],
hidden_size=build_config['hidden_size'],
vocab_size=build_config['vocab_size'],
hidden_act=build_config['hidden_act'],
max_position_embeddings=build_config['n_positions'],
type_vocab_size=build_config['type_vocab_size'],
mapping=tensorrt_llm.Mapping(world_size=world_size, tp_size=world_size))
# Module -> Network
network = builder.create_network()
network.trt_network.name = engine_name
# Plugins
if args.mode == 'plugin':
network.plugin_config.set_bert_attention_plugin(dtype=args.dtype)
network.plugin_config.set_gemm_plugin(dtype=args.dtype)
network.plugin_config.enable_qk_half_accum()
network.plugin_config.set_context_fmha(ContextFMHAType.enabled)
elif args.mode == 'ootb-except-mha':
network.plugin_config.set_bert_attention_plugin(dtype=args.dtype)
network.plugin_config.set_context_fmha(ContextFMHAType.enabled)
if world_size > 1:
network.plugin_config.set_nccl_plugin(
dtype=args.dtype,
use_custom_all_reduce=build_config["use_custom_all_reduce"])
with net_guard(network):
# Prepare
network.set_named_parameters(tensorrt_llm_bert.named_parameters())
# Forward
input_ids = tensorrt_llm.Tensor(
name='input_ids',
dtype=trt.int32,
shape=[-1, -1],
dim_range=OrderedDict([('batch_size', [bs_range]),
('input_len', [inlen_range])]),
)
input_lengths = tensorrt_llm.Tensor(name='input_lengths',
dtype=trt.int32,
shape=[-1],
dim_range=OrderedDict([
('batch_size', [bs_range])
]))
hidden_states = tensorrt_llm_bert(input_ids=input_ids,
input_lengths=input_lengths)
# Mark outputs
hidden_states_dtype = str_dtype_to_trt(args.dtype)
hidden_states.mark_output('hidden_states', hidden_states_dtype)
# Network -> Engine
engine = builder.build_engine(network, builder_config)
assert engine is not None, f'Failed to build engine for rank {runtime_rank}'
if args.output_dir is not None:
if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)
serialize_path = os.path.join(args.output_dir, engine_name)
serialize_engine(engine, serialize_path)
if runtime_rank == 0:
config_path = os.path.join(args.output_dir, 'config.json')
builder_config.plugin_config = network.plugin_config
builder.save_config(builder_config, config_path)
return engine
def main(args):
logger.set_level(args.log_level)
if args.model in get_allowed_models(benchmark_type="gpt"):
build_gpt(args)
elif args.model in get_allowed_models(benchmark_type="bert"):
build_bert(args)
else:
raise Exception(f'Unexpected model: {args.model}')
if __name__ == '__main__':
mp.set_start_method('spawn')
args = parse_arguments()
main(args)

View File

@ -17,158 +17,95 @@ import time
from math import ceil
import torch
from allowed_configs import get_build_config, get_model_family
from base_benchmark import BaseBenchmark, get_engine_name, serialize_engine
import tensorrt_llm
from tensorrt_llm._utils import str_dtype_to_trt
from tensorrt_llm.builder import Builder
from tensorrt_llm.layers import PositionEmbeddingType
from tensorrt_llm.models import quantize_model
from tensorrt_llm.network import net_guard
from tensorrt_llm.plugin.plugin import ContextFMHAType
from tensorrt_llm.quantization import QuantMode
from allowed_configs import get_build_config # isort:skip
from base_benchmark import BaseBenchmark # isort:skip
from build import build_gpt, get_quant_mode # isort:skip
class GPTBenchmark(BaseBenchmark):
def __init__(self,
engine_dir,
model_name,
mode,
batch_sizes,
in_out_lens,
dtype,
refit,
num_beams,
top_k,
top_p,
output_dir,
n_positions=None,
max_input_len=None,
max_output_len=None,
max_batch_size=None,
enable_custom_all_reduce=None,
**kwargs):
super().__init__(engine_dir, model_name, dtype, output_dir)
def __init__(self, args, batch_sizes, in_out_lens):
super().__init__(args.engine_dir, args.model, args.dtype)
self.batch_sizes = batch_sizes
self.in_out_lens = in_out_lens
self.refit = refit
self.num_beams = num_beams
self.num_beams = args.num_beams
self.mode = args.mode
self.build_time = 0
self.mode = mode # plugin or ootb or ootb-except-mha
self.fuse_bias = True
self.cuda_graph_mode = kwargs.get('enable_cuda_graph', False)
self.strongly_typed = kwargs.get('strongly_typed', False)
self.enable_custom_all_reduce = enable_custom_all_reduce
self.cuda_graph_mode = args.enable_cuda_graph
if engine_dir is not None:
if args.engine_dir is not None:
# Get build configs from engine directory is done in base class
# Deserialize engine from engine directory
self.serialize_path = os.path.join(engine_dir, self.engine_name)
self.serialize_path = os.path.join(args.engine_dir,
self.engine_name)
with open(self.serialize_path, 'rb') as f:
engine_buffer = f.read()
else:
# Build engine
self.world_size = tensorrt_llm.mpi_world_size()
self.apply_query_key_layer_scaling = False
self.use_weight_only = False
self.per_group = False
self.weight_only_precision = 'int8'
self.per_token = False
self.per_channel = False
use_mha_plugin = mode == 'plugin' or mode == 'ootb-except-mha'
mha_plg_dtype = dtype if use_mha_plugin else False
use_non_mha_plugin = mode == 'plugin'
non_mha_plg_dtype = dtype if use_non_mha_plugin else False
self.use_gpt_attention_plugin = mha_plg_dtype
self.use_gemm_plugin = non_mha_plg_dtype
# Starting TRT9.1 OOTB norm layer sees improvement over plugin norm layer
self.use_layernorm_plugin = False
self.use_rmsnorm_plugin = False
self.use_lookup_plugin = non_mha_plg_dtype
self.use_weight_only_quant_gemm_plugin = non_mha_plg_dtype
self.enable_context_fmha = use_mha_plugin
self.remove_input_padding = use_non_mha_plugin
for key, value in get_build_config(model_name).items():
for key, value in get_build_config(args.model).items():
setattr(self, key, value)
if self.quantization is None:
self.quantization = kwargs.get('quantization', None)
self.set_quantization()
# Override the n_position/max_input_len/max_output_len/max_batch_size to value from cmd line if that's specified.
if n_positions is not None:
assert isinstance(
n_positions, int
) and n_positions > 0, f"n_positions should be a valid int number, got {n_positions}"
self.n_positions = n_positions
if max_input_len is not None:
assert isinstance(
max_input_len, int
) and max_input_len > 0, f"max_input_len should be a valid int number, got {max_input_len}"
self.max_input_len = max_input_len
if max_output_len is not None:
assert isinstance(
max_output_len, int
) and max_output_len > 0, f"max_output_len should be a valid int number, got {max_output_len}"
self.max_output_len = max_output_len
if max_batch_size is not None:
assert isinstance(
max_batch_size, int
) and max_batch_size > 0, f"max_batch_size should be a valid int number, got {max_batch_size}"
self.max_batch_size = max_batch_size
if self.num_kv_heads is None:
self.num_kv_heads = self.num_heads
if kwargs.get('force_num_layer_1', False):
if args.force_num_layer_1:
self.num_layers = 1
engine_buffer = self.build()
self.quant_mode, _, _, _ = get_quant_mode(args.quantization)
self.enable_fp8 = self.quant_mode.has_fp8_qdq()
self.fp8_kv_cache = self.quant_mode.has_fp8_kv_cache()
# Plugins
self.use_gpt_attention_plugin = False
self.remove_input_padding = False
if args.mode == 'plugin':
self.use_gpt_attention_plugin = True
self.remove_input_padding = True
elif args.mode == 'ootb-except-mha':
self.use_gpt_attention_plugin = True
start = time.time()
engine_buffer = build_gpt(args)
self.build_time = round(time.time() - start, 2)
assert engine_buffer is not None
if not hasattr(self, 'num_kv_heads') or self.num_kv_heads is None:
self.num_kv_heads = self.num_heads
model_config = tensorrt_llm.runtime.ModelConfig(
vocab_size=self.vocab_size,
num_layers=self.num_layers,
num_heads=self.num_heads // self.world_size,
num_kv_heads=ceil(self.num_kv_heads / self.world_size),
hidden_size=self.hidden_size // self.world_size,
vocab_size=self.vocab_size,
num_layers=self.num_layers,
gpt_attention_plugin=self.use_gpt_attention_plugin,
remove_input_padding=self.remove_input_padding,
quant_mode=self.quant_mode,
use_custom_all_reduce=self.enable_custom_all_reduce,
use_custom_all_reduce=self.use_custom_all_reduce,
)
if model_name == 'chatglm_6b':
if args.model == 'chatglm_6b':
self.sampling_config = tensorrt_llm.runtime.SamplingConfig(
end_id=130005,
pad_id=3,
num_beams=num_beams,
top_k=top_k,
top_p=top_p)
num_beams=self.num_beams,
top_k=args.top_k,
top_p=args.top_p)
self.decoder = tensorrt_llm.runtime.ChatGLMGenerationSession(
model_config, engine_buffer, self.runtime_mapping)
elif model_name in ['chatglm2_6b', 'chatglm3_6b']:
elif args.model in ['chatglm2_6b', 'chatglm3_6b']:
self.sampling_config = tensorrt_llm.runtime.SamplingConfig(
end_id=2,
pad_id=0,
num_beams=num_beams,
top_k=top_k,
top_p=top_p)
num_beams=self.num_beams,
top_k=args.top_k,
top_p=args.top_p)
self.decoder = tensorrt_llm.runtime.GenerationSession(
model_config, engine_buffer, self.runtime_mapping)
else:
self.sampling_config = tensorrt_llm.runtime.SamplingConfig(
end_id=50256,
pad_id=50256,
num_beams=num_beams,
top_k=top_k,
top_p=top_p)
num_beams=self.num_beams,
top_k=args.top_k,
top_p=args.top_p)
self.decoder = tensorrt_llm.runtime.GenerationSession(
model_config,
engine_buffer,
@ -201,370 +138,37 @@ class GPTBenchmark(BaseBenchmark):
self.decoder.setup(batch_size, inlen, outlen, beam_width=self.num_beams)
return (input_ids, input_lengths)
def set_quantization(self):
self.quant_mode = QuantMode(0)
def get_report_dict(self, benchmark_profiler=None):
report_dict = super().get_report_dict(
benchmark_profiler=benchmark_profiler)
if benchmark_profiler is not None:
report_dict["generation_time(ms)"] = None
report_dict["total_generated_tokens"] = None
report_dict["generation_tokens_per_second"] = None
return report_dict
if self.quantization == "fp8":
self.strongly_typed = True
self.quant_mode = self.quant_mode.set_fp8_qdq()
self.quant_mode = self.quant_mode.set_fp8_kv_cache()
elif self.quantization == "fp8_gemm":
self.strongly_typed = True
self.quant_mode = self.quant_mode.set_fp8_qdq()
elif self.quantization == "fp8_kv_cache":
self.strongly_typed = True
self.quant_mode = self.quant_mode.set_fp8_kv_cache()
elif self.quantization == "int8_sq_per_tensor":
self.use_smooth_quant = True
self.quant_mode = QuantMode.use_smooth_quant(
self.per_token, self.per_channel)
elif self.quantization == "int8_sq_per_token_channel":
self.use_smooth_quant = True
self.per_token = True
self.per_channel = True
self.quant_mode = QuantMode.use_smooth_quant(
self.per_token, self.per_channel)
elif self.quantization == "int8_weight_only":
self.use_smooth_quant = False
self.use_weight_only = True
self.weight_only_precision = 'int8'
self.quant_mode = QuantMode.use_weight_only(False)
elif self.quantization == "int4_weight_only":
self.use_weight_only = True
self.weight_only_precision = 'int4'
self.quant_mode = QuantMode.use_weight_only(True)
elif self.quantization == "int4_weight_only_awq":
self.use_weight_only = True
self.per_group = True
self.weight_only_precision = 'int4_awq'
self.quant_mode = QuantMode.from_description(
quantize_weights=True,
quantize_activations=False,
per_token=False,
per_channel=False,
per_group=True,
use_int4_weights=True)
elif self.quantization == "int4_weight_only_gptq":
self.use_weight_only = True
self.per_group = True
self.weight_only_precision = 'int4_gptq'
self.quant_mode = QuantMode.from_description(
quantize_weights=True,
quantize_activations=False,
per_token=False,
per_channel=False,
per_group=True,
use_int4_weights=True)
elif self.quantization == None:
pass
else:
raise Exception(f'{0} is invalid config: {self.quantization}')
def build(self):
builder = Builder()
builder_config = builder.create_builder_config(
name=self.model_name,
precision=self.dtype,
timing_cache=None,
tensor_parallel=self.world_size, # TP only
parallel_build=True,
num_layers=self.num_layers,
num_heads=self.num_heads,
num_kv_heads=self.num_kv_heads,
hidden_size=self.hidden_size,
vocab_size=self.vocab_size,
hidden_act=self.hidden_act,
max_position_embeddings=self.n_positions,
apply_query_key_layer_scaling=self.apply_query_key_layer_scaling,
max_batch_size=self.max_batch_size,
max_input_len=self.max_input_len,
max_output_len=self.max_output_len,
int8=self.quant_mode.has_act_and_weight_quant()
or self.quant_mode.is_int8_weight_only(),
quant_mode=self.quant_mode,
use_refit=self.refit,
opt_level=self.builder_opt,
strongly_typed=self.strongly_typed)
engine_name = get_engine_name(self.model_name, self.dtype,
self.world_size, self.runtime_rank)
kv_dtype = str_dtype_to_trt(self.dtype)
# Initialize Module
family = get_model_family(self.model_name)
if family == "gpt":
tensorrt_llm_model = tensorrt_llm.models.GPTLMHeadModel(
num_layers=self.num_layers,
num_heads=self.num_heads,
hidden_size=self.hidden_size,
vocab_size=self.vocab_size,
hidden_act=self.hidden_act,
max_position_embeddings=self.n_positions,
dtype=kv_dtype,
mapping=tensorrt_llm.Mapping(
world_size=self.world_size,
tp_size=self.world_size), # TP only
apply_query_key_layer_scaling=builder_config.
apply_query_key_layer_scaling,
position_embedding_type=PositionEmbeddingType.learned_absolute
if self.position_embedding_type is None else
self.position_embedding_type,
rotary_embedding_percentage=self.rotary_pct,
quant_mode=self.quant_mode,
bias=self.bias)
elif family == "opt":
tensorrt_llm_model = tensorrt_llm.models.OPTLMHeadModel(
num_layers=self.num_layers,
num_heads=self.num_heads,
hidden_size=self.hidden_size,
vocab_size=self.vocab_size,
hidden_act=self.hidden_act,
max_position_embeddings=self.n_positions,
dtype=kv_dtype,
mapping=tensorrt_llm.Mapping(
world_size=self.world_size,
tp_size=self.world_size), # TP only
pre_norm=self.pre_norm,
do_layer_norm_before=self.do_layer_norm_before)
elif family == "llama":
tensorrt_llm_model = tensorrt_llm.models.LLaMAForCausalLM(
num_layers=self.num_layers,
num_heads=self.num_heads,
num_kv_heads=self.num_kv_heads,
hidden_size=self.hidden_size,
vocab_size=self.vocab_size,
hidden_act=self.hidden_act,
max_position_embeddings=self.n_positions,
dtype=kv_dtype,
mlp_hidden_size=self.inter_size,
mapping=tensorrt_llm.Mapping(
world_size=self.world_size,
tp_size=self.world_size), # TP only
quant_mode=self.quant_mode)
elif family == "gptj":
tensorrt_llm_model = tensorrt_llm.models.GPTJForCausalLM(
num_layers=self.num_layers,
num_heads=self.num_heads,
hidden_size=self.hidden_size,
vocab_size=self.vocab_size,
hidden_act=self.hidden_act,
max_position_embeddings=self.n_positions,
rotary_dim=self.rotary_dim,
dtype=kv_dtype,
mapping=tensorrt_llm.Mapping(
world_size=self.world_size,
tp_size=self.world_size), # TP only
quant_mode=self.quant_mode)
elif family == "gptneox":
tensorrt_llm_model = tensorrt_llm.models.GPTNeoXForCausalLM(
num_layers=self.num_layers,
num_heads=self.num_heads,
hidden_size=self.hidden_size,
vocab_size=self.vocab_size,
hidden_act=self.hidden_act,
max_position_embeddings=self.n_positions,
rotary_dim=self.rotary_dim,
dtype=kv_dtype,
mapping=tensorrt_llm.Mapping(
world_size=self.world_size,
tp_size=self.world_size), # TP only
apply_query_key_layer_scaling=builder_config.
apply_query_key_layer_scaling)
elif family == "chatglm":
tensorrt_llm_model = tensorrt_llm.models.ChatGLMHeadModel(
num_layers=self.num_layers,
num_heads=self.num_heads,
hidden_size=self.hidden_size,
vocab_size=self.vocab_size,
hidden_act=self.hidden_act,
max_position_embeddings=self.n_positions,
dtype=kv_dtype,
mapping=tensorrt_llm.Mapping(
world_size=self.world_size,
tp_size=self.world_size), # TP only
apply_query_key_layer_scaling=builder_config.
apply_query_key_layer_scaling,
quant_mode=self.quant_mode,
model_name="chatglm_6b")
elif family == "chatglm2":
tensorrt_llm_model = tensorrt_llm.models.ChatGLMHeadModel(
num_layers=self.num_layers,
num_heads=self.num_heads,
hidden_size=self.hidden_size,
vocab_size=self.vocab_size,
hidden_act=self.hidden_act,
max_position_embeddings=self.n_positions,
dtype=kv_dtype,
mapping=tensorrt_llm.Mapping(
world_size=self.world_size,
tp_size=self.world_size), # TP only
apply_query_key_layer_scaling=builder_config.
apply_query_key_layer_scaling,
quant_mode=self.quant_mode,
model_name="chatglm2_6b")
elif family == "chatglm3":
tensorrt_llm_model = tensorrt_llm.models.ChatGLMHeadModel(
num_layers=self.num_layers,
num_heads=self.num_heads,
hidden_size=self.hidden_size,
vocab_size=self.vocab_size,
hidden_act=self.hidden_act,
max_position_embeddings=self.n_positions,
dtype=kv_dtype,
mapping=tensorrt_llm.Mapping(
world_size=self.world_size,
tp_size=self.world_size), # TP only
apply_query_key_layer_scaling=builder_config.
apply_query_key_layer_scaling,
quant_mode=self.quant_mode,
model_name="chatglm3_6b")
elif family == "bloom":
tensorrt_llm_model = tensorrt_llm.models.BloomForCausalLM(
num_layers=self.num_layers,
num_heads=self.num_heads,
hidden_size=self.hidden_size,
vocab_size=self.vocab_size,
max_position_embeddings=self.n_positions,
dtype=kv_dtype,
mapping=tensorrt_llm.Mapping(
world_size=self.world_size,
tp_size=self.world_size), # TP only
quant_mode=self.quant_mode,
use_parallel_embedding=(self.model_name == 'bloom_176b'))
elif family == "falcon":
tensorrt_llm_model = tensorrt_llm.models.FalconForCausalLM(
num_layers=self.num_layers,
num_heads=self.num_heads,
num_kv_heads=self.num_kv_heads,
hidden_size=self.hidden_size,
vocab_size=self.vocab_size,
max_position_embeddings=self.n_positions,
dtype=kv_dtype,
bias=self.bias,
quant_mode=self.quant_mode,
use_alibi=self.use_alibi,
new_decoder_architecture=self.new_decoder_architecture,
parallel_attention=self.parallel_attention,
mapping=tensorrt_llm.Mapping(world_size=self.world_size,
tp_size=self.world_size))
else:
raise Exception(f'Unexpected model: {self.model_name}')
quant_kwargs = {}
if family == "llama" and self.use_weight_only:
if self.weight_only_precision == 'int4_awq':
quant_kwargs = {
"group_size": 128,
"zero": False,
"pre_quant_scale": True,
"exclude_modules": [],
}
elif self.weight_only_precision == 'int4_gptq':
quant_kwargs = {
"group_size": 128,
"zero": True,
"pre_quant_scale": False,
}
tensorrt_llm_model = quantize_model(tensorrt_llm_model, self.quant_mode,
**quant_kwargs)
# Module -> Network
network = builder.create_network()
network.trt_network.name = engine_name
not_fp8_quantization = self.quantization is None or "fp8" not in self.quantization
if self.use_gpt_attention_plugin:
network.plugin_config.set_gpt_attention_plugin(
dtype=self.use_gpt_attention_plugin)
if self.use_gemm_plugin and not_fp8_quantization:
network.plugin_config.set_gemm_plugin(dtype=self.use_gemm_plugin)
if self.use_layernorm_plugin:
network.plugin_config.set_layernorm_plugin(
dtype=self.use_layernorm_plugin)
if self.use_rmsnorm_plugin:
network.plugin_config.set_rmsnorm_plugin(
dtype=self.use_rmsnorm_plugin)
if self.enable_context_fmha:
network.plugin_config.set_context_fmha(ContextFMHAType.enabled)
if self.remove_input_padding:
network.plugin_config.enable_remove_input_padding()
# Quantization plugins.
if self.use_smooth_quant:
network.plugin_config.set_smooth_quant_gemm_plugin(dtype=self.dtype)
network.plugin_config.set_layernorm_quantization_plugin(
dtype=self.dtype)
network.plugin_config.set_quantize_tensor_plugin()
network.plugin_config.set_quantize_per_token_plugin()
elif self.use_weight_only and self.use_weight_only_quant_gemm_plugin:
network.plugin_config.set_weight_only_quant_matmul_plugin(
dtype=self.dtype)
# RMS norm plugin for SmoothQuant
if self.quant_mode.has_act_and_weight_quant(
) and 'llama' in self.model_name:
network.plugin_config.set_rmsnorm_quantization_plugin()
if self.world_size > 1:
network.plugin_config.set_nccl_plugin(self.dtype,
self.enable_custom_all_reduce)
# Use the plugin for the embedding parallelism and sharing
network.plugin_config.set_lookup_plugin(dtype=self.use_lookup_plugin)
with net_guard(network):
# Prepare
network.set_named_parameters(tensorrt_llm_model.named_parameters())
# Forward
inputs = tensorrt_llm_model.prepare_inputs(self.max_batch_size,
self.max_input_len,
self.max_output_len,
True, self.num_beams)
tensorrt_llm_model(*inputs)
if self.fuse_bias:
tensorrt_llm.graph_rewriting.optimize(network)
# Network -> Engine
start = time.time()
engine = builder.build_engine(network, builder_config)
end = time.time()
self.build_time = round(end - start, 2)
if self.output_dir is not None:
os.makedirs(self.output_dir, exist_ok=True)
self.serialize_path = os.path.join(self.output_dir,
self.engine_name)
serialize_engine(engine, self.serialize_path)
if self.runtime_rank == 0:
config_path = os.path.join(self.output_dir, 'config.json')
builder_config.plugin_config = network.plugin_config
builder.save_config(builder_config, config_path)
return engine
def run(self, inputs, config):
def run(self, inputs, config, benchmark_profiler=None):
batch_size, inlen, outlen = config[0], config[1], config[2]
self.decoder.setup(batch_size, inlen, outlen, beam_width=self.num_beams)
if self.remove_input_padding:
self.decoder.decode_batch(inputs[0], self.sampling_config)
self.decoder.decode_batch(inputs[0],
self.sampling_config,
benchmark_profiler=benchmark_profiler)
else:
self.decoder.decode(inputs[0], inputs[1], self.sampling_config)
self.decoder.decode(inputs[0],
inputs[1],
self.sampling_config,
benchmark_profiler=benchmark_profiler)
torch.cuda.synchronize()
def report(self, config, latency, percentile95, percentile99, peak_gpu_used,
csv):
def report(self,
config,
latency,
percentile95,
percentile99,
peak_gpu_used,
csv,
benchmark_profiler=None):
report_dict = super().get_report_dict()
batch_size, inlen, outlen = config[0], config[1], config[2]
tokens_per_sec = round(batch_size * outlen / (latency / 1000), 2)
@ -582,6 +186,20 @@ class GPTBenchmark(BaseBenchmark):
report_dict["percentile95(ms)"] = percentile95
report_dict["percentile99(ms)"] = percentile99
report_dict["gpu_peak_mem(gb)"] = peak_gpu_used
if benchmark_profiler is not None:
iter_count = benchmark_profiler.get_aux_info('iter_count')
generation_time_ms = benchmark_profiler.get_timer_value(
'generation_time')
generation_step_count = benchmark_profiler.get_aux_info(
'generation_step_count')
token_per_step = batch_size * self.num_beams
total_tokens = generation_step_count * token_per_step
report_dict["generation_time(ms)"] = generation_time_ms / iter_count
report_dict["total_generated_tokens"] = total_tokens / iter_count
tokens_per_second = round(
total_tokens * 1000.0 / generation_time_ms, 3)
report_dict["generation_tokens_per_second"] = tokens_per_second
if self.runtime_rank == 0:
if csv:
line = ",".join([str(v) for v in report_dict.values()])

View File

@ -112,31 +112,6 @@ private:
std::atomic<bool> shutdown_requested_;
void decoupled_execution_loop();
std::shared_ptr<std::thread> worker_thread_;
inline static const std::string kInputIdsTensorName_ = "input_ids";
inline static const std::string kDraftInputIdsTensorName_ = "draft_input_ids";
inline static const std::string kMaxNewTokensTensorName_ = "request_output_len";
inline static const std::string kBeamWidthTensorName_ = "beam_width";
inline static const std::string kEndIdTensorName_ = "end_id";
inline static const std::string kPadIdTensorName_ = "pad_id";
inline static const std::string kBadWordsListTensorName_ = "bad_words_list";
inline static const std::string kStopWordsListTensorName_ = "stop_words_list";
inline static const std::string kEmbeddingBiasTensorName_ = "embedding_bias";
inline static const std::string kTemperatureTensorName_ = "temperature";
inline static const std::string kRuntimeTopKTensorName_ = "runtime_top_k";
inline static const std::string kRuntimeTopPTensorName_ = "runtime_top_p";
inline static const std::string kLengthPenaltyTensorName_ = "len_penalty";
inline static const std::string kRepetitionPenaltyTensorName_ = "repetition_penalty";
inline static const std::string kMinLengthTensorName_ = "min_length";
inline static const std::string kPresencePenaltyTensorName_ = "presence_penalty";
inline static const std::string kRandomSeedTensorName_ = "random_seed";
inline static const std::string kReturnLogProbsTensorName_ = "return_log_probs";
inline static const std::string kPromptEmbeddingTableName_ = "prompt_embedding_table";
inline static const std::string kPromptVocabSizeName_ = "prompt_vocab_size";
inline static const std::string kOutputIdsTensorName_ = "output_ids";
inline static const std::string kSequenceLengthTensorName_ = "sequence_length";
inline static const std::string kLogProbsTensorName_ = "output_log_probs";
inline static const std::string kCumLogProbsTensorName_ = "cum_log_probs";
std::shared_ptr<nvinfer1::ILogger> mLogger{};
};

View File

@ -16,113 +16,82 @@
#pragma once
#include <cassert>
#include <chrono>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <iostream>
#include <list>
#include <map>
#include <memory>
#include <mutex>
#include <set>
#include <string>
#include <thread>
#include <tuple>
#include <vector>
#include "tensorrt_llm/batch_manager/NamedTensor.h"
#include "tensorrt_llm/batch_manager/namedTensor.h"
#include "tensorrt_llm/common/logger.h"
#include "tensorrt_llm/runtime/common.h"
#include "tensorrt_llm/runtime/iTensor.h"
#include <algorithm>
#include <array>
#include <cstdint>
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
namespace tensorrt_llm::batch_manager
{
template <typename TTensor, typename TTensorMap>
namespace inference_request
{
// Input tensors
auto constexpr kInputIdsTensorName = "input_ids";
auto constexpr kDraftInputIdsTensorName = "draft_input_ids";
auto constexpr kDraftLogitsTensorName = "draft_logits";
auto constexpr kMaxNewTokensTensorName = "request_output_len";
auto constexpr kBeamWidthTensorName = "beam_width";
auto constexpr kEndIdTensorName = "end_id";
auto constexpr kPadIdTensorName = "pad_id";
auto constexpr kBadWordsListTensorName = "bad_words_list";
auto constexpr kStopWordsListTensorName = "stop_words_list";
auto constexpr kEmbeddingBiasTensorName = "embedding_bias";
auto constexpr kTemperatureTensorName = "temperature";
auto constexpr kRuntimeTopKTensorName = "runtime_top_k";
auto constexpr kRuntimeTopPTensorName = "runtime_top_p";
auto constexpr kLengthPenaltyTensorName = "len_penalty";
auto constexpr kRepetitionPenaltyTensorName = "repetition_penalty";
auto constexpr kMinLengthTensorName = "min_length";
auto constexpr kPresencePenaltyTensorName = "presence_penalty";
auto constexpr kRandomSeedTensorName = "random_seed";
auto constexpr kReturnLogProbsTensorName = "return_log_probs";
auto constexpr kPromptEmbeddingTableName = "prompt_embedding_table";
auto constexpr kPromptVocabSizeName = "prompt_vocab_size";
// Output tensors
auto constexpr kOutputIdsTensorName = "output_ids";
auto constexpr kSequenceLengthTensorName = "sequence_length";
auto constexpr kLogProbsTensorName = "output_log_probs";
auto constexpr kCumLogProbsTensorName = "cum_log_probs";
} // namespace inference_request
template <typename TTensor, typename TNamedTensor>
class GenericInferenceRequest
{
public:
using TensorPtr = TTensor;
using TensorMap = TTensorMap;
using NamedTensorType = TNamedTensor;
using TensorMap = std::unordered_map<std::string, TTensor>;
GenericInferenceRequest(uint64_t requestId)
: mRequestId(requestId)
, mIsStreaming(false)
explicit GenericInferenceRequest(uint64_t requestId)
: mRequestId{requestId}
, mIsStreaming{false}
{
}
GenericInferenceRequest(TensorMap const& inputTensors, uint64_t requestId)
: mInputTensors(inputTensors)
, mRequestId(requestId)
, mIsStreaming(false)
GenericInferenceRequest(uint64_t requestId, TensorMap&& tensorMap)
: mRequestId{requestId}
, mIsStreaming{false}
, mInputTensors{std::move(tensorMap)}
{
}
GenericInferenceRequest(TensorMap&& inputTensors, uint64_t requestId)
: mInputTensors(std::move(inputTensors))
, mRequestId(requestId)
, mIsStreaming(false)
{
}
~GenericInferenceRequest() {}
template <typename T>
std::tuple<bool, T> getScalarValueFromTensor(
const std::string& inputTensorName, const std::vector<int64_t>& expectedShape, const bool is_optional) const
{
T scalarValue;
try
for (auto const& [name, tensor] : mInputTensors)
{
const auto& t = getInputTensor(inputTensorName);
std::vector<int64_t> tensorShape(t->getShape().nbDims);
for (int32_t i = 0; i < t->getShape().nbDims; ++i)
{
tensorShape[i] = t->getShape().d[i];
}
if (tensorShape != expectedShape)
{
std::string err = "Invalid shape for " + inputTensorName + ". Expected shape: [";
for (auto shape : expectedShape)
{
err += std::to_string(shape) + ",";
}
if (!expectedShape.empty())
{
// Remove last comma
err.pop_back();
}
err += "]";
throw std::runtime_error(err);
}
scalarValue = *static_cast<T*>(t->data());
validateTensorName(name);
}
catch (const std::exception& e)
{
// If parameter is optional, just ignore it
if (is_optional)
{
return {false, scalarValue};
}
else
{
std::cerr << "Out of Range error for tensor: " << inputTensorName << ": " << e.what() << '\n';
throw;
}
}
return {true, scalarValue};
}
const TensorPtr& getInputTensor(std::string const& inputTensorName) const
GenericInferenceRequest(uint64_t requestId, TensorMap const& tensorMap)
: GenericInferenceRequest(requestId, TensorMap{tensorMap})
{
return mInputTensors.at(inputTensorName);
}
void emplaceInputTensor(std::string const& inputTensorName, TensorPtr&& inputTensor)
{
mInputTensors.emplace(inputTensorName, std::move(inputTensor));
}
void setIsStreaming(bool isStreaming)
@ -130,91 +99,152 @@ public:
mIsStreaming = isStreaming;
}
bool isStreaming() const
[[nodiscard]] bool isStreaming() const
{
return mIsStreaming;
}
uint64_t getRequestId() const
[[nodiscard]] uint64_t getRequestId() const
{
return mRequestId;
}
static std::array constexpr kTensorNames = {
inference_request::kInputIdsTensorName,
inference_request::kDraftInputIdsTensorName,
inference_request::kDraftLogitsTensorName,
inference_request::kMaxNewTokensTensorName,
inference_request::kBeamWidthTensorName,
inference_request::kEndIdTensorName,
inference_request::kPadIdTensorName,
inference_request::kBadWordsListTensorName,
inference_request::kStopWordsListTensorName,
inference_request::kEmbeddingBiasTensorName,
inference_request::kTemperatureTensorName,
inference_request::kRuntimeTopKTensorName,
inference_request::kRuntimeTopPTensorName,
inference_request::kLengthPenaltyTensorName,
inference_request::kRepetitionPenaltyTensorName,
inference_request::kMinLengthTensorName,
inference_request::kPresencePenaltyTensorName,
inference_request::kRandomSeedTensorName,
inference_request::kReturnLogProbsTensorName,
inference_request::kPromptEmbeddingTableName,
inference_request::kPromptVocabSizeName,
};
#define TENSOR_GETTER_SETTER(funcName, tensorName) \
\
[[nodiscard]] bool has##funcName() const \
{ \
return mInputTensors.find(tensorName) != mInputTensors.end(); \
} \
\
[[nodiscard]] TensorPtr const& get##funcName() const \
{ \
auto it = mInputTensors.find(tensorName); \
TLLM_CHECK_WITH_INFO(it != mInputTensors.end(), "Undefined tensor: %s", tensorName); \
return it->second; \
} \
\
[[nodiscard]] TensorPtr get##funcName##Unchecked() const \
{ \
auto it = mInputTensors.find(tensorName); \
return it != mInputTensors.end() ? it->second : TensorPtr{}; \
} \
\
[[nodiscard]] NamedTensorType get##funcName##Named() const \
{ \
auto it = mInputTensors.find(tensorName); \
return it != mInputTensors.end() ? NamedTensorType{it->second, tensorName} : NamedTensor{tensorName}; \
} \
\
void set##funcName(TensorPtr const& tensor) \
{ \
mInputTensors[tensorName] = tensor; \
}
TENSOR_GETTER_SETTER(InputIds, inference_request::kInputIdsTensorName)
TENSOR_GETTER_SETTER(DraftInputIds, inference_request::kDraftInputIdsTensorName)
TENSOR_GETTER_SETTER(DraftLogits, inference_request::kDraftLogitsTensorName)
TENSOR_GETTER_SETTER(MaxNewTokens, inference_request::kMaxNewTokensTensorName)
TENSOR_GETTER_SETTER(BeamWidth, inference_request::kBeamWidthTensorName)
TENSOR_GETTER_SETTER(EndId, inference_request::kEndIdTensorName)
TENSOR_GETTER_SETTER(PadId, inference_request::kPadIdTensorName)
TENSOR_GETTER_SETTER(BadWordsList, inference_request::kBadWordsListTensorName)
TENSOR_GETTER_SETTER(StopWordsList, inference_request::kStopWordsListTensorName)
TENSOR_GETTER_SETTER(EmbeddingBias, inference_request::kEmbeddingBiasTensorName)
TENSOR_GETTER_SETTER(Temperature, inference_request::kTemperatureTensorName)
TENSOR_GETTER_SETTER(RuntimeTopK, inference_request::kRuntimeTopKTensorName)
TENSOR_GETTER_SETTER(RuntimeTopP, inference_request::kRuntimeTopPTensorName)
TENSOR_GETTER_SETTER(LengthPenalty, inference_request::kLengthPenaltyTensorName)
TENSOR_GETTER_SETTER(RepetitionPenalty, inference_request::kRepetitionPenaltyTensorName)
TENSOR_GETTER_SETTER(MinLength, inference_request::kMinLengthTensorName)
TENSOR_GETTER_SETTER(PresencePenalty, inference_request::kPresencePenaltyTensorName)
TENSOR_GETTER_SETTER(RandomSeed, inference_request::kRandomSeedTensorName)
TENSOR_GETTER_SETTER(ReturnLogProbs, inference_request::kReturnLogProbsTensorName)
TENSOR_GETTER_SETTER(PromptEmbeddingTable, inference_request::kPromptEmbeddingTableName)
TENSOR_GETTER_SETTER(PromptVocabSize, inference_request::kPromptVocabSizeName)
#undef TENSOR_GETTER_SETTER
protected:
TensorMap mInputTensors;
static void validateTensorName(std::string const& tensorName)
{
// TODO (martinma): Throw an exception if the tensor name is not valid.
if (std::find(kTensorNames.begin(), kTensorNames.end(), tensorName) == kTensorNames.end())
{
TLLM_LOG_WARNING("Invalid tensor name in InferenceRequest: %s", tensorName.c_str());
}
}
uint64_t mRequestId;
bool mIsStreaming;
TensorMap mInputTensors;
};
class InferenceRequest : public GenericInferenceRequest<tensorrt_llm::runtime::ITensor::SharedPtr,
tensorrt_llm::runtime::StringPtrMap<tensorrt_llm::runtime::ITensor>>
class InferenceRequest : public GenericInferenceRequest<tensorrt_llm::runtime::ITensor::SharedPtr, NamedTensor>
{
public:
using Base = GenericInferenceRequest<tensorrt_llm::runtime::ITensor::SharedPtr,
tensorrt_llm::runtime::StringPtrMap<tensorrt_llm::runtime::ITensor>>;
using Base = GenericInferenceRequest<tensorrt_llm::runtime::ITensor::SharedPtr, NamedTensor>;
using TensorPtr = Base::TensorPtr;
using TensorMap = Base::TensorMap;
InferenceRequest(uint64_t requestId)
explicit InferenceRequest(uint64_t requestId)
: Base(requestId)
{
}
InferenceRequest(TensorMap const& inputTensors, uint64_t requestId)
: Base(inputTensors, requestId)
: Base(requestId, inputTensors)
{
}
InferenceRequest(TensorMap&& inputTensors, uint64_t requestId)
: Base(inputTensors, requestId)
: Base(requestId, std::move(inputTensors))
{
}
const std::vector<int64_t> serialize() const
[[deprecated("Use direct tensor access instead")]] [[nodiscard]] TensorPtr const& getInputTensor(
std::string const& inputTensorName) const
{
std::list<int64_t> packed;
// mInputTensors
packed.push_back(static_cast<int64_t>(mInputTensors.size()));
for (auto it = mInputTensors.begin(); it != mInputTensors.end(); ++it)
{
NamedTensor nt(it->second, it->first);
auto packed_tensor = nt.serialize();
packed.push_back(static_cast<int64_t>(packed_tensor.size()));
packed.insert(packed.end(), packed_tensor.begin(), packed_tensor.end());
}
// mRequestId
packed.push_back(static_cast<int64_t>(mRequestId));
// mIsStreaming
packed.push_back(mIsStreaming ? 1 : 0);
// done
std::vector<int64_t> vpacked{
std::make_move_iterator(std::begin(packed)), std::make_move_iterator(std::end(packed))};
return vpacked;
auto it = Base::mInputTensors.find(inputTensorName);
TLLM_CHECK_WITH_INFO(it != Base::mInputTensors.end(), "Invalid input tensor name: %s", inputTensorName.c_str());
return it->second;
}
static std::shared_ptr<InferenceRequest> deserialize(const std::vector<int64_t>& packed)
[[deprecated("Use direct tensor access instead")]] void emplaceInputTensor(
std::string const& inputTensorName, TensorPtr inputTensor)
{
return InferenceRequest::deserialize(packed.data());
validateTensorName(inputTensorName);
Base::mInputTensors[inputTensorName] = std::move(inputTensor);
}
static std::shared_ptr<InferenceRequest> deserialize(const int64_t* packed_ptr)
{
int64_t num_tensors = *packed_ptr++;
TensorMap InputTensors;
for (int64_t i = 0; i < num_tensors; ++i)
{
int64_t n = *packed_ptr++;
auto inputTensor = NamedTensor::deserialize(packed_ptr);
packed_ptr += n;
auto inputTensorName = inputTensor.name;
InputTensors.emplace(inputTensorName, std::move(inputTensor.tensor));
}
uint64_t RequestId = static_cast<uint64_t>(*packed_ptr++);
bool IsStreaming = *packed_ptr++ != 0;
std::shared_ptr<InferenceRequest> ir = std::make_shared<InferenceRequest>(InputTensors, RequestId);
ir->setIsStreaming(IsStreaming);
return ir;
}
[[nodiscard]] std::vector<int64_t> serialize() const;
static std::shared_ptr<InferenceRequest> deserialize(const std::vector<int64_t>& packed);
static std::shared_ptr<InferenceRequest> deserialize(const int64_t* packed_ptr);
};
} // namespace tensorrt_llm::batch_manager

View File

@ -55,7 +55,8 @@ public:
std::optional<TensorPtr> stopWordsList = std::nullopt,
std::optional<TensorPtr> promptEmbeddingTable = std::nullopt,
std::optional<SizeType> promptVocabSize = std::nullopt, bool returnLogProbs = false,
std::optional<std::shared_ptr<VecTokens>> draftTokens = std::nullopt)
std::optional<std::shared_ptr<VecTokens>> draftTokens = std::nullopt,
std::optional<TensorPtr> draftLogits = std::nullopt)
: mRequestId(requestId)
, mPromptLen(inputTokens->size())
, mMaxNewTokens(maxNewTokens)
@ -75,6 +76,7 @@ public:
, mLogProbs(samplingConfig.beamWidth)
, mCumLogProbs(samplingConfig.beamWidth)
, mDraftTokens(draftTokens.value_or(std::make_shared<VecTokens>()))
, mDraftLogits(draftLogits)
{
mMaxSentTokenPos = mPromptLen - 1;
// Scatter the input tokens to other beam
@ -89,6 +91,13 @@ public:
TLLM_LOG_ERROR(errStr);
throw std::runtime_error(errStr);
}
if (draftLogits.has_value() && !draftTokens.has_value())
{
std::string errStr = "Draft tokens must be specified when draft logits are given.";
TLLM_LOG_ERROR(errStr);
throw std::runtime_error(errStr);
}
}
/// @brief Get total number of tokens for this req (prompt + generated)
@ -135,6 +144,13 @@ public:
return mDraftTokens;
}
/// @brief Get the logits for the draft tokens
/// @return Tensor of draft logits
std::optional<TensorPtr> getDraftLogits() const
{
return mDraftLogits;
}
/// @brief Returns true if request has draft tokens
/// @return flag
bool hasDraftTokens() const
@ -162,7 +178,7 @@ public:
/// beamTokens is expected to be of size beamWidth
void addNewTokens(const std::vector<TokenIdType>& beamTokens)
{
assert(mSamplingConfig.beamWidth == beamTokens.size());
assert(static_cast<size_t>(mSamplingConfig.beamWidth) == beamTokens.size());
for (std::size_t beam = 0; beam < beamTokens.size(); ++beam)
{
const auto outputId = beamTokens[beam];
@ -174,7 +190,7 @@ public:
/// @param generatedBeamTokens The generated tokens for all beams (vector of vector of tokens)
void setGeneratedTokens(const BeamTokens& generatedBeamTokens)
{
assert(generatedBeamTokens.size() == mSamplingConfig.beamWidth);
assert(generatedBeamTokens.size() == static_cast<size_t>(mSamplingConfig.beamWidth));
for (std::size_t beam = 0; beam < generatedBeamTokens.size(); ++beam)
{
auto& beamTokens = mTokens[beam];
@ -305,6 +321,11 @@ public:
mDraftTokens = draftTokens;
}
void setDraftLogits(const std::optional<TensorPtr>& draftLogits)
{
mDraftLogits = draftLogits;
}
RequestIdType mRequestId;
SizeType mPromptLen;
SizeType mMaxNewTokens;
@ -333,6 +354,7 @@ protected:
std::vector<VecLogProbs> mLogProbs; // [beamSize, seqLen]
VecLogProbs mCumLogProbs; // [beamSize]
std::shared_ptr<VecTokens> mDraftTokens;
std::optional<TensorPtr> mDraftLogits;
};
class LlmRequest : public GenericLlmRequest<runtime::ITensor::SharedPtr>
@ -353,9 +375,11 @@ public:
std::optional<TensorPtr> badWordsList = std::nullopt, std::optional<TensorPtr> stopWordsList = std::nullopt,
std::optional<TensorPtr> promptEmbeddingTable = std::nullopt,
std::optional<SizeType> promptVocabSize = std::nullopt, bool returnLogProbs = false,
std::optional<std::shared_ptr<VecTokens>> draftTokens = std::nullopt)
std::optional<std::shared_ptr<VecTokens>> draftTokens = std::nullopt,
std::optional<TensorPtr> draftLogits = std::nullopt)
: Base(requestId, maxNewTokens, inputTokens, samplingConfig, isStreaming, endId, padId, embeddingBias,
badWordsList, stopWordsList, promptEmbeddingTable, promptVocabSize, returnLogProbs, draftTokens)
badWordsList, stopWordsList, promptEmbeddingTable, promptVocabSize, returnLogProbs, draftTokens,
draftLogits)
{
}

View File

@ -18,11 +18,14 @@
#include "tensorrt_llm/runtime/iTensor.h"
#include <string>
namespace tensorrt_llm::batch_manager
{
template <typename TTensor>
struct GenericNamedTensor
class GenericNamedTensor
{
public:
using TensorPtr = TTensor;
TensorPtr tensor;
@ -31,24 +34,32 @@ struct GenericNamedTensor
GenericNamedTensor() = default;
~GenericNamedTensor() = default;
// Host Tensor constructor
GenericNamedTensor(
nvinfer1::DataType _type, std::vector<int64_t> const& _shape, std::string _name, const void* _data = nullptr);
GenericNamedTensor(TensorPtr _tensor, std::string _name)
: tensor(std::move(_tensor))
, name(std::move(_name))
: tensor{std::move(_tensor)}
, name{std::move(_name)}
{
}
GenericNamedTensor(std::string _name)
: name(std::move(_name))
explicit GenericNamedTensor(std::string _name)
: tensor{}
, name{std::move(_name)}
{
}
TensorPtr operator()()
{
return tensor;
}
TensorPtr const& operator()() const
{
return tensor;
}
};
struct NamedTensor : public GenericNamedTensor<tensorrt_llm::runtime::ITensor::SharedPtr>
class NamedTensor : public GenericNamedTensor<tensorrt_llm::runtime::ITensor::SharedPtr>
{
public:
using Base = GenericNamedTensor<tensorrt_llm::runtime::ITensor::SharedPtr>;
using TensorPtr = Base::TensorPtr;
@ -56,9 +67,13 @@ struct NamedTensor : public GenericNamedTensor<tensorrt_llm::runtime::ITensor::S
nvinfer1::DataType _type, std::vector<int64_t> const& _shape, std::string _name, const void* _data = nullptr);
NamedTensor(TensorPtr _tensor, std::string _name)
: Base(_tensor, _name){};
: Base(std::move(_tensor), std::move(_name)){};
explicit NamedTensor(std::string _name)
: Base(std::move(_name)){};
[[nodiscard]] std::vector<int64_t> serialize() const;
std::vector<int64_t> serialize();
static NamedTensor deserialize(const int64_t* packed);
};
} // namespace tensorrt_llm::batch_manager

View File

@ -33,16 +33,19 @@ public:
using SizeType = tensorrt_llm::runtime::SizeType;
explicit TrtGptModelOptionalParams(KvCacheConfig const& kvCacheConfig = KvCacheConfig{},
std::optional<SizeType> maxNumSequences = std::nullopt, bool enableTrtOverlap = true)
std::optional<SizeType> maxNumSequences = std::nullopt, bool enableTrtOverlap = true,
bool useContextFMHAForGeneration = false)
: kvCacheConfig{kvCacheConfig}
, maxNumSequences{maxNumSequences}
, enableTrtOverlap{enableTrtOverlap}
, useContextFMHAForGeneration(useContextFMHAForGeneration)
{
}
KvCacheConfig kvCacheConfig;
std::optional<SizeType> maxNumSequences;
bool enableTrtOverlap;
bool useContextFMHAForGeneration;
};
} // namespace tensorrt_llm::batch_manager

View File

@ -71,9 +71,7 @@ public:
// Vector of views on newTokensSteps for each token. Elements are on gpu.
// optional parameters
TensorPtr finishedSteps; // [maxTokensPerStep, batchSize, beamWidth] finished states at each generated token of
// maxTokensPerStep, on gpu
TensorPtr finished; // [batchSize, beamWidth], usually a view of finishedSteps for current token.
TensorPtr finished; // [batchSize, beamWidth],
// Set to true by decoding if any of the stop conditions are met or if DecodingInput.finished is
// true. In beam search and to determine whether to stop according to
// DecodingInput.sequenceLimitLength, on gpu

View File

@ -21,6 +21,7 @@
#include "tensorrt_llm/runtime/decodingInput.h"
#include "tensorrt_llm/runtime/decodingOutput.h"
#include "tensorrt_llm/runtime/samplingConfig.h"
#include <curand_kernel.h>
#include <cstdint>
#include <memory>
@ -55,9 +56,17 @@ public:
DecodingInput const& decodingInput, BufferManager const& manager)
= 0;
static void acceptTokens(const ITensor& targetTokenIds, const ITensor& draftTokenIds, const ITensor& contextLengths,
const ITensor& numDraftTokens, ITensor& sequenceLengths, const ITensor& finishedVec, ITensor& finishedFinal,
ITensor& finishedSum, BufferManager::CudaStreamPtr const& stream);
virtual const SamplingConfig& getSamplingConfig() = 0;
static void acceptDraftTokensByIds(const ITensor& targetTokenIds, const ITensor& draftTokenIds,
const ITensor& contextLengths, const ITensor& numDraftTokens, ITensor& sequenceLengths,
const ITensor& finishedVec, ITensor& finishedFinal, ITensor& finishedSum,
BufferManager::CudaStreamPtr const& stream);
static void acceptDraftTokensByLogits(ITensor& draftLogits, const ITensor& targetLogits, ITensor& draftProbs,
ITensor& targetProbs, const ITensor& numDraftTokens, ITensor& finished, SizeType vocabSize,
SizeType vocabSizePadded, bool useRandomAcceptThreshold, float randomAcceptThreshold,
curandState_t* curandState, BufferManager::CudaStreamPtr const& stream);
static std::unique_ptr<IGptDecoder> create(
nvinfer1::DataType dtype, size_t vocabSize, size_t vocabSizePadded, BufferManager::CudaStreamPtr const& stream);
@ -82,6 +91,11 @@ public:
void gatherTree(ITensor& finalOutputIds, DecodingOutput const& decodingOutput, DecodingInput const& decodingInput,
BufferManager const& manager) override;
const SamplingConfig& getSamplingConfig() override
{
return mSamplingConfig;
}
private:
BufferManager mManager;
@ -90,6 +104,7 @@ private:
TensorPtr mLogProbsTiled; // Buffer used to store the transpose of the logProbs. Needed because the kernels have
// been written to use that shape.
SamplingConfig mSamplingConfig;
};
inline std::unique_ptr<IGptDecoder> IGptDecoder::create(

View File

@ -181,7 +181,10 @@ private:
DecodingOutputPtr mJointDecodingOutput;
std::vector<TensorPtr> mDraftTokenIds;
std::vector<TensorPtr> mDraftLogits;
std::vector<bool> mAcceptByLogits;
TensorPtr mNumDraftTokens;
TensorPtr mCurandStates;
std::vector<SizeType> mNbSteps;
std::vector<bool> mFinished;
@ -189,6 +192,13 @@ private:
std::vector<SizeType> mMaxNewTokens;
std::vector<SizeType> mBeamWidths;
std::vector<SizeType> mGeneratedTokensPerStep;
TensorPtr mFinishedSteps; // [maxTokensPerStep, batchSize, beamWidth] finished states of type FinishedState
// for each generated token of maxTokensPerStep, on gpu
TensorPtr mDraftProbs; // [batchSize, maxDraftTokens, beamWidth, vocabPadded], temporary data for speculative
// decoding accept by logits kernel, on gpu
TensorPtr mTargetProbs; // [batchSize, maxDraftTokens+1, beamWidth, vocabPadded], temporary data for speculative
// decoding accept by logits kernel, on gpu
SizeType mMaxSequenceLength{};
SizeType mMaxKvCacheLength{};
SizeType mActualBatchSize{};

View File

@ -33,6 +33,8 @@
#include <typeinfo>
#include <vector>
#include "tensorrt_llm/common/dataType.h"
namespace tensorrt_llm::runtime
{
@ -213,19 +215,7 @@ public:
[[nodiscard]] constexpr std::size_t getSize() const noexcept
{
switch (static_cast<nvinfer1::DataType>(*this))
{
case nvinfer1::DataType::kINT64: return 8;
case nvinfer1::DataType::kINT32: [[fallthrough]];
case nvinfer1::DataType::kFLOAT: return 4;
case nvinfer1::DataType::kBF16: [[fallthrough]];
case nvinfer1::DataType::kHALF: return 2;
case nvinfer1::DataType::kBOOL: [[fallthrough]];
case nvinfer1::DataType::kUINT8: [[fallthrough]];
case nvinfer1::DataType::kINT8: [[fallthrough]];
case nvinfer1::DataType::kFP8: return 1;
}
return 0;
return tensorrt_llm::common::getDTypeSize(static_cast<nvinfer1::DataType>(*this));
}
private:

View File

@ -65,6 +65,8 @@ public:
std::optional<SizeType> maxNewTokens; // maximum number of tokens to generate for this request
std::optional<SizeType> endId; // end token id
BufferPtr draftTokens; // [generatedTokensPerStep - 1], on gpu, draft tokens from speculative decoding
std::optional<TensorPtr>
draftLogits; // [generatedTokensPerStep - 1, vocabSize], on gpu, draft tokens from speculative decoding
TensorPtr embeddingBias; // [vocabSizePadded], on gpu
TensorPtr badWordsList; // [2, badWordsLength], on gpu
TensorPtr stopWordsList; // [2, stopWordsLength], on gpu

View File

@ -22,6 +22,7 @@
#include <NvInferRuntime.h>
#include <algorithm>
#include <cstdint>
#include <functional>
#include <initializer_list>
@ -238,6 +239,39 @@ public:
//!
static std::string toString(Shape const& dims);
//!
//! \brief A convenience function to compare shapes.
//!
static bool shapeEquals(Shape const& lhs, Shape const& rhs)
{
return shapeEquals(lhs, rhs.d, rhs.nbDims);
}
//!
//! \brief A convenience function to compare shapes.
//!
template <typename T>
static bool shapeEquals(Shape const& lhs, T const* dims, SizeType count)
{
return lhs.nbDims == count && std::equal(lhs.d, lhs.d + lhs.nbDims, dims);
}
bool shapeEquals(Shape const& other) const
{
return shapeEquals(getShape(), other);
}
bool shapeEquals(std::initializer_list<SizeType> const& other) const
{
return shapeEquals(getShape(), other.begin(), other.size());
}
template <typename T>
bool shapeEquals(T const* dims, SizeType count) const
{
return shapeEquals(getShape(), dims, count);
}
protected:
ITensor() = default;

View File

@ -56,6 +56,9 @@ public:
// beam search layer
OptVec<FloatType> beamSearchDiversityRate;
OptVec<FloatType> lengthPenalty;
// speculative decoding
OptVec<FloatType> draftAcceptanceThreshold;
};
} // namespace tensorrt_llm::runtime

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:9e6a5d7dba399049a4da9ca729153e5a6080986782a314b867e7635454eb36de
size 1705954
oid sha256:ba982afff27c597c9f5f25bec4ed37debd883c7be2107b47776a014075899fbd
size 1719266

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:64fae7bca97be7c3067b4544da0c3d79621ec3632c10e39b7a005d886702e8eb
size 1706098
oid sha256:04ec1f2f45dde1ef6b6b0f605e79715eebed38b19b4d833fcb668d2cb71f8a03
size 1733118

View File

@ -1,3 +1,3 @@
02375d908e57e2194e3f28a4e83dd963 libtensorrt_llm_batch_manager_static.a
e5c4994ecc347d808f6d38fb686b5cf1 libtensorrt_llm_batch_manager_static.pre_cxx11.a
cd83045d7c127af5f907efd7c710bd6fe1f90ec4 commit
aab384dfc59de5df4c7ecf53e30d03e9 libtensorrt_llm_batch_manager_static.a
e0074afa6959c896f1cbc7ab90872058 libtensorrt_llm_batch_manager_static.pre_cxx11.a
c7450aa071e91659a3e2855c0cca21021f96ada8 commit

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:f3cca913fc62df4119e4df10921be97086714740148f54c528da7bb2826f67ba
size 1617426
oid sha256:546c9e2b79cb3cf2623876902ef2d40c65925157d43850b2505eedf274e060a1
size 1638840

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:3d633874e8b32a56758bf8bbdc0955ed8c5d43d531ba330a2274bcec13e1c89f
size 1620144
oid sha256:935a706ce0d107f8c226566a50946a0f0e35ce926c98b7a12b000b3d72e5f0b6
size 1635602

View File

@ -1,2 +1,2 @@
f379e62b3f69afa4bd1d8e5551a6ede4 libtensorrt_llm_batch_manager_static.a
92f44b2834d39c2c62a9b0bd0549b159 libtensorrt_llm_batch_manager_static.pre_cxx11.a
8e0c5b31d579f4118b84a34ffb00c15a libtensorrt_llm_batch_manager_static.a
885ea7b9f594d7aa9cc9018527b95f6d libtensorrt_llm_batch_manager_static.pre_cxx11.a

View File

@ -45,11 +45,13 @@ extern bool CHECK_DEBUG_ENABLED;
: tensorrt_llm::common::throwRuntimeError(__FILE__, __LINE__, #val); \
} while (0)
#define TLLM_CHECK_WITH_INFO(val, info) \
#define TLLM_CHECK_WITH_INFO(val, info, ...) \
do \
{ \
TLLM_LIKELY(static_cast<bool>(val)) ? ((void) 0) \
: tensorrt_llm::common::throwRuntimeError(__FILE__, __LINE__, info); \
TLLM_LIKELY(static_cast<bool>(val)) \
? ((void) 0) \
: tensorrt_llm::common::throwRuntimeError( \
__FILE__, __LINE__, tensorrt_llm::common::fmtstr(info, ##__VA_ARGS__)); \
} while (0)
#define TLLM_CHECK_DEBUG(val) \

View File

@ -57,6 +57,7 @@ CUDADriverWrapper::CUDADriverWrapper()
*(void**) (&_cuModuleLoadData) = load_sym(handle, "cuModuleLoadData");
*(void**) (&_cuLinkCreate) = load_sym(handle, "cuLinkCreate_v2");
*(void**) (&_cuModuleGetFunction) = load_sym(handle, "cuModuleGetFunction");
*(void**) (&_cuModuleGetGlobal) = load_sym(handle, "cuModuleGetGlobal_v2");
*(void**) (&_cuLinkAddFile) = load_sym(handle, "cuLinkAddFile_v2");
*(void**) (&_cuLinkAddData) = load_sym(handle, "cuLinkAddData_v2");
*(void**) (&_cuLaunchCooperativeKernel) = load_sym(handle, "cuLaunchCooperativeKernel");
@ -109,6 +110,11 @@ CUresult CUDADriverWrapper::cuModuleGetFunction(CUfunction* hfunc, CUmodule hmod
return (*_cuModuleGetFunction)(hfunc, hmod, name);
}
CUresult CUDADriverWrapper::cuModuleGetGlobal(CUdeviceptr* dptr, size_t* bytes, CUmodule hmod, const char* name) const
{
return (*_cuModuleGetGlobal)(dptr, bytes, hmod, name);
}
CUresult CUDADriverWrapper::cuLinkAddFile(CUlinkState state, CUjitInputType type, const char* path,
unsigned int numOptions, CUjit_option* options, void** optionValues) const
{

View File

@ -54,6 +54,8 @@ public:
CUresult cuModuleGetFunction(CUfunction* hfunc, CUmodule hmod, const char* name) const;
CUresult cuModuleGetGlobal(CUdeviceptr* dptr, size_t* bytes, CUmodule hmod, const char* name) const;
CUresult cuLinkAddFile(CUlinkState state, CUjitInputType type, const char* path, unsigned int numOptions,
CUjit_option* options, void** optionValues) const;
@ -78,6 +80,7 @@ private:
CUresult (*_cuLinkCreate)(unsigned int, CUjit_option*, void**, CUlinkState*);
CUresult (*_cuModuleLoadData)(CUmodule*, const void*);
CUresult (*_cuModuleGetFunction)(CUfunction*, CUmodule, const char*);
CUresult (*_cuModuleGetGlobal)(CUdeviceptr*, size_t*, CUmodule, const char*);
CUresult (*_cuLinkAddFile)(CUlinkState, CUjitInputType, const char*, unsigned int, CUjit_option*, void**);
CUresult (*_cuLinkAddData)(
CUlinkState, CUjitInputType, void*, size_t, const char*, unsigned int, CUjit_option*, void**);

View File

@ -287,6 +287,16 @@ inline int getMultiProcessorCount()
return multi_processor_count;
}
inline int getMaxSharedMemoryPerBlockOptin()
{
int device_id;
int max_shared_memory_per_block;
check_cuda_error(cudaGetDevice(&device_id));
check_cuda_error(
cudaDeviceGetAttribute(&max_shared_memory_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device_id));
return max_shared_memory_per_block;
}
inline int divUp(int a, int n)
{
return (a + n - 1) / n;

View File

@ -0,0 +1,41 @@
/*
* Copyright (c) 1993-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <NvInferRuntime.h>
namespace tensorrt_llm::common
{
constexpr static size_t getDTypeSize(nvinfer1::DataType type)
{
switch (type)
{
case nvinfer1::DataType::kINT64: return 8;
case nvinfer1::DataType::kINT32: [[fallthrough]];
case nvinfer1::DataType::kFLOAT: return 4;
case nvinfer1::DataType::kBF16: [[fallthrough]];
case nvinfer1::DataType::kHALF: return 2;
case nvinfer1::DataType::kBOOL: [[fallthrough]];
case nvinfer1::DataType::kUINT8: [[fallthrough]];
case nvinfer1::DataType::kINT8: [[fallthrough]];
case nvinfer1::DataType::kFP8: return 1;
}
return 0;
}
} // namespace tensorrt_llm::common

View File

@ -38,7 +38,7 @@ std::string vformat(char const* fmt, va_list args)
std::string stringBuf(size, char{});
auto const size2 = std::vsnprintf(&stringBuf[0], size + 1, fmt, args);
TLLM_CHECK_WITH_INFO(size2 == size, std::strerror(errno));
TLLM_CHECK_WITH_INFO(size2 == size, std::string(std::strerror(errno)));
return stringBuf;
}

View File

@ -42,6 +42,16 @@ static inline std::basic_ostream<char>& operator<<(std::basic_ostream<char>& str
return stream;
}
inline std::string fmtstr(std::string const& s)
{
return s;
}
inline std::string fmtstr(std::string&& s)
{
return s;
}
#if defined(_MSC_VER)
std::string fmtstr(char const* format, ...);
#else

View File

@ -37,6 +37,7 @@
#include "cutlass/array.h"
#include "cutlass/cutlass.h"
#include "cutlass/epilogue/thread/activation.h"
#include "cutlass/epilogue/thread/linear_combination_generic.h"
#include "cutlass/epilogue/thread/scale_type.h"
#include "cutlass/functional.h"
#include "cutlass/half.h"

View File

@ -48,11 +48,23 @@ struct EpilogueOpBiasFtGelu
{
};
struct EpilogueOpDefaultSilu
{
};
struct EpilogueOpDefaultReLU
{
};
struct EpilogueOpDefaultFtGelu
{
};
struct EpilogueOpBias
{
};
struct EpilogueOpNoBias
struct EpilogueOpDefault
{
};
@ -61,40 +73,66 @@ struct Epilogue
{
};
constexpr auto BiasScaleMode = cutlass::epilogue::thread::ScaleType::NoBetaScaling;
template <typename ElementType, int ElementsPerVectorAccess, typename ElementAccumulator>
struct Epilogue<ElementType, ElementsPerVectorAccess, ElementAccumulator, EpilogueOpBiasSilu>
{
using Op = cutlass::epilogue::thread::LinearCombinationSilu<ElementType, ElementsPerVectorAccess,
ElementAccumulator, ElementAccumulator, cutlass::epilogue::thread::ScaleType::NoBetaScaling>;
ElementAccumulator, ElementAccumulator, BiasScaleMode>;
};
template <typename ElementType, int ElementsPerVectorAccess, typename ElementAccumulator>
struct Epilogue<ElementType, ElementsPerVectorAccess, ElementAccumulator, EpilogueOpBiasReLU>
{
using Op = cutlass::epilogue::thread::LinearCombinationRelu<ElementType, ElementsPerVectorAccess,
ElementAccumulator, ElementAccumulator, cutlass::epilogue::thread::ScaleType::NoBetaScaling>;
ElementAccumulator, ElementAccumulator, BiasScaleMode>;
};
template <typename ElementType, int ElementsPerVectorAccess, typename ElementAccumulator>
struct Epilogue<ElementType, ElementsPerVectorAccess, ElementAccumulator, EpilogueOpBiasFtGelu>
{
using Op = cutlass::epilogue::thread::LinearCombinationGeneric<cutlass::epilogue::thread::GELU_taylor, ElementType,
ElementsPerVectorAccess, ElementAccumulator, ElementAccumulator,
cutlass::epilogue::thread::ScaleType::NoBetaScaling, cutlass::FloatRoundStyle::round_to_nearest, true>;
ElementsPerVectorAccess, ElementAccumulator, ElementAccumulator, BiasScaleMode,
cutlass::FloatRoundStyle::round_to_nearest, true>;
};
template <typename ElementType, int ElementsPerVectorAccess, typename ElementAccumulator>
struct Epilogue<ElementType, ElementsPerVectorAccess, ElementAccumulator, EpilogueOpBias>
{
using Op = cutlass::epilogue::thread::LinearCombination<ElementType, ElementsPerVectorAccess, ElementAccumulator,
ElementAccumulator, cutlass::epilogue::thread::ScaleType::NoBetaScaling>;
ElementAccumulator, BiasScaleMode>;
};
constexpr auto DefaultScaleMode = cutlass::epilogue::thread::ScaleType::Default;
template <typename ElementType, int ElementsPerVectorAccess, typename ElementAccumulator>
struct Epilogue<ElementType, ElementsPerVectorAccess, ElementAccumulator, EpilogueOpDefaultSilu>
{
using Op = cutlass::epilogue::thread::LinearCombinationSilu<ElementType, ElementsPerVectorAccess,
ElementAccumulator, ElementAccumulator, DefaultScaleMode>;
};
template <typename ElementType, int ElementsPerVectorAccess, typename ElementAccumulator>
struct Epilogue<ElementType, ElementsPerVectorAccess, ElementAccumulator, EpilogueOpNoBias>
struct Epilogue<ElementType, ElementsPerVectorAccess, ElementAccumulator, EpilogueOpDefaultReLU>
{
using Op = cutlass::epilogue::thread::LinearCombinationRelu<ElementType, ElementsPerVectorAccess,
ElementAccumulator, ElementAccumulator, DefaultScaleMode>;
};
template <typename ElementType, int ElementsPerVectorAccess, typename ElementAccumulator>
struct Epilogue<ElementType, ElementsPerVectorAccess, ElementAccumulator, EpilogueOpDefaultFtGelu>
{
using Op = cutlass::epilogue::thread::LinearCombinationGeneric<cutlass::epilogue::thread::GELU_taylor, ElementType,
ElementsPerVectorAccess, ElementAccumulator, ElementAccumulator, DefaultScaleMode,
cutlass::FloatRoundStyle::round_to_nearest, true>;
};
template <typename ElementType, int ElementsPerVectorAccess, typename ElementAccumulator>
struct Epilogue<ElementType, ElementsPerVectorAccess, ElementAccumulator, EpilogueOpDefault>
{
using Op = cutlass::epilogue::thread::LinearCombination<ElementType, ElementsPerVectorAccess, ElementAccumulator,
ElementAccumulator, cutlass::epilogue::thread::ScaleType::Default>;
ElementAccumulator, DefaultScaleMode>;
};
} // namespace cutlass_extensions

View File

@ -0,0 +1,73 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*! \file
\brief Scheduler for grouped GEMM
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/gemm/kernel/gemm_grouped_problem_visitor.h"
#include "cutlass/matrix_coord.h"
#include "cutlass_extensions/gemm/kernel/gemm_moe_problem_visitor.h"
#include "cutlass_extensions/gemm/kernel/moe_problem_visitor.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass
{
namespace gemm
{
namespace kernel
{
/// Visitor class to abstract away the algorithm for iterating over tiles
template <typename ThreadblockShape, GroupScheduleMode GroupScheduleMode_, int PrefetchTileCount, int ThreadCount,
bool Transposed = false>
struct GemmMoeProblemVisitor
: public MoeProblemVisitor<detail::GemmGroupedProblemSizeHelper<ThreadblockShape, Transposed>, ThreadblockShape,
GroupScheduleMode_, PrefetchTileCount, ThreadCount>
{
static bool const kTransposed = Transposed;
using ProblemSizeHelper = detail::GemmGroupedProblemSizeHelper<ThreadblockShape, Transposed>;
using Base
= MoeProblemVisitor<ProblemSizeHelper, ThreadblockShape, GroupScheduleMode_, PrefetchTileCount, ThreadCount>;
using Params = typename Base::Params;
using SharedStorage = typename Base::SharedStorage;
//
// Methods
//
CUTLASS_DEVICE
GemmMoeProblemVisitor(Params const& params_, SharedStorage& shared_storage_, int32_t block_idx)
: Base(params_, shared_storage_, block_idx)
{
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace kernel
} // namespace gemm
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,522 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*! \file
\brief
*/
#pragma once
#include "cutlass/complex.h"
#include "cutlass/cutlass.h"
#include "cutlass/fast_math.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/matrix_coord.h"
#include "cutlass/semaphore.h"
#include "cutlass/gemm/kernel/gemm_transpose_operands.h"
#include "cutlass/layout/matrix.h"
#include "cutlass/trace.h"
#include "cutlass_extensions/gemm/kernel/gemm_moe_problem_visitor.h"
#include "cutlass_extensions/tile_interleaved_layout.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass
{
namespace gemm
{
namespace kernel
{
/////////////////////////////////////////////////////////////////////////////////////////////////
// This section exists to that we can use the same kernel code for regular gemm and dequantizing gemms.
// It will dispatch to the dequantizing gemm if the Mma type has an Iterator for scales in global.
template <typename...>
using void_t = void;
template <typename Mma, typename = void>
struct use_dq_gemm : platform::false_type
{
};
template <typename Mma>
struct use_dq_gemm<Mma, void_t<typename Mma::IteratorScale>> : platform::true_type
{
};
/////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate
typename Epilogue_, ///! Epilogue
typename ThreadblockSwizzle_, ///! Threadblock swizzling function
typename KernelArch, ///! The Architecture this kernel is compiled for. Used since SIMT kernels lose top-level
/// arch.
GroupScheduleMode GroupScheduleMode_ ///! Type of scheduling to perform
>
struct MoeFCGemm
{
public:
using Mma = Mma_;
using Epilogue = Epilogue_;
using EpilogueOutputOp = typename Epilogue::OutputOp;
using ThreadblockSwizzle = ThreadblockSwizzle_;
static GroupScheduleMode const kGroupScheduleMode = GroupScheduleMode_;
static bool const kTransposed = false;
// Optional transpose
using MapArguments = kernel::detail::MapArguments<typename Mma::IteratorA::Element, typename Mma::IteratorA::Layout,
Mma::kTransformA, Mma::IteratorA::AccessType::kElements, typename Mma::IteratorB::Element,
typename Mma::IteratorB::Layout, Mma::kTransformB, Mma::IteratorB::AccessType::kElements, typename Mma::LayoutC,
kTransposed>;
// Public-facing type definitions related to operand element type, layout, and complex conjugate
// operation. Must interact with the 'kTransposed' notion.
static_assert(!kTransposed, "Transpose problem not supported");
using ElementA = typename MapArguments::ElementA;
using LayoutA = typename MapArguments::LayoutA;
using ElementB = typename MapArguments::ElementB;
using LayoutB = typename MapArguments::LayoutB;
using ElementC = typename Epilogue::OutputTileIterator::Element;
using LayoutC = typename MapArguments::LayoutC;
using ElementScale = ElementC;
static ComplexTransform const kTransformA = MapArguments::kTransformA;
static ComplexTransform const kTransformB = MapArguments::kTransformB;
// Type definitions about the mainloop.
using Operator = typename Mma::Operator;
using OperatorClass = typename Mma::Operator::OperatorClass;
using ThreadblockShape = typename Mma::Shape;
using WarpShape = typename Mma::Operator::Shape;
using InstructionShape = typename Mma::Policy::Operator::InstructionShape;
using ArchTag = typename Mma::ArchTag;
static int const kStages = Mma::kStages;
static int const kAlignmentA = MapArguments::kAlignmentA;
static int const kAlignmentB = MapArguments::kAlignmentB;
static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess;
/// Warp count (concept: GemmShape)
using WarpCount = typename Mma::WarpCount;
static int const kThreadCount = 32 * WarpCount::kCount;
using ProblemVisitor
= GemmMoeProblemVisitor<ThreadblockShape, kGroupScheduleMode, kThreadCount, kThreadCount, kTransposed>;
//
// Structures
//
/// Argument structure
struct Arguments
{
//
// Data members
//
int problem_count;
int threadblock_count;
int group_size;
typename EpilogueOutputOp::Params output_op;
ElementA* ptr_A;
ElementB* ptr_B;
ElementScale* weight_scales;
ElementC* ptr_C;
ElementC* ptr_D;
int64_t* total_rows_before_expert;
int64_t gemm_n;
int64_t gemm_k;
// Only used by device-level operator
GemmCoord* host_problem_sizes;
//
// Methods
//
/// Default ctor
CUTLASS_HOST_DEVICE
Arguments()
: problem_count(0)
, threadblock_count(0)
, ptr_A(nullptr)
, ptr_B(nullptr)
, weight_scales(nullptr)
, ptr_C(nullptr)
, ptr_D(nullptr)
, total_rows_before_expert(nullptr)
, gemm_n(0)
, gemm_k(0)
, host_problem_sizes(nullptr)
{
}
/// Ctor
CUTLASS_HOST_DEVICE
Arguments(int problem_count, int threadblock_count, int group_size, typename EpilogueOutputOp::Params output_op,
const ElementA* ptr_A, const ElementB* ptr_B, const ElementScale* weight_scales, const ElementC* ptr_C,
ElementC* ptr_D, int64_t* total_rows_before_expert, int64_t gemm_n, int64_t gemm_k,
GemmCoord* host_problem_sizes = nullptr)
: problem_count(problem_count)
, threadblock_count(threadblock_count)
, group_size(group_size)
, output_op(output_op)
, ptr_A(const_cast<ElementA*>(ptr_A))
, ptr_B(const_cast<ElementB*>(ptr_B))
, weight_scales(const_cast<ElementScale*>(weight_scales))
, ptr_C(const_cast<ElementC*>(ptr_C))
, ptr_D(ptr_D)
, total_rows_before_expert(total_rows_before_expert)
, gemm_n(gemm_n)
, gemm_k(gemm_k)
, host_problem_sizes(nullptr)
{
if (platform::is_same<uint8_t, ElementB>::value || platform::is_same<uint4b_t, ElementB>::value)
{
assert(weight_scales);
}
}
};
//
// Structure for precomputing values in host memory and passing to kernels
//
/// Parameters structure
struct Params
{
typename ProblemVisitor::Params problem_visitor;
int threadblock_count;
int group_size;
typename EpilogueOutputOp::Params output_op;
ElementA* ptr_A;
ElementB* ptr_B;
ElementScale* weight_scales;
ElementC* ptr_C;
ElementC* ptr_D;
//
// Methods
//
CUTLASS_HOST_DEVICE
Params()
: ptr_A(nullptr)
, ptr_B(nullptr)
, weight_scales(nullptr)
, ptr_C(nullptr)
, ptr_D(nullptr)
{
}
CUTLASS_HOST_DEVICE
Params(Arguments const& args, void* workspace = nullptr, int tile_count = 0)
: problem_visitor(
args.total_rows_before_expert, args.gemm_n, args.gemm_k, args.problem_count, workspace, tile_count)
, threadblock_count(args.threadblock_count)
, group_size(args.group_size)
, output_op(args.output_op)
, ptr_A(args.ptr_A)
, ptr_B(args.ptr_B)
, weight_scales(args.weight_scales)
, ptr_C(args.ptr_C)
, ptr_D(args.ptr_D)
{
}
CUTLASS_HOST_DEVICE
void update(Arguments const& args, void* workspace = nullptr, int tile_count = 0)
{
problem_visitor = typename ProblemVisitor::Params(
args.total_rows_before_expert, args.gemm_n, args.gemm_k, args.problem_count, workspace, tile_count);
threadblock_count = args.threadblock_count;
output_op = args.output_op;
ptr_A = args.ptr_A;
ptr_B = args.ptr_B;
weight_scales = args.weight_scales;
ptr_C = args.ptr_C;
ptr_D = args.ptr_D;
}
};
/// Shared memory storage structure
union SharedStorage
{
typename ProblemVisitor::SharedStorage problem_visitor;
typename Mma::SharedStorage main_loop;
typename Epilogue::SharedStorage epilogue;
};
public:
//
// Methods
//
CUTLASS_DEVICE
MoeFCGemm() {}
/// Determines whether kernel satisfies alignment
static Status can_implement(cutlass::gemm::GemmCoord const& problem_size)
{
return Status::kSuccess;
}
static Status can_implement(Arguments const& args)
{
if (platform::is_same<uint8_t, ElementB>::value || platform::is_same<uint4b_t, ElementB>::value)
{
if (args.weight_scales == nullptr)
{
CUTLASS_TRACE_HOST("MoeFCGemm::can_implement() - weight scales are required for uint8_t and uint4b_t");
return Status::kInvalid;
}
}
else if (args.weight_scales != nullptr)
{
CUTLASS_TRACE_HOST(
"MoeFCGemm::can_implement() - weight scales are ignored for all types except uint8_t and uint4b_t");
return Status::kInvalid;
}
else if (args.group_size != args.gemm_k)
{
CUTLASS_TRACE_HOST("MoeFCGemm::can_implement() - scale shape should be (1, gemm_n)");
return Status::kInvalid;
}
// Handle the case the input is too short
else if (args.gemm_n < Mma::IteratorB::AccessType::kElements)
{
CUTLASS_TRACE_HOST("MoeFCGemm::can_implement() - gemm_n is smaller than the input alignment");
return Status::kInvalid;
}
return Status::kSuccess;
}
static size_t get_extra_workspace_size(Arguments const& args, cutlass::gemm::GemmCoord const& grid_tiled_shape)
{
return 0;
}
// The dummy template parameter is not used and exists so that we can compile this code using
// a standard earlier than C++17. Prior to C++17, fully specialized templates HAD to exists in
// a namespace
template <bool B, typename dummy = void>
struct KernelRunner
{
CUTLASS_DEVICE
static void run_kernel(Params const& params, SharedStorage& shared_storage)
{
CUTLASS_NOT_IMPLEMENTED();
}
};
template <typename dummy>
struct KernelRunner<true, dummy>
{
CUTLASS_DEVICE
static void run_kernel(Params const& params, SharedStorage& shared_storage)
{
//
// These types shadow the type-level definitions and support the ability to implement
// a 'transposed' GEMM that computes the transposed problems.
//
using ElementA = typename Mma::IteratorA::Element;
using LayoutA = typename Mma::IteratorA::Layout;
using ElementB = typename Mma::IteratorB::Element;
using LayoutB = typename Mma::IteratorB::Layout;
using ElementC = typename Epilogue::OutputTileIterator::Element;
using LayoutC = typename Epilogue::OutputTileIterator::Layout;
static constexpr int kInterleave = Mma::IteratorB::Shape::kRow / Mma::Shape::kK;
static_assert(platform::is_same<LayoutB, layout::RowMajor>::value && kInterleave == 1
|| platform::is_same<LayoutB, layout::ColumnMajor>::value && kInterleave >= 1,
"B must be row major/col major OR col major interleaved.");
//
// Problem visitor.
//
ProblemVisitor problem_visitor(params.problem_visitor, shared_storage.problem_visitor, blockIdx.x);
const int64_t gemm_k = params.problem_visitor.gemm_k;
const int64_t gemm_n = params.problem_visitor.gemm_n;
int64_t bytes_per_expert_matrix = (gemm_k * gemm_n / 8) * cutlass::sizeof_bits<ElementB>::value;
// Outer 'persistent' loop to iterate over tiles
int loop = 0;
while (problem_visitor.next_tile())
{
loop++;
GemmCoord problem_size = problem_visitor.problem_size();
int32_t problem_idx = problem_visitor.problem_index();
int32_t cta_idx = int32_t(problem_visitor.threadblock_idx());
GemmCoord grid_shape = problem_visitor.grid_shape(problem_size);
cutlass::gemm::GemmCoord threadblock_offset(
int(cta_idx / grid_shape.n()) * Mma::Shape::kM, int(cta_idx % grid_shape.n()) * Mma::Shape::kN, 0);
// Load element pointers. Exchange pointers and strides if working on the transpose
const int64_t rows_to_jump
= problem_idx == 0 ? 0 : params.problem_visitor.last_row_for_problem[problem_idx - 1];
ElementA* ptr_A = reinterpret_cast<ElementA*>(params.ptr_A) + rows_to_jump * gemm_k;
typename LayoutA::LongIndex ldm_A = gemm_k;
char* byte_ptr_B = ((char*) params.ptr_B) + problem_idx * bytes_per_expert_matrix;
ElementB* ptr_B = reinterpret_cast<ElementB*>(byte_ptr_B);
typename LayoutB::LongIndex ldm_B
= platform::is_same<layout::RowMajor, LayoutB>::value ? gemm_n : gemm_k * kInterleave;
// Compute initial location in logical coordinates
cutlass::MatrixCoord tb_offset_A{
threadblock_offset.m(),
0,
};
cutlass::MatrixCoord tb_offset_B{0, threadblock_offset.n() / kInterleave};
cutlass::MatrixCoord tb_offset_scale{0, threadblock_offset.n()};
// Compute position within threadblock
int thread_idx = threadIdx.x;
// Construct iterators to A and B operands
typename Mma::IteratorA iterator_A(
LayoutA(ldm_A), ptr_A, {problem_size.m(), problem_size.k()}, thread_idx, tb_offset_A);
typename Mma::IteratorB iterator_B(LayoutB(ldm_B), ptr_B,
{problem_size.k() * kInterleave, problem_size.n() / kInterleave}, thread_idx, tb_offset_B);
typename Mma::FragmentC accumulators;
accumulators.clear();
// Broadcast the warp_id computed by lane 0 to ensure dependent code
// is compiled as warp-uniform.
int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
int lane_idx = threadIdx.x % 32;
//
// Matrix multiply phase
//
// Construct thread-scoped matrix multiply
auto CreateMMA = [&]()
{
if constexpr (use_dq_gemm<Mma>::value)
return Mma(shared_storage.main_loop, params.group_size, thread_idx, warp_idx, lane_idx);
else
return Mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx);
};
Mma mma = CreateMMA();
// Compute threadblock-scoped matrix multiply-add
int gemm_k_iterations = (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK;
// Wait for all threads to finish their epilogue phases from the previous tile.
__syncthreads();
// Compute threadblock-scoped matrix multiply-add
ElementScale* weight_scale_ptr = params.weight_scales + problem_idx * problem_size.n();
if constexpr (use_dq_gemm<Mma>::value)
{
const MatrixCoord scale_extent = {1, problem_size.n()};
typename Mma::IteratorScale iterator_scale(Mma::IteratorScale::Layout(scale_extent.column()),
weight_scale_ptr, scale_extent, thread_idx, tb_offset_scale);
mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, iterator_scale, accumulators);
}
else
{
mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators);
}
//
// Epilogue
//
EpilogueOutputOp output_op(params.output_op);
ElementC* ptr_C = reinterpret_cast<ElementC*>(params.ptr_C) + problem_idx * gemm_n;
ElementC* ptr_D = reinterpret_cast<ElementC*>(params.ptr_D) + rows_to_jump * gemm_n;
LayoutC layout_C(0);
LayoutC layout_D(gemm_n);
typename Epilogue::OutputTileIterator::Params params_C(layout_C);
typename Epilogue::OutputTileIterator::Params params_D(layout_D);
// Tile iterator loading from source tensor.
typename Epilogue::OutputTileIterator iterator_C(
params_C, ptr_C, problem_size.mn(), thread_idx, threadblock_offset.mn());
// Tile iterator writing to destination tensor.
typename Epilogue::OutputTileIterator iterator_D(
params_D, ptr_D, problem_size.mn(), thread_idx, threadblock_offset.mn());
Epilogue epilogue(shared_storage.epilogue, thread_idx, warp_idx, lane_idx);
// Execute the epilogue operator to update the destination tensor.
epilogue(output_op, iterator_D, accumulators, iterator_C);
// Next tile
problem_visitor.advance(gridDim.x);
}
}
};
/*
To improve compilation speed, we do not compile the device operator if the CUDA_ARCH does not correspond
to the ArchTag of the cutlass kernel operator.
*/
/// Executes one GEMM
CUTLASS_DEVICE
void operator()(Params const& params, SharedStorage& shared_storage)
{
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 700) && (__CUDA_ARCH__ < 750)
static constexpr bool compile_needed = platform::is_same<KernelArch, arch::Sm70>::value;
KernelRunner<compile_needed>::run_kernel(params, shared_storage);
#elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 750) && (__CUDA_ARCH__ < 800)
static constexpr bool compile_needed = platform::is_same<KernelArch, arch::Sm75>::value;
KernelRunner<compile_needed>::run_kernel(params, shared_storage);
#elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && (__CUDA_ARCH__ < 900)
static constexpr bool compile_needed = platform::is_same<KernelArch, arch::Sm80>::value;
KernelRunner<compile_needed>::run_kernel(params, shared_storage);
#else
CUTLASS_NOT_IMPLEMENTED();
#endif
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace kernel
} // namespace gemm
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,344 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*! \file
\brief Base scheduler for grouped problems, using MoE
*/
#pragma once
#include "cutlass/gemm/kernel/grouped_problem_visitor.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass
{
namespace gemm
{
namespace kernel
{
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Visitor class to abstract away the algorithm for iterating over tiles
template <typename ProblemSizeHelper, typename ThreadblockShape_>
struct BaseMoeProblemVisitor
{
using ThreadblockShape = ThreadblockShape_;
struct ProblemInfo
{
static int32_t const kNoPrefetchEntry = -1;
int32_t problem_idx;
int32_t problem_start;
CUTLASS_DEVICE
ProblemInfo()
: problem_idx(kNoPrefetchEntry)
, problem_start(kNoPrefetchEntry)
{
}
CUTLASS_DEVICE
ProblemInfo(int32_t problem_idx_, int32_t problem_start_)
: problem_idx(problem_idx_)
, problem_start(problem_start_)
{
}
};
struct Params
{
int64_t const* last_row_for_problem;
int64_t gemm_n;
int64_t gemm_k;
int32_t problem_count;
void const* workspace;
int32_t tile_count;
//
// Methods
//
/// Ctor
CUTLASS_HOST_DEVICE
Params()
: last_row_for_problem(nullptr)
, gemm_n(0)
, gemm_k(0)
, problem_count(0)
, workspace(nullptr)
, tile_count(0)
{
}
/// Ctor
CUTLASS_HOST_DEVICE
Params(int64_t const* last_row_for_problem, int64_t gemm_n, int64_t gemm_k, int32_t problem_count,
void const* workspace = nullptr, int32_t tile_count = 0)
: last_row_for_problem(last_row_for_problem)
, gemm_n(gemm_n)
, gemm_k(gemm_k)
, problem_count(problem_count)
, workspace(workspace)
, tile_count(tile_count)
{
}
};
Params const& params;
int32_t tile_idx;
int32_t problem_tile_start;
int32_t problem_idx;
//
// Methods
//
CUTLASS_DEVICE
BaseMoeProblemVisitor(Params const& params_, int32_t block_idx)
: params(params_)
, tile_idx(block_idx)
, problem_tile_start(0)
, problem_idx(0)
{
}
/// Get the grid shape
CUTLASS_HOST_DEVICE
static cutlass::gemm::GemmCoord grid_shape(const cutlass::gemm::GemmCoord& problem)
{
return cutlass::gemm::GemmCoord(((problem.m() - 1 + ThreadblockShape::kM) / ThreadblockShape::kM),
((problem.n() - 1 + ThreadblockShape::kN) / ThreadblockShape::kN), 1);
}
/// Gets the global tile index
CUTLASS_HOST_DEVICE
int32_t tile_index() const
{
return tile_idx;
}
/// Gets the index of the problem
CUTLASS_HOST_DEVICE
int32_t problem_index() const
{
return problem_idx;
}
CUTLASS_HOST_DEVICE
int32_t threadblock_idx() const
{
return tile_idx - problem_tile_start;
}
CUTLASS_DEVICE
void advance(int32_t grid_size)
{
tile_idx += grid_size;
}
CUTLASS_HOST_DEVICE
static void possibly_transpose_problem(cutlass::gemm::GemmCoord& problem)
{
ProblemSizeHelper::possibly_transpose_problem(problem);
}
/// Returns the problem size for the current problem
CUTLASS_HOST_DEVICE
cutlass::gemm::GemmCoord problem_size() const
{
return problem_size(problem_idx);
}
CUTLASS_HOST_DEVICE
cutlass::gemm::GemmCoord problem_size(int idx) const
{
const int64_t prev_problem_row = idx == 0 ? 0 : params.last_row_for_problem[idx - 1];
const int64_t current_problem_row = params.last_row_for_problem[idx];
const int64_t gemm_m = current_problem_row - prev_problem_row;
GemmCoord problem(GemmCoord::Index(gemm_m), GemmCoord::Index(params.gemm_n), GemmCoord::Index(params.gemm_k));
ProblemSizeHelper::possibly_transpose_problem(problem);
return problem;
}
CUTLASS_HOST_DEVICE
static int32_t tile_count(const cutlass::gemm::GemmCoord& grid)
{
return ProblemSizeHelper::tile_count(grid);
}
static int32_t group_tile_count(const cutlass::gemm::GemmCoord* host_problem_sizes_ptr, int32_t problem_count)
{
int32_t total_tiles = 0;
for (int32_t i = 0; i < problem_count; ++i)
{
auto problem = host_problem_sizes_ptr[i];
possibly_transpose_problem(problem);
auto grid = grid_shape(problem);
total_tiles += tile_count(grid);
}
return total_tiles;
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
template <typename ProblemSizeHelper, typename ThreadblockShape, GroupScheduleMode GroupScheduleMode_,
int PrefetchTileCount, int ThreadCount>
struct MoeProblemVisitor;
/////////////////////////////////////////////////////////////////////////////////////////////////
// ProblemVisitor that performs all scheduling on device
//
template <typename ProblemSizeHelper, typename ThreadblockShape, int PrefetchTileCount, int ThreadCount>
struct MoeProblemVisitor<ProblemSizeHelper, ThreadblockShape, GroupScheduleMode::kDeviceOnly, PrefetchTileCount,
ThreadCount> : public BaseMoeProblemVisitor<ProblemSizeHelper, ThreadblockShape>
{
using Base = BaseMoeProblemVisitor<ProblemSizeHelper, ThreadblockShape>;
using Params = typename Base::Params;
static int const kThreadCount = ThreadCount;
static bool const kRequiresPrecomputation = false;
static int const kThreadsPerWarp = 32;
struct SharedStorage
{
};
// Final tile of the problem loaded by this thread. Each thread will hold
// a separate value.
int32_t problem_ending_tile;
SharedStorage& shared_storage;
//
// Methods
//
CUTLASS_DEVICE
MoeProblemVisitor(Params const& params_, SharedStorage& shared_storage_, int32_t block_idx)
: Base(params_, block_idx)
, problem_ending_tile(0)
, shared_storage(shared_storage_)
{
this->problem_idx = -1 * kThreadsPerWarp;
this->problem_tile_start = 0;
}
CUTLASS_DEVICE
bool next_tile()
{
// Check whether the tile to compute is within the range of the current problem.
int32_t problem_tile_end = __shfl_sync(0xffffffff, problem_ending_tile, this->problem_idx % kThreadsPerWarp);
if (this->tile_idx < problem_tile_end)
{
return true;
}
// Check whether the tile to compute is within the current group of problems fetched by the warp.
// The last tile for this group is the final tile of the problem held by the final thread in the warp.
int32_t group_tile_end = __shfl_sync(0xffffffff, problem_ending_tile, kThreadsPerWarp - 1);
// Keep the starting problem for this group in `problem_idx`. This is done to reduce
// register pressure. The starting problem for this group is simply the first problem
// in the group most recently fetched by the warp.
int32_t& group_problem_start = this->problem_idx;
group_problem_start = (this->problem_idx / kThreadsPerWarp) * kThreadsPerWarp;
// Keep the starting tile for this group in `problem_tile_start`. This is done to reduce
// register pressure.
int32_t& group_tile_start = this->problem_tile_start;
// Each thread in the warp processes a separate problem to advance until
// reaching a problem whose starting tile is less less than tile_idx.
while (group_tile_end <= this->tile_idx)
{
group_problem_start += kThreadsPerWarp;
if (group_problem_start > this->params.problem_count)
{
return false;
}
// Since `group_tile_start` is a reference to `this->problem_tile_start`, this
// also sets `this->problem_tile_start`. The fact that `this->problem_tile_start`
// is also set here is used later in `next_tile`.
group_tile_start = group_tile_end;
int lane_idx = threadIdx.x % kThreadsPerWarp;
int32_t lane_problem = group_problem_start + lane_idx;
// Compute the number of tiles in the problem assigned to each thread.
problem_ending_tile = 0;
if (lane_problem < this->params.problem_count)
{
cutlass::gemm::GemmCoord problem = this->problem_size(lane_problem);
cutlass::gemm::GemmCoord grid = this->grid_shape(problem);
problem_ending_tile = this->tile_count(grid);
}
// Compute a warp-wide inclusive prefix sum to compute the ending tile index of
// each thread's problem.
CUTLASS_PRAGMA_UNROLL
for (int i = 1; i < kThreadsPerWarp; i <<= 1)
{
int32_t val = __shfl_up_sync(0xffffffff, problem_ending_tile, i);
if (lane_idx >= i)
{
problem_ending_tile += val;
}
}
// The total tile count for this group is now in the final position of the prefix sum
int32_t tiles_in_group = __shfl_sync(0xffffffff, problem_ending_tile, kThreadsPerWarp - 1);
problem_ending_tile += group_tile_start;
group_tile_end += tiles_in_group;
}
// The next problem to process is the first one that does not have ending tile position
// that is greater than or equal to tile index.
int32_t problem_idx_in_group = __popc(__ballot_sync(0xffffffff, problem_ending_tile <= this->tile_idx));
this->problem_idx = group_problem_start + problem_idx_in_group;
// The starting tile for this problem is the ending tile of the previous problem. In cases
// where `problem_idx_in_group` is the first problem in the group, we do not need to reset
// `problem_tile_start`, because it is set to the previous group's ending tile in the while
// loop above.
if (problem_idx_in_group > 0)
{
this->problem_tile_start = __shfl_sync(0xffffffff, problem_ending_tile, problem_idx_in_group - 1);
}
return true;
}
static size_t get_workspace_size(
const cutlass::gemm::GemmCoord* host_problem_sizes_ptr, int32_t problem_count, int32_t block_count)
{
return 0;
}
static void host_precompute(const cutlass::gemm::GemmCoord* host_problem_sizes_ptr, int32_t problem_count,
int32_t block_count, void* host_workspace_ptr)
{
}
};
} // namespace kernel
} // namespace gemm
} // namespace cutlass

View File

@ -25,7 +25,7 @@ namespace kernels
{
template <typename T>
__global__ void ban_repeat_ngram(T* logits, const int** output_ids_buf, const bool* finished_buf,
__global__ void ban_repeat_ngram(T* logits, const int** output_ids_buf, const FinishedState* finished_buf,
const int* parent_ids_buf, int batch_size, int beam_width, const int* no_repeat_ngram_size_buf, int id_offset,
int vocab_size_padded, size_t step)
{
@ -60,7 +60,7 @@ __global__ void ban_repeat_ngram(T* logits, const int** output_ids_buf, const bo
}
// if the beam has already finished, skip ngram check
if ((finished_buf != nullptr) && (finished_buf[id_offset + local_batch_idx * beam_width + beam_idx]))
if ((finished_buf != nullptr) && (finished_buf[id_offset + local_batch_idx * beam_width + beam_idx].isFinished()))
{
return;
}
@ -134,9 +134,9 @@ __global__ void ban_repeat_ngram(T* logits, const int** output_ids_buf, const bo
}
template <typename T>
void invokeBanRepeatNgram(T* logits, const int** output_ids_buf, const bool* finished_buf, const int* parent_ids_buf,
int batch_size, int local_batch_size, int beam_width, const int* no_repeat_ngram_size_buf, int id_offset,
int vocab_size_padded, size_t step, cudaStream_t stream)
void invokeBanRepeatNgram(T* logits, const int** output_ids_buf, const FinishedState* finished_buf,
const int* parent_ids_buf, int batch_size, int local_batch_size, int beam_width,
const int* no_repeat_ngram_size_buf, int id_offset, int vocab_size_padded, size_t step, cudaStream_t stream)
{
// each input in the local batch can have different no_repeat_ngram_size. Use max for shmem allocation
// getting the max of current batch and allocate shmem as needed is ideal. But here the ngram_buf is on GPU, while
@ -160,7 +160,7 @@ void invokeBanRepeatNgram(T* logits, const int** output_ids_buf, const bool* fin
}
#define INVOKE_BAN_REPEAT_NGRAM(T) \
template void invokeBanRepeatNgram(T* logits, const int** output_ids_buf, const bool* finished_buf, \
template void invokeBanRepeatNgram(T* logits, const int** output_ids_buf, const FinishedState* finished_buf, \
const int* parent_ids_buf, int batch_size, int local_batch_size, int beam_width, \
const int* no_repeat_ngram_size_buf, int id_offset, int vocab_size_padded, size_t step, cudaStream_t stream);

View File

@ -16,6 +16,7 @@
#pragma once
#include "tensorrt_llm/kernels/decodingCommon.h"
#include <cuda_fp16.h>
#include <cuda_runtime.h>
@ -25,9 +26,9 @@ namespace kernels
{
template <typename T>
void invokeBanRepeatNgram(T* logits, const int** output_ids_buf, const bool* finished_buf, const int* parent_ids_buf,
int batch_size, int local_batch_size, int beam_width, const int* no_repeat_ngram_size_buf, int id_offset,
int vocab_size_padded, size_t step, cudaStream_t stream);
void invokeBanRepeatNgram(T* logits, const int** output_ids_buf, const FinishedState* finished_buf,
const int* parent_ids_buf, int batch_size, int local_batch_size, int beam_width,
const int* no_repeat_ngram_size_buf, int id_offset, int vocab_size_padded, size_t step, cudaStream_t stream);
} // namespace kernels
} // namespace tensorrt_llm

View File

@ -756,8 +756,8 @@ template void invokeTileEncoderResults(__nv_bfloat16* tiled_output, int* tiled_s
const size_t mem_max_seq_len, const size_t d_model, cudaStream_t stream);
#endif
__global__ void insertUnfinishedPath(BeamHypotheses beam_hyps, const bool* finished, const float* cum_log_probs,
const int batch_size, const int beam_width)
__global__ void insertUnfinishedPath(BeamHypotheses beam_hyps, const FinishedState* finished,
const float* cum_log_probs, const int batch_size, const int beam_width)
{
const int bid = blockIdx.x;
const int tgt_start_idx = beam_hyps.num_beams[bid];
@ -802,7 +802,7 @@ __global__ void insertUnfinishedPath(BeamHypotheses beam_hyps, const bool* finis
// TODO huggingface uses total length to normalize the scores, instead of number of generated tokens.
// Check that is it reasonable or not.
beam_hyps.normed_scores[tgt_beam_idx] = apply_length_penalty(cum_log_probs[src_beam_idx],
finished[src_beam_idx] ? last_token_idx + 1 : last_token_idx, length_penalty);
finished[src_beam_idx].isFinished() ? last_token_idx + 1 : last_token_idx, length_penalty);
beam_hyps.cum_log_probs[tgt_beam_idx] = cum_log_probs[src_beam_idx];
beam_hyps.num_beams[bid]++;
@ -810,7 +810,7 @@ __global__ void insertUnfinishedPath(BeamHypotheses beam_hyps, const bool* finis
}
}
void invokeInsertUnfinishedPath(BeamHypotheses beam_hyps, const bool* finished, const float* cum_log_probs,
void invokeInsertUnfinishedPath(BeamHypotheses beam_hyps, const FinishedState* finished, const float* cum_log_probs,
const int batch_size, const int beam_width, cudaStream_t stream)
{
insertUnfinishedPath<<<batch_size, 256, 0, stream>>>(beam_hyps, finished, cum_log_probs, batch_size, beam_width);

View File

@ -14,10 +14,11 @@
* limitations under the License.
*/
#include <cuda_runtime.h>
#pragma once
#include "tensorrt_llm/kernels/decodingCommon.h"
#include <cuda_runtime.h>
namespace tensorrt_llm
{
namespace kernels
@ -76,7 +77,7 @@ void invokeTileEncoderResults(T* tiled_encoder_output, int* tiled_encoder_sequen
const int* encoder_sequence_length, const size_t batch_size, const size_t beam_width, const size_t mem_max_seq_len,
const size_t d_model, cudaStream_t stream);
void invokeInsertUnfinishedPath(BeamHypotheses beam_hyps, const bool* finished, const float* cum_log_probs,
void invokeInsertUnfinishedPath(BeamHypotheses beam_hyps, const FinishedState* finished, const float* cum_log_probs,
const int batch_size, const int beam_width, cudaStream_t stream);
void invokeCopyBatchMajorToGeneralPtr(

View File

@ -82,13 +82,16 @@ public:
, mHeadSize(headSize)
, mQScaling(qScaling)
, sm(sm_)
, xmmaKernel(getXMMAKernelsV2(data_type, sm_))
{
TLLM_CHECK_WITH_INFO(
(sm == kSM_80 || sm == kSM_86 || sm == kSM_89 || sm == kSM_90), "Unsupported architecture");
TLLM_CHECK_WITH_INFO((mDataType == DATA_TYPE_FP16 || mDataType == DATA_TYPE_BF16), "Unsupported data type");
pagedKVXmmaKernel = getPagedKVXMMAKernelsV2(mDataType, sm);
xmmaKernel = getXMMAKernelsV2(mDataType, sm);
params.clear();
pagedKVParams.clear();
// get device attributes
int device_id;
@ -99,6 +102,7 @@ public:
~mhaImpl() {}
// Support packed QKV.
void setup(const int b, const int s, const int sliding_window_size, const int total_seqlen, const bool has_alibi,
const bool scale_alibi, const int tp_size, const int tp_rank)
{
@ -185,6 +189,95 @@ public:
}
}
// Support paged_kv_cache and chunked_attention.
void setup_paged_kv(const int b, const int s_q, const int s_kv, const int blocks_per_context_sequence,
const int tokens_per_kv_block, const int sliding_window_size, const int total_seqlen, const bool has_alibi,
const bool scale_alibi, const int tp_size, const int tp_rank)
{
const float inv_sqrt_scale = (1.f / (sqrtf(mHeadSize) * mQScaling));
// Note that we apply scales and bias in the order of
// (bmm1_output * scale_bmm1 + alibi) * scale_after_alibi
const float scale_after_alibi = scale_alibi ? inv_sqrt_scale : 1.0f;
const float scale_bmm1 = scale_alibi ? 1.0f : inv_sqrt_scale;
const float scale_softmax = 1.f; // Seems to be only required for int8
const float scale_bmm2 = 1.f;
Data_type scale_type = launch_params.force_fp32_acc ? DATA_TYPE_FP32 : mDataType;
set_alpha(pagedKVParams.scale_bmm1, scale_bmm1, scale_type);
set_alpha(pagedKVParams.scale_softmax, scale_softmax, scale_type);
set_alpha(pagedKVParams.scale_bmm2, scale_bmm2, scale_type);
pagedKVParams.b = b;
pagedKVParams.h = mNumHeads;
pagedKVParams.s = s_q;
pagedKVParams.d = mHeadSize;
// Total sequence length needed by TMA descriptor
// it should be actual total seq length if non-padded input is given.
mTotalSeqLen = total_seqlen;
TLLM_CHECK_WITH_INFO(tokens_per_kv_block >= 128, "FMHA with paged kv cache needs tokens_per_block >= 128 !");
// Needed by TMA descriptors.
launch_params.blocks_per_context_sequence = blocks_per_context_sequence;
pagedKVParams.q_stride_in_bytes = mNumHeads * mHeadSize * sizeof(half);
pagedKVParams.kv_stride_in_bytes = tokens_per_kv_block * mHeadSize * sizeof(half);
pagedKVParams.o_stride_in_bytes = mNumHeads * mHeadSize * sizeof(half);
// Hopper: fallback to original fmha_v2 when head_size <= 64 and seq_len <= 256
const bool isSm90 = (sm == kSM_90);
const bool isSm8x = (sm == kSM_86 || sm == kSM_89);
const bool isSm80 = (sm == kSM_80);
// always use flash attention kernels.
launch_params.flash_attention = true;
// flash attention kernles s = 0 (support any seq length)
launch_params.kernel_s = 0;
launch_params.kernel_kv_s = s_kv;
launch_params.force_unroll = true;
// enable warp-specialization kernels when s > 512.
if (isSm90 && s_kv > 512)
{
launch_params.warp_specialization = true;
launch_params.use_tma = true;
}
else
{
// enable tiled kernels on Ampere/Ada
if (launch_params.flash_attention && s_kv <= 64)
{
// flash attention tiled kernels allows larger free dim tile size (M, N) with flexibility
// in unroll dimension tile size (K). for short sequence length (s<=128), tiled kernels
// can suffer from tile quantization loss therefore use flash attention non-tiled instead
launch_params.granular_tiling = false;
}
else if (isSm8x && params.d < 256)
{
// flash attention tiled kernel is faster on Ada and Ampere derivatives when head_size>=256
launch_params.granular_tiling = false;
}
else if (isSm90 || isSm80 || isSm8x)
{
// otherwise, choose tiled kernel for Ampere/Ada
launch_params.granular_tiling = true;
}
}
// alibi.
if (has_alibi)
{
pagedKVParams.has_alibi = true;
pagedKVParams.alibi_params = AlibiParams(mNumHeads, s_kv, tp_size, tp_rank, scale_after_alibi);
}
// Sliding_window_causal mask.
if (s_kv > sliding_window_size && launch_params.attention_mask_type == ContextAttentionMaskType::CAUSAL)
{
pagedKVParams.sliding_window_size = sliding_window_size;
launch_params.attention_mask_type = ContextAttentionMaskType::SLIDING_WINDOW_CAUSAL;
}
}
// NOTE: assume that heads_interleaved = false (b, s, 3, h, d), and sequences are padded/non-padded
// TMA descriptors are used as grid_constant parameters (remove MemCpyH2D operations)
void set_tma_descriptors()
@ -246,14 +339,14 @@ public:
if (sTmaMetaInfo[i].mD == params.d)
{
q_step = sTmaMetaInfo[i].mQStep;
kv_step = sTmaMetaInfo[i].mKvStep;
kv_step = sTmaMetaInfo[i].mKVStep;
break;
}
}
// QKV [TOTAL, 3, h, d]
// NOTE: we may need to use actual seqlen to set oob_value
char* qkv_ptr = reinterpret_cast<char*>(params.qkv_ptr);
const char* qkv_ptr = reinterpret_cast<const char*>(params.qkv_ptr);
tensor_size_qkv[3] = mTotalSeqLen;
// Q: STEP_Q
@ -275,12 +368,131 @@ public:
&params.tma_desc_v);
}
// Q are contiguous in the shape of [B, S, H, D]
// Paged KV has [B, 2, NumBlocksPerSequence] buffers,
// and each points to the contiguous buffer with shape [H, TokensPerBlock, D]
// TMA descriptors need cudaMemcpyAsync since we need multiple tma descriptors in device memory.
void set_paged_kv_tma_descriptors(cudaStream_t stream)
{
// split D into multiple groups in order to match the TMA swizzle mode (128B)
const uint32_t d_in_bytes = pagedKVParams.d * sizeof(uint16_t);
const uint32_t d_groups = d_in_bytes > 128 ? d_in_bytes / 128 : 1;
uint32_t q_step = 0, kv_step = 0;
for (unsigned int i = 0u; i < sizeof(sTmaPagedKVMetaInfo) / sizeof(sTmaPagedKVMetaInfo[0]); ++i)
{
if (sTmaPagedKVMetaInfo[i].mD == pagedKVParams.d)
{
q_step = sTmaPagedKVMetaInfo[i].mQStep;
kv_step = sTmaPagedKVMetaInfo[i].mKVStep;
break;
}
}
// Separate q, and paged kv tma descriptors.
Multiple_tma_descriptor<4> q_tma_descriptor;
Multiple_tma_descriptor<4> paged_kv_tma_descriptor(
pagedKVParams.b * 2 * launch_params.blocks_per_context_sequence);
// Contiguous Q
// query tensor size [B x S, 1, H, D]
uint32_t tensor_size_q[4];
tensor_size_q[3] = mTotalSeqLen;
tensor_size_q[2] = 1;
tensor_size_q[1] = pagedKVParams.h;
tensor_size_q[0] = pagedKVParams.d;
// box size for k and v
uint32_t box_size_q[4];
box_size_q[3] = q_step;
box_size_q[2] = 1;
box_size_q[1] = 1;
box_size_q[0] = pagedKVParams.d / d_groups;
// stride size in bytes.
uint64_t tensor_stride_q[3];
tensor_stride_q[0] = tensor_size_q[0] * sizeof(uint16_t);
tensor_stride_q[1] = tensor_size_q[1] * tensor_stride_q[0];
tensor_stride_q[2] = tensor_size_q[2] * tensor_stride_q[1];
// traversal stride
uint32_t traversal_stride[4] = {1, 1, 1, 1};
// OOB fill zeros
uint32_t oob_fill = 0;
// FP32 to TF32 conversion disabled
uint32_t fp32_to_tf32 = 0;
// gmma descriptor mode
const uint32_t d_bytes_per_group = (pagedKVParams.d * sizeof(uint16_t)) / d_groups;
const cudaTmaDescSwizzle swizzle_mode = (d_bytes_per_group > 64
? cudaTmaDescSwizzle::SWIZZLE_128B
: (d_bytes_per_group > 32 ? cudaTmaDescSwizzle::SWIZZLE_64B : cudaTmaDescSwizzle::SWIZZLE_32B));
// Q ptr.
const char* q_ptr = reinterpret_cast<const char*>(pagedKVParams.q_ptr);
// Q: STEP_Q.
q_tma_descriptor.set_tma_desctriptor(q_ptr, cudaTmaDescFormat::F16_RN,
cudaTmaDescInterleave::INTERLEAVE_DISABLED, swizzle_mode, cudaTmaDescPromotion::PROMOTION_DISABLED,
tensor_size_q, tensor_stride_q, traversal_stride, box_size_q, oob_fill, fp32_to_tf32,
&pagedKVParams.tma_desc_q);
// Paged KV
// Per batch tensor size.
uint32_t tensor_size_kv[4];
tensor_size_kv[3] = 1;
tensor_size_kv[2] = pagedKVParams.h_kv;
tensor_size_kv[1] = pagedKVParams.paged_kv_cache.mTokensPerBlock;
tensor_size_kv[0] = pagedKVParams.d;
// Box size for k and v.
uint32_t box_size_kv[4];
box_size_kv[3] = 1;
box_size_kv[2] = 1;
box_size_kv[1] = kv_step;
box_size_kv[0] = pagedKVParams.d / d_groups;
// Stride size in bytes.
uint64_t tensor_stride_kv[3];
tensor_stride_kv[0] = tensor_size_kv[0] * sizeof(uint16_t);
tensor_stride_kv[1] = tensor_size_kv[1] * tensor_stride_kv[0];
tensor_stride_kv[2] = tensor_size_kv[2] * tensor_stride_kv[1];
// 2 stands for k, and v blocks.
// We only need to prepare as many tma descriptos as the number of paged kv blocks for context.
for (int block_idx = 0; block_idx < pagedKVParams.b * 2 * launch_params.blocks_per_context_sequence;
block_idx++)
{
int block_ptr_idx = int(block_idx / launch_params.blocks_per_context_sequence)
* pagedKVParams.paged_kv_cache.mMaxBlocksPerSeq
+ (block_idx % launch_params.blocks_per_context_sequence);
paged_kv_tma_descriptor.set_tma_desctriptor(
reinterpret_cast<char*>(launch_params.paged_kv_block_ptrs[block_ptr_idx]), cudaTmaDescFormat::F16_RN,
cudaTmaDescInterleave::INTERLEAVE_DISABLED, swizzle_mode, cudaTmaDescPromotion::PROMOTION_DISABLED,
tensor_size_kv, tensor_stride_kv, traversal_stride, box_size_kv, oob_fill, fp32_to_tf32, block_idx);
}
// set mMaxBlocksPerSeq to the number of blocks needed for context.
pagedKVParams.paged_kv_cache.mMaxBlocksPerSeq = launch_params.blocks_per_context_sequence;
paged_kv_tma_descriptor.copy_to_device(pagedKVParams.tma_desc_paged_kv, stream);
}
void setup_flags(const bool force_fp32_acc, const bool is_s_padded, const bool causal_mask, const int num_kv_heads)
{
// BF16 FMHA only accumulates on FP32
launch_params.force_fp32_acc = mDataType == DATA_TYPE_BF16 || force_fp32_acc;
launch_params.attention_mask_type
= causal_mask ? ContextAttentionMaskType::CAUSAL : ContextAttentionMaskType::PADDING;
// Paged KV Cache.
pagedKVParams.h_kv = num_kv_heads;
TLLM_CHECK_WITH_INFO(mNumHeads % num_kv_heads == 0, "number of Query heads should be multiple of KV heads !");
pagedKVParams.h_q_per_kv = mNumHeads / num_kv_heads;
pagedKVParams.is_s_padded = is_s_padded;
// Contiguous Cache.
params.h_kv = num_kv_heads;
params.is_s_padded = is_s_padded;
}
@ -290,11 +502,11 @@ public:
return MHARunner::fmha_supported(mHeadSize, sm);
}
void run(const void* qkvPtr, const void* cuSeqlenPtr, void* output, cudaStream_t stream)
void run(const void* qkvPtr, const void* cuSeqlenPtr, void* outputPtr, cudaStream_t stream)
{
params.qkv_ptr = const_cast<void*>(qkvPtr);
params.o_ptr = output;
params.cu_seqlens = static_cast<int*>(const_cast<void*>(cuSeqlenPtr));
params.qkv_ptr = qkvPtr;
params.o_ptr = outputPtr;
params.cu_seqlens = reinterpret_cast<const int*>(cuSeqlenPtr);
if (sm == kSM_90 && launch_params.use_tma)
{
@ -305,9 +517,31 @@ public:
xmmaKernel->run(params, launch_params, stream);
}
void run_paged_kv(const void* qPtr, void* pagedKVTmaDesc, const void* pagedKVBlockPtrsOnHost,
const KVBlockArray pagedKVCache, const void* cuQSeqlenPtr, const void* cuKVSeqlenPtr, void* outputPtr,
cudaStream_t stream)
{
pagedKVParams.q_ptr = qPtr;
pagedKVParams.tma_desc_paged_kv = reinterpret_cast<cudaTmaDesc*>(pagedKVTmaDesc);
pagedKVParams.paged_kv_cache = pagedKVCache;
pagedKVParams.o_ptr = outputPtr;
pagedKVParams.cu_q_seqlens = reinterpret_cast<const int*>(cuQSeqlenPtr);
pagedKVParams.cu_seqlens = reinterpret_cast<const int*>(cuKVSeqlenPtr);
// paged kv block device ptrs on host (used by tma descriptors).
launch_params.paged_kv_block_ptrs = reinterpret_cast<const int64_t*>(pagedKVBlockPtrsOnHost);
if (sm == kSM_90 && launch_params.use_tma)
{
// memcpy H2D is needed as we use multiple tma descriptors in device memory.
set_paged_kv_tma_descriptors(stream);
}
pagedKVXmmaKernel->run(pagedKVParams, launch_params, stream);
}
bool isValid(int s) const
{
return xmmaKernel->isValid(s);
return pagedKVXmmaKernel->isValid(s) && xmmaKernel->isValid(s);
}
int getSFromMaxSeqLen(const int max_seq_len)
@ -345,9 +579,11 @@ public:
private:
Fused_multihead_attention_params_v2 params;
Fused_multihead_attention_paged_kv_params_v2 pagedKVParams;
Launch_params launch_params;
int sm;
const FusedMultiHeadAttentionXMMAKernelV2* xmmaKernel;
const FusedMultiHeadAttentionPagedKVXMMAKernelV2* pagedKVXmmaKernel;
bool use_flash_attention = false;
const Data_type mDataType;
const int mNumHeads;
@ -372,6 +608,14 @@ void FusedMHARunnerV2::setup(const int b, const int s, const int sliding_window_
pimpl->setup(b, s, sliding_window_size, total_seqlen, has_alibi, scale_alibi, tp_size, tp_rank);
}
void FusedMHARunnerV2::setup_paged_kv(const int b, const int s_q, const int s_kv, const int blocks_per_context_sequence,
const int tokens_per_kv_block, const int sliding_window_size, const int total_seqlen, const bool has_alibi,
const bool scale_alibi, const int tp_size, const int tp_rank)
{
pimpl->setup_paged_kv(b, s_q, s_kv, blocks_per_context_sequence, tokens_per_kv_block, sliding_window_size,
total_seqlen, has_alibi, scale_alibi, tp_size, tp_rank);
}
bool FusedMHARunnerV2::fmha_supported()
{
return pimpl->fmha_supported();
@ -383,9 +627,17 @@ void FusedMHARunnerV2::setup_flags(
pimpl->setup_flags(force_fp32_acc, is_s_padded, causal_mask, num_kv_heads);
}
void FusedMHARunnerV2::run(const void* qkvPtr, const void* cuSeqlenPtr, void* output, cudaStream_t stream)
void FusedMHARunnerV2::run(const void* qkvPtr, const void* cuSeqlenPtr, void* outputPtr, cudaStream_t stream)
{
pimpl->run(qkvPtr, cuSeqlenPtr, output, stream);
pimpl->run(qkvPtr, cuSeqlenPtr, outputPtr, stream);
}
void FusedMHARunnerV2::run_paged_kv(const void* qPtr, void* pagedKVTmaDesc, const void* pagedKVBlockPtrsOnHost,
const KVBlockArray pagedKVCache, const void* cuQSeqlenPtr, const void* cuKVSeqlenPtr, void* outputPtr,
cudaStream_t stream)
{
pimpl->run_paged_kv(
qPtr, pagedKVTmaDesc, pagedKVBlockPtrsOnHost, pagedKVCache, cuQSeqlenPtr, cuKVSeqlenPtr, outputPtr, stream);
}
bool FusedMHARunnerV2::isValid(int s) const

View File

@ -51,6 +51,11 @@ public:
const bool has_alibi = false, const bool scale_alibi = false, const int tp_size = 1, const int tp_rank = 0)
= 0;
virtual void setup_paged_kv(const int b, const int s_q, const int s_kv, const int blocks_per_context_sequence,
const int tokens_per_kv_block, const int sliding_window_size, const int total_seqlen,
const bool has_alibi = false, const bool scale_alibi = false, const int tp_size = 1, const int tp_rank = 0)
= 0;
static bool fmha_supported(const int headSize, const int sm);
virtual bool fmha_supported() = 0;
@ -61,6 +66,11 @@ public:
virtual void run(const void* input, const void* cu_seqlens, void* output, cudaStream_t stream) = 0;
virtual void run_paged_kv(const void* q_input, void* paged_kv_tma_desc, const void* paged_kv_block_ptrs_on_host,
const KVBlockArray paged_kv_cache, const void* cu_q_seqlens, const void* cu_kv_seqlens, void* output,
cudaStream_t stream)
= 0;
virtual bool isValid(int s) const = 0;
};
@ -84,9 +94,17 @@ public:
const bool has_alibi = false, const bool scale_alibi = false, const int tp_size = 1,
const int tp_rank = 0) override;
void setup_paged_kv(const int b, const int s_q, const int s_kv, const int blocks_per_context_sequence,
const int tokens_per_kv_block, const int sliding_window_size, const int total_seqlen,
const bool has_alibi = false, const bool scale_alibi = false, const int tp_size = 1,
const int tp_rank = 0) override;
bool fmha_supported() override;
void run(const void* input, const void* cu_seqlens, void* output, cudaStream_t stream) override;
void run_paged_kv(const void* q_input, void* paged_kv_tma_desc, const void* paged_kv_block_ptrs_on_host,
const KVBlockArray paged_kv_cache, const void* cu_q_seqlens, const void* cu_kv_seqlens, void* output,
cudaStream_t stream) override;
void setup_flags(const bool force_fp32_acc, const bool is_s_padded, const bool causal_mask,
const int num_kv_heads /* MQA or GQA */) override;

View File

@ -16,26 +16,17 @@
#pragma once
#include "tensorrt_llm/kernels/kvCacheUtils.h"
#include "tmaDescriptor.h"
#include <limits.h>
#include <stdint.h>
#include "tensorrt_llm/kernels/multiHeadAttentionCommon.h"
namespace tensorrt_llm
{
namespace kernels
{
enum Data_type
{
DATA_TYPE_BOOL,
DATA_TYPE_FP16,
DATA_TYPE_FP32,
DATA_TYPE_INT4,
DATA_TYPE_INT8,
DATA_TYPE_INT32,
DATA_TYPE_BF16,
DATA_TYPE_E4M3,
DATA_TYPE_E5M2
};
enum class ContextFMHAType
{
@ -52,14 +43,6 @@ enum class ContextAttentionMaskType
SLIDING_WINDOW_CAUSAL
};
constexpr int32_t kSM_70 = 70;
constexpr int32_t kSM_72 = 72;
constexpr int32_t kSM_75 = 75;
constexpr int32_t kSM_80 = 80;
constexpr int32_t kSM_86 = 86;
constexpr int32_t kSM_89 = 89;
constexpr int32_t kSM_90 = 90;
struct AlibiParams
{
constexpr static int round_down_to_power_two(int x)
@ -101,9 +84,9 @@ struct AlibiParams
struct Fused_multihead_attention_params_v2
{
// The QKV matrices.
void* qkv_ptr;
const void* qkv_ptr;
// The mask to implement drop-out.
void* packed_mask_ptr;
const void* packed_mask_ptr;
// The O matrix (output).
void* o_ptr;
@ -123,7 +106,7 @@ struct Fused_multihead_attention_params_v2
bool enable_i2f_trick;
// array of length b+1 holding prefix sum of actual sequence lengths
int* cu_seqlens;
const int* cu_seqlens;
// use C/32 Format.
bool interleaved = false;
@ -183,6 +166,102 @@ struct Fused_multihead_attention_params_v2
use_int8_scale_max = false;
h_kv = 0;
sliding_window_size = INT_MAX;
is_s_padded = false;
has_alibi = false;
alibi_params = AlibiParams{};
}
};
struct Fused_multihead_attention_paged_kv_params_v2
{
// The Q matrices.
const void* q_ptr;
// Paged KV Cache buffer.
KVBlockArray paged_kv_cache;
// The O matrix (output).
void* o_ptr;
// The packed mask for random mask.
const void* packed_mask_ptr;
// The stride between rows of the Q matrices.
int64_t q_stride_in_bytes;
// The stride between rows of the paged KV matrices.
int64_t kv_stride_in_bytes;
// The stride between rows of O.
int64_t o_stride_in_bytes;
// The stride between matrices of packed mask.
int64_t packed_mask_stride_in_bytes;
// The dimensions.
int b, h, s, d;
// The scaling factors for the kernel.
uint32_t scale_bmm1, scale_softmax, scale_bmm2;
// Do we use Niall's trick to avoid I2F/F2I in the INT8 kernel.
// See https://confluence.nvidia.com/pages/viewpage.action?pageId=302779721 for details.
bool enable_i2f_trick;
// true: for int8, instead of doing max reduce, use max value encoded in scale factor
bool use_int8_scale_max = false;
// If the kernel is using alibi or not
bool has_alibi = false;
AlibiParams alibi_params;
// array of length b+1 holding prefix sum of actual kv sequence lengths.
const int* cu_seqlens;
// Chunked attention (only handles one tile of Q).
const int* cu_q_seqlens;
// q with shape [B, S, H, D] in const cache.
cudaTmaDesc tma_desc_q;
// Tma descriptors for paged kv cache.
// paged kv has [B, 2, Num_blocks] buffers,
// and each points to [Tokens_per_block, H, D] contiguous memory.
cudaTmaDesc* tma_desc_paged_kv;
// In multi-query or grouped-query attention (MQA/GQA), several Q heads are associated with one KV head
int h_kv = 0;
int h_q_per_kv = 0;
// Sliding Window Attention
// Only pay attention to [max(0, query_idx - sliding_window_size), query_idx].
int sliding_window_size = INT_MAX;
// is input/output padded
bool is_s_padded = false;
void clear()
{
q_ptr = nullptr;
o_ptr = nullptr;
packed_mask_ptr = nullptr;
q_stride_in_bytes = 0;
kv_stride_in_bytes = 0;
o_stride_in_bytes = 0;
packed_mask_stride_in_bytes = 0;
b = 0;
h = 0;
s = 0;
d = 0;
// The scaling factors for the kernel.
scale_bmm1 = 0;
scale_softmax = 0;
scale_bmm2 = 0;
enable_i2f_trick = false;
cu_seqlens = nullptr;
cu_q_seqlens = nullptr;
use_int8_scale_max = false;
h_kv = 0;
h_q_per_kv = 0;
sliding_window_size = INT_MAX;
is_s_padded = false;
has_alibi = false;
@ -195,6 +274,8 @@ struct Launch_params
{
// seq_length to select the kernel
int kernel_s = 0;
// kv_seq_length to set launch strategies.
int kernel_kv_s = 0;
// flags to control small batch kernel choice
// true: never unroll
bool ignore_b1opt = false;
@ -208,6 +289,10 @@ struct Launch_params
bool use_tma = false;
// host seqlens to set tma descriptors
int* seqlens = nullptr;
// number of paged kv blocks for context sequence.
int blocks_per_context_sequence = 0;
// device ptrs on the host for paged kv cache.
const int64_t* paged_kv_block_ptrs = nullptr;
// if flash attention is used (only FP16)
bool flash_attention = false;
// if warp_specialized kernels are used (only SM90 HGMMA + TMA)

View File

@ -18,6 +18,7 @@
#include "cubin/fmha_cubin.h"
#include "cuda_runtime_api.h"
#include "fused_multihead_attention_common.h"
#include "pagedKVCubin/fmha_cubin.h"
#include "tensorrt_llm/common/assert.h"
#include "tensorrt_llm/common/cudaDriverWrapper.h"
#include "tmaDescriptor.h"
@ -44,9 +45,19 @@ static const struct TmaKernelMetaInfo
{
unsigned int mD;
unsigned int mQStep;
unsigned int mKvStep;
unsigned int mKVStep;
} sTmaMetaInfo[] = {{32, 64, 256}, {64, 64, 256}, {128, 64, 128}, {256, 64, 64}};
////////////////////////////////////////////////////////////////////////////////////////////////////
// meta info for tma warp-specialized kernels that supports paged kv cache
static const struct TmaPagedKVKernelMetaInfo
{
unsigned int mD;
unsigned int mQStep;
unsigned int mKVStep;
} sTmaPagedKVMetaInfo[] = {{32, 64, 128}, {64, 64, 128}, {128, 64, 128}, {256, 64, 64}};
////////////////////////////////////////////////////////////////////////////////////////////////////
// Base Class
@ -206,6 +217,8 @@ private:
////////////////////////////////////////////////////////////////////////////////////////////////////
// FMHA kernels that support Contiguous QKV input.
class FusedMultiHeadAttentionXMMAKernelV2
: public TFusedMultiHeadAttentionXMMAKernel<FusedMultiHeadAttentionKernelMetaInfoV2,
Fused_multihead_attention_params_v2>
@ -291,7 +304,7 @@ public:
kernelMeta.mSharedMemBytes, stream, kernelParams, nullptr),
mDriver);
} // forceunroll = true for flash attention kernels
else if (mSM == kSM_90 && launch_params.flash_attention)
else if (mSM == kSM_90 && launch_params.flash_attention && launch_params.warp_specialization)
{
// tricks for launching warp-specialized flash attention kernels on Hopper
dim3 block_size(1, std::min(params.b * params.h, launch_params.multi_processor_count));
@ -302,9 +315,9 @@ public:
size_t m_steps = size_t((params.s + kernelMeta.mUnrollStep - 1) / kernelMeta.mUnrollStep);
m_steps = size_t((m_steps + NUM_COMPUTE_GROUPS - 1) / NUM_COMPUTE_GROUPS) * NUM_COMPUTE_GROUPS;
size_t size_in_bytes = params.b * params.s * params.qkv_stride_in_bytes;
if (launch_params.attention_mask_type == ContextAttentionMaskType::PADDING
&& size_in_bytes <= launch_params.device_l2_cache_size)
// 2 * 2 stands for kv cache and 2 bytes per element.
size_t size_in_bytes = block_size.y * params.s * params.d * 2 * 2;
if (size_in_bytes <= launch_params.device_l2_cache_size)
{
// strategy 1: limit to only 1 wave
block_size.x = std::min(m_steps / NUM_COMPUTE_GROUPS, sms_per_head);
@ -328,8 +341,8 @@ public:
{
unroll = (params.s + kernelMeta.mUnrollStep - 1) / kernelMeta.mUnrollStep;
}
// on Hopper, we still launch blocks (h, b, steps)
if (mSM == kSM_90)
// on Hopper non-flash-attention, we still launch blocks (h, b, steps)
if (mSM == kSM_90 && !launch_params.flash_attention)
{
cuErrCheck(mDriver.cuLaunchKernel(func, params.h, params.b, unroll, kernelMeta.mThreadsPerCTA, 1, 1,
kernelMeta.mSharedMemBytes, stream, kernelParams, nullptr),
@ -353,5 +366,115 @@ inline const FusedMultiHeadAttentionXMMAKernelV2* getXMMAKernelsV2(Data_type typ
sMhaKernelMetaInfosV2, sizeof(sMhaKernelMetaInfosV2) / sizeof(sMhaKernelMetaInfosV2[0]), type, sm);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
// FMHA kernels that support Paged KV Cache and Chunked Attention.
class FusedMultiHeadAttentionPagedKVXMMAKernelV2
: public TFusedMultiHeadAttentionXMMAKernel<FusedMultiHeadAttentionPagedKVKernelMetaInfoV2,
Fused_multihead_attention_paged_kv_params_v2>
{
public:
FusedMultiHeadAttentionPagedKVXMMAKernelV2(const FusedMultiHeadAttentionPagedKVKernelMetaInfoV2* pMetaStart,
unsigned int nMetaCount, Data_type type, unsigned int sm)
: TFusedMultiHeadAttentionXMMAKernel<FusedMultiHeadAttentionPagedKVKernelMetaInfoV2,
Fused_multihead_attention_paged_kv_params_v2>(pMetaStart, nMetaCount, type, sm)
{
}
inline uint64_t hashID(unsigned int s, unsigned int d, bool interleaved, bool unroll, bool force_fp32_acc,
bool flash_attention, bool warp_specialization, int attention_mask_type, bool tiled) const
{
s = flash_attention ? 0 : s;
// D <= 2048
return (uint64_t) s << 32 | d << 16 | (attention_mask_type << 6) | (warp_specialization ? 16ull : 0ull)
| (tiled ? 16ull : 0ull) | (force_fp32_acc ? 8ull : 0ull) | (flash_attention ? 4ull : 0ull)
| (interleaved ? 2ull : 0ull) | (unroll ? 1ull : 0ull);
}
virtual uint64_t hashID(const KernelMeta& kernelMeta) const
{
return hashID(kernelMeta.mS, kernelMeta.mD, kernelMeta.mInterleaved, kernelMeta.mUnrollStep,
kernelMeta.mFP32Accumulation, kernelMeta.mFlashAttention, kernelMeta.mWarpSpecialization,
kernelMeta.mAttentionMaskType, kernelMeta.mTiled);
}
virtual void run(
Fused_multihead_attention_paged_kv_params_v2& params, Launch_params& launch_params, cudaStream_t stream) const
{
const auto findIter = mFunctions.find(
hashID(launch_params.kernel_s, params.d, launch_params.interleaved, launch_params.force_unroll,
launch_params.force_fp32_acc, launch_params.flash_attention, launch_params.warp_specialization,
static_cast<int>(launch_params.attention_mask_type), launch_params.granular_tiling));
TLLM_CHECK_WITH_INFO(findIter != mFunctions.end(), "FMHA kernels are not found");
const auto& kernelMeta = mKernelMeta[findIter->second.mMetaInfoIndex];
const CUfunction func = findIter->second.mDeviceFunction;
void* kernelParams[] = {&params, nullptr};
if (mSM == kSM_90 && launch_params.flash_attention && launch_params.warp_specialization)
{
// tricks for launching warp-specialized flash attention kernels on Hopper
dim3 block_size(1, std::min(params.b * params.h, launch_params.multi_processor_count));
// distribute m steps to multiple blocks (fully utilize SMs)
// block.x = blocks that handle single head, block.y = blocks that handle different heads
size_t sms_per_head = (launch_params.multi_processor_count) / block_size.y;
size_t m_steps = size_t((params.s + kernelMeta.mUnrollStep - 1) / kernelMeta.mUnrollStep);
m_steps = size_t((m_steps + NUM_COMPUTE_GROUPS - 1) / NUM_COMPUTE_GROUPS) * NUM_COMPUTE_GROUPS;
// 2 * 2 stands for kv cache and 2 bytes per element.
size_t size_in_bytes = block_size.y * launch_params.kernel_kv_s * params.d * 2 * 2;
if (size_in_bytes <= launch_params.device_l2_cache_size)
{
// strategy 1: limit to only 1 wave
block_size.x = std::min(m_steps / NUM_COMPUTE_GROUPS, sms_per_head);
}
else
{
// strategy 2: fully unroll the q loops (contiguous blocks handle all q loops)
block_size.x = m_steps / NUM_COMPUTE_GROUPS;
}
cuErrCheck(mDriver.cuLaunchKernel(func, block_size.x, block_size.y, block_size.z, kernelMeta.mThreadsPerCTA,
1, 1, kernelMeta.mSharedMemBytes, stream, kernelParams, nullptr),
mDriver);
}
else
{ // forceunroll = true for flash attention kernels
int unroll = kernelMeta.mS / kernelMeta.mUnrollStep;
TLLM_CHECK_WITH_INFO(kernelMeta.mS == kernelMeta.mUnrollStep * unroll, "Wrong launching sequence length");
// flash attention supports any sequence length, so we runtime s here
if (launch_params.flash_attention)
{
unroll = (params.s + kernelMeta.mUnrollStep - 1) / kernelMeta.mUnrollStep;
}
// on Hopper non-flash-attention, we still launch blocks (h, b, steps)
if (mSM == kSM_90 && !launch_params.flash_attention)
{
cuErrCheck(mDriver.cuLaunchKernel(func, params.h, params.b, unroll, kernelMeta.mThreadsPerCTA, 1, 1,
kernelMeta.mSharedMemBytes, stream, kernelParams, nullptr),
mDriver);
} // on Ampere/Ada flash attention, we launch blocks (steps, h, b)
else
{
cuErrCheck(mDriver.cuLaunchKernel(func, unroll, params.h, params.b, kernelMeta.mThreadsPerCTA, 1, 1,
kernelMeta.mSharedMemBytes, stream, kernelParams, nullptr),
mDriver);
}
}
}
};
using FusedMHAPagedKVKernelFactoryV2 = TFusedMHAKernelFactory<FusedMultiHeadAttentionPagedKVXMMAKernelV2>;
inline const FusedMultiHeadAttentionPagedKVXMMAKernelV2* getPagedKVXMMAKernelsV2(Data_type type, unsigned int sm)
{
return FusedMHAPagedKVKernelFactoryV2::Get().getXMMAKernels(sMhaPagedKVKernelMetaInfosV2,
sizeof(sMhaPagedKVKernelMetaInfosV2) / sizeof(sMhaPagedKVKernelMetaInfosV2[0]), type, sm);
}
} // namespace kernels
} // namespace tensorrt_llm

File diff suppressed because it is too large Load Diff

Some files were not shown because too many files have changed in this diff Show More