mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
Update TensorRT-LLM (#422)
* Update TensorRT-LLM --------- Co-authored-by: Tltin <TltinDeng01@gmail.com> Co-authored-by: zhaohb <zhaohbcloud@126.com> Co-authored-by: Bradley Heilbrun <brad@repl.it> Co-authored-by: nqbao11 <nqbao11.01@gmail.com> Co-authored-by: Nikhil Varghese <nikhil@bot-it.ai>
This commit is contained in:
parent
ab7b4614b8
commit
6755a3f077
2
3rdparty/cutlass
vendored
2
3rdparty/cutlass
vendored
@ -1 +1 @@
|
||||
Subproject commit fc9ebc645b63f3a6bc80aaefde5c063fb72110d6
|
||||
Subproject commit 39c6a83f231d6db2bc6b9c251e7add77d68cbfb4
|
||||
2
3rdparty/json
vendored
2
3rdparty/json
vendored
@ -1 +1 @@
|
||||
Subproject commit 5fec8034933ef434a98dfbd2551b052c56345869
|
||||
Subproject commit bc889afb4c5bf1c0d8ee29ef35eaaf4c8bef8a5d
|
||||
61
README.md
61
README.md
@ -43,17 +43,22 @@ H200 FP8 achieves 11,819 tok/s on Llama2-13B on a single GPU, and is up to 1.9x
|
||||
- [Installation](#installation)
|
||||
- [Quick Start](#quick-start)
|
||||
- [Support Matrix](#support-matrix)
|
||||
- [Devices](#devices)
|
||||
- [Precision](#precision)
|
||||
- [Key Features](#key-features)
|
||||
- [Models](#models)
|
||||
- [Performance](#performance)
|
||||
- [Advanced Topics](#advanced-topics)
|
||||
- [Quantization](#quantization)
|
||||
- [In-flight Batching](#in-flight-batching)
|
||||
- [Attention](#attention)
|
||||
- [Graph Rewriting](#graph-rewriting)
|
||||
- [Benchmarking](#benchmarking)
|
||||
- [Benchmark](#benchmark)
|
||||
- [Troubleshooting](#troubleshooting)
|
||||
- [Release Notes](#release-notes)
|
||||
- [Changelog](#changelog)
|
||||
- [Known issues](#known-issues)
|
||||
- [Release notes](#release-notes)
|
||||
- [Change Log](#change-log)
|
||||
- [Known Issues](#known-issues)
|
||||
- [Report Issues](#report-issues)
|
||||
|
||||
## TensorRT-LLM Overview
|
||||
|
||||
@ -154,14 +159,14 @@ See the BLOOM [example](examples/bloom) for more details and options regarding t
|
||||
|
||||
***3. Run***
|
||||
|
||||
The `summarize.py` script can be used to perform the summarization of articles
|
||||
The `../summarize.py` script can be used to perform the summarization of articles
|
||||
from the CNN Daily dataset:
|
||||
|
||||
```python
|
||||
python summarize.py --test_trt_llm \
|
||||
--hf_model_location ./bloom/560M/ \
|
||||
--data_type fp16 \
|
||||
--engine_dir ./bloom/560M/trt_engines/fp16/1-gpu/
|
||||
python ../summarize.py --test_trt_llm \
|
||||
--hf_model_dir ./bloom/560M/ \
|
||||
--data_type fp16 \
|
||||
--engine_dir ./bloom/560M/trt_engines/fp16/1-gpu/
|
||||
```
|
||||
|
||||
More details about the script and how to run the BLOOM model can be found in
|
||||
@ -237,19 +242,26 @@ The list of supported models is:
|
||||
* [Bert](examples/bert)
|
||||
* [Blip2](examples/blip2)
|
||||
* [BLOOM](examples/bloom)
|
||||
* [ChatGLM](examples/chatglm), including ChatGLM-6B, ChatGLM2-6B, ChatGLM2-6B-32k, ChatGLM3-6B, ChatGLM3-6B-32k
|
||||
* [ChatGLM](examples/chatglm)
|
||||
* [Falcon](examples/falcon)
|
||||
* [Flan-T5](examples/enc_dec)
|
||||
* [GPT](examples/gpt)
|
||||
* [GPT-J](examples/gptj)
|
||||
* [GPT-Nemo](examples/gpt)
|
||||
* [GPT-NeoX](examples/gptneox)
|
||||
* [InternLM](examples/internlm)
|
||||
* [LLaMA](examples/llama)
|
||||
* [LLaMA-v2](examples/llama)
|
||||
* [Mistral](examples/llama)
|
||||
* [MPT](examples/mpt)
|
||||
* [OPT](examples/opt)
|
||||
* [Qwen](examples/qwen)
|
||||
* [Replit Code](examples/mpt)
|
||||
* [SantaCoder](examples/gpt)
|
||||
* [StarCoder](examples/gpt)
|
||||
* [InternLM](examples/internlm)
|
||||
* [T5](examples/enc_dec)
|
||||
|
||||
Note: [Encoder-Decoder](examples/enc_dec/) provides general encoder-decoder support that contains many encoder-decoder models such as T5, Flan-T5, etc. We unroll the exact model names in the list above to let users find specific models easiler.
|
||||
|
||||
## Performance
|
||||
|
||||
@ -311,6 +323,33 @@ may happen. One possible solution is to reduce the amount of memory needed by
|
||||
reducing the maximum batch size, input and output lengths. Another option is to
|
||||
enable plugins, for example: `--use_gpt_attention_plugin`.
|
||||
|
||||
* MPI + Slurm
|
||||
|
||||
TensorRT-LLM is a [MPI](https://en.wikipedia.org/wiki/Message_Passing_Interface)-aware package that uses [`mpi4py`](https://mpi4py.readthedocs.io/en/stable/). If you are running scripts in a [Slurm](https://slurm.schedmd.com/) environment, you might encounter interferences:
|
||||
```
|
||||
--------------------------------------------------------------------------
|
||||
PMI2_Init failed to initialize. Return code: 14
|
||||
--------------------------------------------------------------------------
|
||||
--------------------------------------------------------------------------
|
||||
The application appears to have been direct launched using "srun",
|
||||
but OMPI was not built with SLURM's PMI support and therefore cannot
|
||||
execute. There are several options for building PMI support under
|
||||
SLURM, depending upon the SLURM version you are using:
|
||||
|
||||
version 16.05 or later: you can use SLURM's PMIx support. This
|
||||
requires that you configure and build SLURM --with-pmix.
|
||||
|
||||
Versions earlier than 16.05: you must use either SLURM's PMI-1 or
|
||||
PMI-2 support. SLURM builds PMI-1 by default, or you can manually
|
||||
install PMI-2. You must then build Open MPI using --with-pmi pointing
|
||||
to the SLURM PMI library location.
|
||||
|
||||
Please configure as appropriate and try again.
|
||||
--------------------------------------------------------------------------
|
||||
```
|
||||
As a rule of thumb, if you are running TensorRT-LLM interactively on a Slurm node, prefix your commands with `mpirun -n 1` to run TensorRT-LLM in a dedicated MPI environment, not the one provided by your Slurm allocation.
|
||||
For example: `mpirun -n 1 python3 examples/gpt/build.py ...`
|
||||
|
||||
## Release notes
|
||||
|
||||
* TensorRT-LLM requires TensorRT 9.1.0.4 and 23.08 containers.
|
||||
|
||||
@ -7,18 +7,14 @@ multiple GPUs or multiple nodes with multiple GPUs.
|
||||
|
||||
### 1. Build TensorRT-LLM and benchmarking source code
|
||||
|
||||
Please follow the [`installation document`](../../../README.md) to build TensorRT-LLM.
|
||||
Please follow the [`installation document`](../../docs/source/installation.md) to build TensorRT-LLM.
|
||||
|
||||
Note that the benchmarking source code for C++ runtime is not built by default, you can use the argument `--benchmarks` in [`build_wheel.py`](../../scripts/build_wheel.py) to build that.
|
||||
|
||||
Windows users: Follow the
|
||||
[`Windows installation document`](../../../windows/README.md)
|
||||
[`Windows installation document`](../../windows/README.md)
|
||||
instead, and be sure to set DLL paths as specified in
|
||||
[Extra Steps for C++ Runtime Usage](../../../windows/README.md#extra-steps-for-c-runtime-usage).
|
||||
|
||||
After that, you can build benchmarking source code for C++ runtime
|
||||
```
|
||||
cd cpp/build
|
||||
make -j benchmarks
|
||||
```
|
||||
[Extra Steps for C++ Runtime Usage](../../windows/README.md#extra-steps-for-c-runtime-usage).
|
||||
|
||||
### 2. Launch C++ benchmarking (Fixed BatchSize/InputLen/OutputLen)
|
||||
|
||||
@ -59,6 +55,8 @@ mpirun -n 8 ./benchmarks/gptSessionBenchmark \
|
||||
# [BENCHMARK] batch_size 1 input_length 60 output_length 20 latency(ms) 792.14
|
||||
```
|
||||
|
||||
If you want to obtain context and generation logits, you could build an enigne with `--gather_all_token_logits` and run gptSessionBenchmark with `--print_all_logits`. This will print a large number of logit values and has a certain impact on performance.
|
||||
|
||||
*Please note that the expected outputs in that document are only for reference, specific performance numbers depend on the GPU you're using.*
|
||||
|
||||
### 3. Launch Batch Manager benchmarking (Inflight/V1 batching)
|
||||
|
||||
@ -18,6 +18,7 @@
|
||||
#include "tensorrt_llm/plugins/api/tllmPlugin.h"
|
||||
#include "tensorrt_llm/runtime/gptJsonConfig.h"
|
||||
#include "tensorrt_llm/runtime/gptSession.h"
|
||||
#include "tensorrt_llm/runtime/iTensor.h"
|
||||
#include "tensorrt_llm/runtime/memoryCounters.h"
|
||||
#include "tensorrt_llm/runtime/tllmLogger.h"
|
||||
|
||||
@ -37,7 +38,7 @@ namespace
|
||||
void benchmarkGptSession(std::string const& modelName, std::filesystem::path const& dataPath,
|
||||
std::vector<int> const& batchSizes, int beamWidth, std::vector<std::vector<int>> const& inOutLen,
|
||||
std::shared_ptr<nvinfer1::ILogger> const& logger, int warmUp, int numRuns, int duration,
|
||||
GptSession::Config& sessionConfig, bool cudaGraphMode)
|
||||
GptSession::Config& sessionConfig, bool cudaGraphMode, bool printAllLogits)
|
||||
{
|
||||
|
||||
std::string modelNameHyphen = modelName;
|
||||
@ -60,7 +61,6 @@ void benchmarkGptSession(std::string const& modelName, std::filesystem::path con
|
||||
|
||||
SamplingConfig samplingConfig{beamWidth};
|
||||
samplingConfig.temperature = std::vector{1.0f};
|
||||
samplingConfig.minLength = std::vector{1};
|
||||
samplingConfig.randomSeed = std::vector{42ull};
|
||||
samplingConfig.topK = std::vector{1};
|
||||
samplingConfig.topP = std::vector{0.0f};
|
||||
@ -77,6 +77,7 @@ void benchmarkGptSession(std::string const& modelName, std::filesystem::path con
|
||||
auto const maxNewTokens = inOut[1];
|
||||
|
||||
sessionConfig.maxSequenceLength = maxInputLength + maxNewTokens;
|
||||
samplingConfig.minLength = std::vector{maxNewTokens};
|
||||
|
||||
GptSession session{sessionConfig, modelConfig, worldConfig, enginePath.string(), logger};
|
||||
|
||||
@ -102,6 +103,7 @@ void benchmarkGptSession(std::string const& modelName, std::filesystem::path con
|
||||
// copy inputs and wrap into shared_ptr
|
||||
GenerationInput::TensorPtr inputIds;
|
||||
std::vector<int32_t> inputsHost(batchSize * maxInputLength, padId);
|
||||
|
||||
if (inputPacked)
|
||||
{
|
||||
inputIds = bufferManager.copyFrom(
|
||||
@ -123,6 +125,17 @@ void benchmarkGptSession(std::string const& modelName, std::filesystem::path con
|
||||
bufferManager.emptyTensor(MemoryType::kGPU, nvinfer1::DataType::kINT32),
|
||||
bufferManager.emptyTensor(MemoryType::kGPU, nvinfer1::DataType::kINT32)};
|
||||
|
||||
if (session.getModelConfig().computeContextLogits())
|
||||
{
|
||||
generationOutput.contextLogits
|
||||
= bufferManager.emptyTensor(MemoryType::kGPU, nvinfer1::DataType::kFLOAT);
|
||||
}
|
||||
if (session.getModelConfig().computeGenerationLogits())
|
||||
{
|
||||
generationOutput.generationLogits
|
||||
= bufferManager.emptyTensor(MemoryType::kGPU, nvinfer1::DataType::kFLOAT);
|
||||
bufferManager.setZero(*generationOutput.generationLogits);
|
||||
}
|
||||
TLLM_LOG_INFO(memoryCounter.toString());
|
||||
|
||||
for (auto r = 0; r < warmUp; ++r)
|
||||
@ -168,6 +181,30 @@ void benchmarkGptSession(std::string const& modelName, std::filesystem::path con
|
||||
"%.2f\n",
|
||||
batchSize, maxInputLength, maxNewTokens, averageLatency, tokensPerSec);
|
||||
}
|
||||
|
||||
// logits are store in last rank
|
||||
if (worldConfig.getRank() == worldConfig.getSize() - 1)
|
||||
{
|
||||
if (session.getModelConfig().computeContextLogits() && printAllLogits)
|
||||
{
|
||||
std::cout << "generationOutput.contextLogits.shape: "
|
||||
<< generationOutput.contextLogits->getShape()
|
||||
<< std::endl; // (batchsize, prompt_len, vocabsize)
|
||||
std::cout << "generationOutput.contextLogits" << *generationOutput.contextLogits << std::endl;
|
||||
}
|
||||
|
||||
if (session.getModelConfig().computeGenerationLogits() && printAllLogits)
|
||||
{
|
||||
std::cout << "generationOutput.generationLogits.shape: "
|
||||
<< generationOutput.generationLogits->getShape()
|
||||
<< std::endl; // (batchsize, beamwidth, maxNewTokens-1, vocabsize)
|
||||
generationOutput.generationLogits->reshape(ITensor::makeShape({batchSize * beamWidth,
|
||||
maxNewTokens - 1, modelConfig.getVocabSizePadded(worldConfig.getSize())}));
|
||||
|
||||
std::cout << "generationOutput.generationLogits: " << *generationOutput.generationLogits
|
||||
<< std::endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
catch (std::runtime_error& e)
|
||||
{
|
||||
@ -231,6 +268,7 @@ int main(int argc, char* argv[])
|
||||
"kv_cache_free_gpu_mem_fraction", "K-V Cache Free Gpu Mem Fraction.", cxxopts::value<float>());
|
||||
|
||||
options.add_options()("enable_cuda_graph", "Execute GPT session with CUDA graph.");
|
||||
options.add_options()("print_all_logits", "Print all context and generation logits.");
|
||||
|
||||
auto result = options.parse(argc, argv);
|
||||
|
||||
@ -328,6 +366,7 @@ int main(int argc, char* argv[])
|
||||
|
||||
// Argument: Enable CUDA graph
|
||||
auto enableCudaGraph = result.count("enable_cuda_graph") > 0;
|
||||
auto printAllLogits = result.count("print_all_logits") > 0;
|
||||
|
||||
initTrtLlmPlugins(logger.get());
|
||||
|
||||
@ -335,7 +374,7 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
benchmarkGptSession(result["model"].as<std::string>(), result["engine_dir"].as<std::string>(), batchSizes,
|
||||
beamWidth, inOutLen, logger, result["warm_up"].as<int>(), result["num_runs"].as<int>(),
|
||||
result["duration"].as<int>(), sessionConfig, enableCudaGraph);
|
||||
result["duration"].as<int>(), sessionConfig, enableCudaGraph, printAllLogits);
|
||||
}
|
||||
catch (const std::exception& e)
|
||||
{
|
||||
|
||||
@ -353,7 +353,7 @@ _allowed_configs = {
|
||||
builder_opt=None,
|
||||
)),
|
||||
"chatglm_6b":
|
||||
ModelConfig(name="chatglm-6b",
|
||||
ModelConfig(name="chatglm_6b",
|
||||
family="chatglm",
|
||||
benchmark_type="gpt",
|
||||
build_config=BuildConfig(
|
||||
@ -370,7 +370,7 @@ _allowed_configs = {
|
||||
remove_input_padding=False,
|
||||
)),
|
||||
"chatglm2_6b":
|
||||
ModelConfig(name="chatglm2-6b",
|
||||
ModelConfig(name="chatglm2_6b",
|
||||
family="chatglm2",
|
||||
benchmark_type="gpt",
|
||||
build_config=BuildConfig(
|
||||
@ -387,7 +387,7 @@ _allowed_configs = {
|
||||
remove_input_padding=False,
|
||||
)),
|
||||
"chatglm3_6b":
|
||||
ModelConfig(name="chatglm3-6b",
|
||||
ModelConfig(name="chatglm3_6b",
|
||||
family="chatglm3",
|
||||
benchmark_type="gpt",
|
||||
build_config=BuildConfig(
|
||||
|
||||
@ -143,7 +143,7 @@ class GPTBenchmark(BaseBenchmark):
|
||||
quant_mode=self.quant_mode,
|
||||
use_custom_all_reduce=self.enable_custom_all_reduce,
|
||||
)
|
||||
if model_name == 'chatglm-6b':
|
||||
if model_name == 'chatglm_6b':
|
||||
self.sampling_config = tensorrt_llm.runtime.SamplingConfig(
|
||||
end_id=130005,
|
||||
pad_id=3,
|
||||
@ -152,16 +152,7 @@ class GPTBenchmark(BaseBenchmark):
|
||||
top_p=top_p)
|
||||
self.decoder = tensorrt_llm.runtime.ChatGLMGenerationSession(
|
||||
model_config, engine_buffer, self.runtime_mapping)
|
||||
elif model_name == 'chatglm2-6b':
|
||||
self.sampling_config = tensorrt_llm.runtime.SamplingConfig(
|
||||
end_id=2,
|
||||
pad_id=0,
|
||||
num_beams=num_beams,
|
||||
top_k=top_k,
|
||||
top_p=top_p)
|
||||
self.decoder = tensorrt_llm.runtime.GenerationSession(
|
||||
model_config, engine_buffer, self.runtime_mapping)
|
||||
elif model_name == 'chatglm3-6b':
|
||||
elif model_name in ['chatglm2_6b', 'chatglm3_6b']:
|
||||
self.sampling_config = tensorrt_llm.runtime.SamplingConfig(
|
||||
end_id=2,
|
||||
pad_id=0,
|
||||
@ -402,7 +393,7 @@ class GPTBenchmark(BaseBenchmark):
|
||||
apply_query_key_layer_scaling=builder_config.
|
||||
apply_query_key_layer_scaling,
|
||||
quant_mode=self.quant_mode,
|
||||
model_version="1")
|
||||
model_name="chatglm_6b")
|
||||
elif family == "chatglm2":
|
||||
tensorrt_llm_model = tensorrt_llm.models.ChatGLMHeadModel(
|
||||
num_layers=self.num_layers,
|
||||
@ -418,7 +409,7 @@ class GPTBenchmark(BaseBenchmark):
|
||||
apply_query_key_layer_scaling=builder_config.
|
||||
apply_query_key_layer_scaling,
|
||||
quant_mode=self.quant_mode,
|
||||
model_version="2")
|
||||
model_name="chatglm2_6b")
|
||||
elif family == "chatglm3":
|
||||
tensorrt_llm_model = tensorrt_llm.models.ChatGLMHeadModel(
|
||||
num_layers=self.num_layers,
|
||||
@ -434,7 +425,7 @@ class GPTBenchmark(BaseBenchmark):
|
||||
apply_query_key_layer_scaling=builder_config.
|
||||
apply_query_key_layer_scaling,
|
||||
quant_mode=self.quant_mode,
|
||||
model_version="3")
|
||||
model_name="chatglm3_6b")
|
||||
elif family == "bloom":
|
||||
tensorrt_llm_model = tensorrt_llm.models.BloomForCausalLM(
|
||||
num_layers=self.num_layers,
|
||||
@ -458,6 +449,7 @@ class GPTBenchmark(BaseBenchmark):
|
||||
max_position_embeddings=self.n_positions,
|
||||
dtype=kv_dtype,
|
||||
bias=self.bias,
|
||||
quant_mode=self.quant_mode,
|
||||
use_alibi=self.use_alibi,
|
||||
new_decoder_architecture=self.new_decoder_architecture,
|
||||
parallel_attention=self.parallel_attention,
|
||||
|
||||
@ -22,7 +22,7 @@ def get_memory_info(handle):
|
||||
version=pynvml.nvmlMemory_v2)
|
||||
total = round(mem_info.total / 1024 / 1024 / 1024, 2)
|
||||
used = round(mem_info.used / 1024 / 1024 / 1024, 2)
|
||||
free = round(mem_info.used / 1024 / 1024 / 1024, 2)
|
||||
free = round(mem_info.free / 1024 / 1024 / 1024, 2)
|
||||
return total, used, free
|
||||
|
||||
|
||||
|
||||
@ -237,6 +237,24 @@ if(WIN32)
|
||||
set(CMAKE_CXX_FLAGS "/DNOMINMAX ${CMAKE_CXX_FLAGS}")
|
||||
endif()
|
||||
|
||||
if((MSVC))
|
||||
if((MSVC_VERSION GREATER_EQUAL 1914))
|
||||
# MSVC does not apply the correct __cplusplus version per the C++ standard
|
||||
# by default. This is required for compiling CUTLASS 3.0 kernels on windows
|
||||
# with C++-17 constexpr enabled. The 2017 15.7 MSVC adds /Zc:__cplusplus to
|
||||
# set __cplusplus to 201703 with std=c++17. See
|
||||
# https://learn.microsoft.com/en-us/cpp/build/reference/zc-cplusplus for
|
||||
# more info.
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /Zc:__cplusplus")
|
||||
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xcompiler /Zc:__cplusplus")
|
||||
else()
|
||||
message(
|
||||
FATAL_ERROR
|
||||
"Build is only supported with Visual Studio 2017 version 15.7 or higher"
|
||||
)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-extended-lambda")
|
||||
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-relaxed-constexpr")
|
||||
if(FAST_MATH)
|
||||
|
||||
@ -121,10 +121,13 @@ private:
|
||||
inline static const std::string kMinLengthTensorName_ = "min_length";
|
||||
inline static const std::string kPresencePenaltyTensorName_ = "presence_penalty";
|
||||
inline static const std::string kRandomSeedTensorName_ = "random_seed";
|
||||
inline static const std::string kReturnLogProbsTensorName_ = "return_log_probs";
|
||||
inline static const std::string kPromptEmbeddingTableName_ = "prompt_embedding_table";
|
||||
inline static const std::string kPromptVocabSizeName_ = "prompt_vocab_size";
|
||||
inline static const std::string kOutputIdsTensorName_ = "output_ids";
|
||||
inline static const std::string kSequenceLengthTensorName_ = "sequence_length";
|
||||
inline static const std::string kLogProbsTensorName_ = "output_log_probs";
|
||||
inline static const std::string kCumLogProbsTensorName_ = "cum_log_probs";
|
||||
|
||||
std::shared_ptr<nvinfer1::ILogger> mLogger{};
|
||||
};
|
||||
|
||||
@ -309,8 +309,8 @@ public:
|
||||
}
|
||||
|
||||
[[nodiscard]] static SizeType getMaxNumTokens(KvCacheConfig const& config, nvinfer1::DataType dtype,
|
||||
tensorrt_llm::runtime::GptModelConfig const& modelConfig,
|
||||
tensorrt_llm::runtime::WorldConfig const& worldConfig);
|
||||
tensorrt_llm::runtime::GptModelConfig const& modelConfig, tensorrt_llm::runtime::WorldConfig const& worldConfig,
|
||||
runtime::BufferManager const& bufferManager);
|
||||
|
||||
private:
|
||||
void resetBlockPointers(SizeType batchSlotIdx, SizeType beamWidth);
|
||||
|
||||
@ -43,6 +43,7 @@ public:
|
||||
using TokenIdType = runtime::TokenIdType;
|
||||
using RequestIdType = std::uint64_t;
|
||||
using BeamTokens = std::vector<std::vector<TokenIdType>>;
|
||||
using VecLogProbs = std::vector<float>;
|
||||
using TensorPtr = runtime::ITensor::SharedPtr;
|
||||
|
||||
LlmRequest(RequestIdType requestId, SizeType maxNewTokens, std::shared_ptr<std::vector<TokenIdType>> input_tokens,
|
||||
@ -50,7 +51,7 @@ public:
|
||||
std::optional<SizeType> padId = std::nullopt, std::optional<TensorPtr> embeddingBias = std::nullopt,
|
||||
std::optional<TensorPtr> badWordsList = std::nullopt, std::optional<TensorPtr> stopWordsList = std::nullopt,
|
||||
std::optional<TensorPtr> promptEmbeddingTable = std::nullopt,
|
||||
std::optional<SizeType> promptVocabSize = std::nullopt)
|
||||
std::optional<SizeType> promptVocabSize = std::nullopt, bool returnLogProbs = false)
|
||||
: mRequestId(requestId)
|
||||
, mPromptLen(input_tokens->size())
|
||||
, mMaxNewTokens(maxNewTokens)
|
||||
@ -60,11 +61,15 @@ public:
|
||||
, mEndId(endId)
|
||||
, mPadId(padId)
|
||||
, mBatchSlot(-1)
|
||||
, mOrigPromptLen(input_tokens->size())
|
||||
, mEmbeddingBias(embeddingBias)
|
||||
, mBadWordsList(badWordsList)
|
||||
, mStopWordsList(stopWordsList)
|
||||
, mPromptEmbeddingTable(promptEmbeddingTable)
|
||||
, mPromptVocabSize(promptVocabSize)
|
||||
, mReturnLogProbs(returnLogProbs)
|
||||
, mLogProbs(samplingConfig.beamWidth)
|
||||
, mCumLogProbs(samplingConfig.beamWidth)
|
||||
{
|
||||
mMaxSentTokenPos = mPromptLen - 1;
|
||||
// Scatter the input tokens to other beam
|
||||
@ -168,17 +173,29 @@ public:
|
||||
// As a temporary solution, we currently reset the tokens to the prompt
|
||||
if (mSamplingConfig.beamWidth > 1)
|
||||
{
|
||||
for (auto& beamTokens : *mTokens)
|
||||
for (std::size_t beam = 0; beam < mTokens->size(); ++beam)
|
||||
{
|
||||
auto& beamTokens = mTokens->at(beam);
|
||||
beamTokens.resize(mPromptLen);
|
||||
if (mReturnLogProbs)
|
||||
{
|
||||
mLogProbs.at(beam).clear();
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
SizeType newPromptLen = std::min(maxInputLen, mPromptLen + getMaxNumGeneratedTokens());
|
||||
for (auto& beamTokens : *mTokens)
|
||||
for (std::size_t beam = 0; beam < mTokens->size(); ++beam)
|
||||
{
|
||||
auto& beamTokens = mTokens->at(beam);
|
||||
beamTokens.resize(newPromptLen);
|
||||
|
||||
if (mReturnLogProbs)
|
||||
{
|
||||
auto& logProb = mLogProbs.at(beam);
|
||||
logProb.resize(newPromptLen - mPromptLen);
|
||||
}
|
||||
}
|
||||
mMaxNewTokens -= (newPromptLen - mPromptLen);
|
||||
mPromptLen = newPromptLen;
|
||||
@ -187,16 +204,16 @@ public:
|
||||
mBatchSlot = -1;
|
||||
}
|
||||
|
||||
/// @brief Get the maximum position of the tokens returned to the client. Use to ensure we don't return to client
|
||||
/// duplicated token positions.
|
||||
/// @brief Get the maximum position of the tokens returned to the client. Use to ensure we don't return to
|
||||
/// client duplicated token positions.
|
||||
/// @return The maximum position of the tokens sent to the client
|
||||
SizeType getMaxSentTokenPos() const
|
||||
{
|
||||
return mMaxSentTokenPos;
|
||||
}
|
||||
|
||||
/// @brief Sets the maximum position of the tokens returned to the client. Use to ensure we don't return to client
|
||||
/// duplicated token positions.
|
||||
/// @brief Sets the maximum position of the tokens returned to the client. Use to ensure we don't return to
|
||||
/// client duplicated token positions.
|
||||
/// @param pos The maximum position
|
||||
void setMaxSentTokenPos(SizeType pos)
|
||||
{
|
||||
@ -243,6 +260,42 @@ public:
|
||||
return mStopWordsList;
|
||||
}
|
||||
|
||||
bool returnLogProbs() const
|
||||
{
|
||||
return mReturnLogProbs;
|
||||
}
|
||||
|
||||
std::vector<VecLogProbs> const& getLogProbs() const
|
||||
{
|
||||
return mLogProbs;
|
||||
}
|
||||
|
||||
VecLogProbs const& getLogProbs(SizeType beam) const
|
||||
{
|
||||
return mLogProbs.at(beam);
|
||||
}
|
||||
|
||||
void setLogProbs(VecLogProbs const& logProbs, SizeType beam)
|
||||
{
|
||||
mLogProbs.at(beam).resize(mPromptLen - mOrigPromptLen);
|
||||
mLogProbs.at(beam).insert(mLogProbs.at(beam).end(), logProbs.begin(), logProbs.end());
|
||||
}
|
||||
|
||||
VecLogProbs const& getCumLogProbs() const
|
||||
{
|
||||
return mCumLogProbs;
|
||||
}
|
||||
|
||||
void setCumLogProb(float cumLogProb, SizeType beam)
|
||||
{
|
||||
mCumLogProbs.at(beam) = cumLogProb;
|
||||
}
|
||||
|
||||
SizeType getOrigPromptLen() const
|
||||
{
|
||||
return mOrigPromptLen;
|
||||
}
|
||||
|
||||
RequestIdType mRequestId;
|
||||
SizeType mPromptLen;
|
||||
SizeType mMaxNewTokens;
|
||||
@ -255,6 +308,7 @@ public:
|
||||
SizeType mBatchSlot;
|
||||
|
||||
private:
|
||||
SizeType mOrigPromptLen;
|
||||
std::shared_ptr<BeamTokens> mTokens;
|
||||
SizeType mMaxSentTokenPos;
|
||||
|
||||
@ -264,6 +318,11 @@ private:
|
||||
|
||||
std::optional<TensorPtr> mPromptEmbeddingTable;
|
||||
std::optional<SizeType> mPromptVocabSize;
|
||||
|
||||
bool mReturnLogProbs;
|
||||
|
||||
std::vector<VecLogProbs> mLogProbs; // [beamSize, seqLen]
|
||||
VecLogProbs mCumLogProbs; // [beamSize]
|
||||
};
|
||||
|
||||
} // namespace tensorrt_llm::batch_manager
|
||||
|
||||
@ -147,9 +147,33 @@ public:
|
||||
//! \brief Get the underlying cuda stream.
|
||||
[[nodiscard]] CudaStream const& getStream() const;
|
||||
|
||||
//! \brief The current size of the memory reserved by the memory pool.
|
||||
[[nodiscard]] std::size_t memoryPoolReserved() const;
|
||||
|
||||
//! \brief The current size of the memory used by the memory pool.
|
||||
[[nodiscard]] std::size_t memoryPoolUsed() const;
|
||||
|
||||
//! \brief The current size of the memory free in the memory pool.
|
||||
[[nodiscard]] std::size_t memoryPoolFree() const;
|
||||
|
||||
//! \brief Try to trim the memory reserved by the pool to `size` bytes. This synchronizes implicitly with the
|
||||
//! stream.
|
||||
void memoryPoolTrimTo(std::size_t size);
|
||||
|
||||
private:
|
||||
void static initMemoryPool(int device);
|
||||
|
||||
std::size_t static memoryPoolReserved(int device);
|
||||
|
||||
std::size_t static memoryPoolUsed(int device);
|
||||
|
||||
std::size_t static memoryPoolFree(int device)
|
||||
{
|
||||
return memoryPoolReserved(device) - memoryPoolUsed(device);
|
||||
}
|
||||
|
||||
void static memoryPoolTrimTo(int device, std::size_t size);
|
||||
|
||||
CudaStreamPtr mStream;
|
||||
};
|
||||
|
||||
|
||||
@ -70,9 +70,9 @@ public:
|
||||
TensorPtr finished; // [batchSize, beamWidth], mandatory in beam search and to determine whether to stop
|
||||
// according to DecodingInput.sequenceLimitLength, on gpu
|
||||
TensorPtr finishedSum; // [1], the sum of finished sequences, in pinned memory
|
||||
TensorPtr logProbs; // [maxNewTokens, batchSize, beamWidth], must be float*, on gpu
|
||||
|
||||
// mandatory parameters for beam search
|
||||
TensorPtr logProbs; // [batchSize, beamWidth, maxSeqLen], must be float*, on gpu
|
||||
TensorPtr cumLogProbs; // [batchSize, beamWidth], optional for sampling, on gpu
|
||||
TensorPtr parentIds; // [batchSize, beamWidth, maxSeqLen], on gpu
|
||||
TensorPtr lengths; // [batchSize, beamWidth], total sequence lengths including padding, on gpu
|
||||
|
||||
@ -46,8 +46,10 @@ public:
|
||||
TensorPtr lengths; // [batchSize, beamWidth]
|
||||
|
||||
// optional parameters
|
||||
TensorPtr logProbs; // [request_output_length, batch_size * beam_width], must be float*, on gpu
|
||||
TensorPtr contextLogits; // [batch_size, max_input_length, vocab_size_padded]
|
||||
TensorPtr cumLogProbs; // [batchSize, beamWidth], must be float*, on gpu
|
||||
TensorPtr logProbs; // [batchSize, beamWidth, maxInputLength + maxNewTokens], must be float*, on gpu
|
||||
TensorPtr contextLogits; // [batch_size, max_input_length, vocab_size_padded]
|
||||
TensorPtr generationLogits; // [batch_size, beam_width, max_output_length-1, vocab_size_padded]
|
||||
|
||||
// callbacks
|
||||
Callback onTokenGenerated;
|
||||
|
||||
@ -45,14 +45,15 @@ class IGptDecoder
|
||||
public:
|
||||
virtual ~IGptDecoder() = default;
|
||||
|
||||
virtual void setup(SamplingConfig const& samplingConfig, size_t batchSize) = 0;
|
||||
virtual void setup(SamplingConfig const& samplingConfig, size_t batchSize, SizeType maxSequenceLength) = 0;
|
||||
|
||||
virtual bool forward(DecodingOutput& output, DecodingInput const& input) = 0;
|
||||
|
||||
virtual void forwardAsync(DecodingOutput& output, DecodingInput const& input) = 0;
|
||||
|
||||
static void gatherTree(ITensor& finalOutputIds, DecodingOutput const& decodingOutput,
|
||||
DecodingInput const& decodingInput, BufferManager const& manager);
|
||||
virtual void gatherTree(ITensor& finalOutputIds, DecodingOutput const& decodingOutput,
|
||||
DecodingInput const& decodingInput, BufferManager const& manager)
|
||||
= 0;
|
||||
|
||||
static std::unique_ptr<IGptDecoder> create(
|
||||
nvinfer1::DataType dtype, size_t vocabSize, size_t vocabSizePadded, BufferManager::CudaStreamPtr const& stream);
|
||||
@ -64,19 +65,27 @@ class GptDecoder : public virtual IGptDecoder
|
||||
|
||||
public:
|
||||
using CudaStreamPtr = BufferManager::CudaStreamPtr;
|
||||
using TensorPtr = std::shared_ptr<ITensor>;
|
||||
|
||||
GptDecoder(size_t vocabSize, size_t vocabSizePadded, CudaStreamPtr const& stream);
|
||||
|
||||
void setup(SamplingConfig const& samplingConfig, size_t batchSize) override;
|
||||
void setup(SamplingConfig const& samplingConfig, size_t batchSize, SizeType maxSequenceLength) override;
|
||||
|
||||
bool forward(DecodingOutput& output, DecodingInput const& input) override;
|
||||
|
||||
void forwardAsync(DecodingOutput& output, DecodingInput const& input) override;
|
||||
|
||||
void gatherTree(ITensor& finalOutputIds, DecodingOutput const& decodingOutput, DecodingInput const& decodingInput,
|
||||
BufferManager const& manager) override;
|
||||
|
||||
private:
|
||||
BufferManager mManager;
|
||||
|
||||
common::CudaAllocator mAllocator;
|
||||
std::shared_ptr<tensorrt_llm::layers::DynamicDecodeLayer<T>> mDynamicDecodeLayer;
|
||||
|
||||
TensorPtr mLogProbsTiled; // Buffer used to store the transpose of the logProbs. Needed because the kernels have
|
||||
// been written to use that shape.
|
||||
};
|
||||
|
||||
inline std::unique_ptr<IGptDecoder> IGptDecoder::create(
|
||||
|
||||
@ -20,6 +20,7 @@
|
||||
#include "tensorrt_llm/runtime/bufferManager.h"
|
||||
#include "tensorrt_llm/runtime/cudaEvent.h"
|
||||
#include "tensorrt_llm/runtime/cudaStream.h"
|
||||
#include "tensorrt_llm/runtime/generationOutput.h"
|
||||
#include "tensorrt_llm/runtime/gptDecoder.h"
|
||||
#include "tensorrt_llm/runtime/iGptDecoderBatch.h"
|
||||
#include "tensorrt_llm/runtime/iTensor.h"
|
||||
@ -51,7 +52,8 @@ public:
|
||||
void newRequest(
|
||||
SizeType batchIdx, decoder_batch::Request const& request, SamplingConfig const& samplingConfig) override;
|
||||
|
||||
void newBatch(GenerationInput const& inputs, SamplingConfig const& samplingConfig) override;
|
||||
void newBatch(
|
||||
GenerationInput const& inputs, GenerationOutput const& outputs, SamplingConfig const& samplingConfig) override;
|
||||
|
||||
TokenPtr forwardAsync(decoder_batch::Output& output, decoder_batch::Input const& input) override;
|
||||
|
||||
@ -85,14 +87,10 @@ public:
|
||||
|
||||
//! @brief Gather final beam search results for request `batchIdx`.
|
||||
//! Result will only be available after event returned.
|
||||
//! @returns [maxBeamWidth, maxInputLength + maxNewTokens], contains input token ids and generated token ids without
|
||||
//! padding for request `batchIdx`, on gpu
|
||||
[[nodiscard]] std::tuple<CudaEvent, TensorPtr> getFinalOutputIds(SizeType batchIdx) const override;
|
||||
[[nodiscard]] CudaEvent finalize(SizeType batchIdx) const;
|
||||
|
||||
//! @brief Gather final beam search results for all requests.
|
||||
//! @returns [batchSize, maxBeamWidth, maxInputLength + maxNewTokens], contains input token ids and generated token
|
||||
//! ids without padding, on gpu
|
||||
[[nodiscard]] TensorPtr getFinalOutputIds() const override;
|
||||
void finalize() const override;
|
||||
|
||||
//! @returns [batchSize, maxBeamWidth, maxInputLength + maxNewTokens], contains parent ids collected during beam
|
||||
//! search without padding, on gpu
|
||||
@ -119,6 +117,28 @@ public:
|
||||
return ITensor::slice(mJointDecodingOutput->cumLogProbs, 0, mActualBatchSize);
|
||||
}
|
||||
|
||||
//! @returns [maxBeamWidth], cumulative log probabilities (per beam), on gpu
|
||||
[[nodiscard]] TensorPtr getCumLogProbs(SizeType batchIdx) const
|
||||
{
|
||||
auto tensor = ITensor::slice(mJointDecodingOutput->cumLogProbs, batchIdx, 1);
|
||||
tensor->squeeze(0);
|
||||
return tensor;
|
||||
}
|
||||
|
||||
//! @returns [batchSize, maxBeamWidth, maxSequenceLength], log probabilities (per beam), on gpu
|
||||
[[nodiscard]] TensorPtr getLogProbs() const override
|
||||
{
|
||||
return ITensor::slice(mJointDecodingOutput->logProbs, 0, mActualBatchSize);
|
||||
}
|
||||
|
||||
//! @returns [maxBeamWidth, maxSequenceLength], log probabilities (per beam), on gpu
|
||||
[[nodiscard]] TensorPtr getLogProbs(SizeType batchIdx) const
|
||||
{
|
||||
auto tensor = ITensor::slice(mJointDecodingOutput->logProbs, batchIdx, 1);
|
||||
tensor->squeeze(0);
|
||||
return tensor;
|
||||
}
|
||||
|
||||
//! @returns [batchSize, maxBeamWidth], tokens generated in last forward pass, on gpu
|
||||
[[nodiscard]] TensorPtr getNewTokens() const override
|
||||
{
|
||||
|
||||
@ -50,6 +50,7 @@ public:
|
||||
, mMaxOutputLen(0)
|
||||
, mMaxNumTokens(std::nullopt)
|
||||
, mComputeContextLogits(false)
|
||||
, mComputeGenerationLogits(false)
|
||||
, mModelVariant(ModelVariant::kGpt)
|
||||
, mUseCustomAllReduce(false)
|
||||
, mMaxPromptEmbeddingTableSize(0)
|
||||
@ -222,6 +223,16 @@ public:
|
||||
mComputeContextLogits = computeContextLogits;
|
||||
}
|
||||
|
||||
[[nodiscard]] bool constexpr computeGenerationLogits() const noexcept
|
||||
{
|
||||
return mComputeGenerationLogits;
|
||||
}
|
||||
|
||||
void constexpr computeGenerationLogits(bool computeGenerationLogits) noexcept
|
||||
{
|
||||
mComputeGenerationLogits = computeGenerationLogits;
|
||||
}
|
||||
|
||||
[[nodiscard]] ModelVariant getModelVariant() const
|
||||
{
|
||||
return mModelVariant;
|
||||
@ -260,6 +271,7 @@ private:
|
||||
std::optional<SizeType> mMaxNumTokens;
|
||||
|
||||
bool mComputeContextLogits;
|
||||
bool mComputeGenerationLogits;
|
||||
ModelVariant mModelVariant;
|
||||
bool mUseCustomAllReduce;
|
||||
|
||||
|
||||
@ -63,6 +63,8 @@ class GptSession
|
||||
{
|
||||
using KvCacheManager = batch_manager::kv_cache_manager::KVCacheManager;
|
||||
using KvCacheConfig = batch_manager::kv_cache_manager::KvCacheConfig;
|
||||
using TensorPtr = runtime::ITensor::SharedPtr;
|
||||
using TokenGeneratedCallback = std::function<void(SizeType step, bool finished)>;
|
||||
|
||||
public:
|
||||
using LoggerPtr = std::shared_ptr<nvinfer1::ILogger>;
|
||||
@ -108,7 +110,7 @@ public:
|
||||
|
||||
[[nodiscard]] nvinfer1::ILogger& getLogger() const;
|
||||
|
||||
[[nodiscard]] BufferManager& getBufferManager() const;
|
||||
[[nodiscard]] BufferManager const& getBufferManager() const;
|
||||
|
||||
[[nodiscard]] GptModelConfig const& getModelConfig() const
|
||||
{
|
||||
@ -133,8 +135,9 @@ private:
|
||||
return !mCudaGraphInstances.empty();
|
||||
}
|
||||
|
||||
void generateBatched(GenerationOutput& outputs, std::vector<GenerationInput> const& microBatches,
|
||||
SamplingConfig const& samplingConfig);
|
||||
void generateBatched(std::vector<GenerationOutput>& microBatchesOutputs,
|
||||
std::vector<GenerationInput> const& microBatchesInputs, SamplingConfig const& samplingConfig,
|
||||
TokenGeneratedCallback const& onTokenGenerated);
|
||||
|
||||
void setup(Config const& sessionConfig);
|
||||
|
||||
@ -148,9 +151,9 @@ private:
|
||||
|
||||
void executeContextStep(std::vector<GenerationInput> const& microBatches,
|
||||
std::vector<SizeType> const& microBatchOffsets, KvCacheManager const* kvCacheManager);
|
||||
SizeType executeGenerationStep(SizeType step, std::vector<GenerationInput> const& microBatches,
|
||||
std::vector<SizeType> const& microBatchOffsets, KvCacheManager* kvCacheManager,
|
||||
std::vector<bool>& microBatchesFinished);
|
||||
SizeType executeGenerationStep(SizeType step, std::vector<GenerationInput> const& microBatchesInputs,
|
||||
std::vector<GenerationOutput>& microBatchesOutputs, std::vector<SizeType> const& microBatchOffsets,
|
||||
KvCacheManager* kvCacheManager, std::vector<bool>& microBatchesFinished);
|
||||
|
||||
//! @brief Execute decoder on last PP rank, receive decoder output on other PP ranks.
|
||||
void decoderStepAsync(SizeType decoderStep, SizeType microBatchId);
|
||||
@ -158,17 +161,17 @@ private:
|
||||
//! @brief Synchronize with the decoder and return the `shouldStop` flag.
|
||||
bool shouldStopSync(SizeType batchSize, SizeType beamWidth, SizeType microBatchId);
|
||||
|
||||
//! @brief Collect final output ids on last PP rank and send them to first PP rank.
|
||||
//! @brief Collect final output ids and log probs on last PP rank and send them to first PP rank.
|
||||
//! @details Receives are asynchronous on host, so synchronization is required before access.
|
||||
void finalizeOutputIds(SizeType microBatchId);
|
||||
void finalize(SizeType microBatchId);
|
||||
|
||||
void kvCacheAddSequences(SizeType beamWidth, SizeType microBatchId, SizeType firstBatchIdx);
|
||||
|
||||
//! @brief Populate outputIds and return reference to newTokens tensor
|
||||
ITensor::SharedPtr initDecoder(ITensor& outputIds, GenerationInput const& inputs,
|
||||
ITensor::SharedPtr initDecoder(ITensor& outputIds, GenerationInput const& inputs, GenerationOutput const& outputs,
|
||||
SamplingConfig const& samplingConfig, SizeType microBatchId) const;
|
||||
|
||||
std::function<void(SizeType step, bool finished)> createOnTokenGeneratedCallback(GenerationOutput& outputs);
|
||||
TokenGeneratedCallback createOnTokenGeneratedCallback(GenerationOutput& outputs);
|
||||
|
||||
class CudaGraphExecutor
|
||||
{
|
||||
|
||||
@ -68,84 +68,108 @@ struct MemoryTypeString<MemoryType::kPINNED>
|
||||
|
||||
//! \brief For converting a TensorRT data type to a C++ data type.
|
||||
template <nvinfer1::DataType kDataType, bool kIsUnsigned = false, bool kIsPointer = false>
|
||||
struct CppDataType
|
||||
struct DataTypeTraits
|
||||
{
|
||||
};
|
||||
|
||||
template <>
|
||||
struct CppDataType<nvinfer1::DataType::kFLOAT>
|
||||
struct DataTypeTraits<nvinfer1::DataType::kFLOAT>
|
||||
{
|
||||
using type = float;
|
||||
static char constexpr name[] = "float";
|
||||
static auto constexpr size = sizeof(type);
|
||||
};
|
||||
|
||||
template <>
|
||||
struct CppDataType<nvinfer1::DataType::kHALF>
|
||||
struct DataTypeTraits<nvinfer1::DataType::kHALF>
|
||||
{
|
||||
using type = half;
|
||||
static char constexpr name[] = "half";
|
||||
static auto constexpr size = sizeof(type);
|
||||
};
|
||||
|
||||
template <>
|
||||
struct CppDataType<nvinfer1::DataType::kINT8>
|
||||
struct DataTypeTraits<nvinfer1::DataType::kINT8>
|
||||
{
|
||||
using type = std::int8_t;
|
||||
static char constexpr name[] = "int8";
|
||||
static auto constexpr size = sizeof(type);
|
||||
};
|
||||
|
||||
template <>
|
||||
struct CppDataType<nvinfer1::DataType::kINT32>
|
||||
struct DataTypeTraits<nvinfer1::DataType::kINT32>
|
||||
{
|
||||
using type = std::int32_t;
|
||||
static char constexpr name[] = "int32";
|
||||
static auto constexpr size = sizeof(type);
|
||||
};
|
||||
|
||||
template <>
|
||||
struct CppDataType<nvinfer1::DataType::kINT64>
|
||||
struct DataTypeTraits<nvinfer1::DataType::kINT64>
|
||||
{
|
||||
using type = std::int64_t;
|
||||
static char constexpr name[] = "int64";
|
||||
static auto constexpr size = sizeof(type);
|
||||
};
|
||||
|
||||
template <>
|
||||
struct CppDataType<nvinfer1::DataType::kINT32, true>
|
||||
struct DataTypeTraits<nvinfer1::DataType::kINT32, true>
|
||||
{
|
||||
using type = std::uint32_t;
|
||||
static char constexpr name[] = "uint32";
|
||||
static auto constexpr size = sizeof(type);
|
||||
};
|
||||
|
||||
template <>
|
||||
struct CppDataType<nvinfer1::DataType::kINT64, true>
|
||||
struct DataTypeTraits<nvinfer1::DataType::kINT64, true>
|
||||
{
|
||||
using type = std::uint64_t;
|
||||
static char constexpr name[] = "uint64";
|
||||
static auto constexpr size = sizeof(type);
|
||||
};
|
||||
|
||||
template <bool kUnsigned>
|
||||
struct CppDataType<nvinfer1::DataType::kBOOL, kUnsigned>
|
||||
struct DataTypeTraits<nvinfer1::DataType::kBOOL, kUnsigned>
|
||||
{
|
||||
using type = bool;
|
||||
static char constexpr name[] = "bool";
|
||||
static auto constexpr size = sizeof(type);
|
||||
};
|
||||
|
||||
template <bool kUnsigned>
|
||||
struct CppDataType<nvinfer1::DataType::kUINT8, kUnsigned>
|
||||
struct DataTypeTraits<nvinfer1::DataType::kUINT8, kUnsigned>
|
||||
{
|
||||
using type = std::uint8_t;
|
||||
static char constexpr name[] = "uint8";
|
||||
static auto constexpr size = sizeof(type);
|
||||
};
|
||||
|
||||
#ifdef ENABLE_BF16
|
||||
template <>
|
||||
struct CppDataType<nvinfer1::DataType::kBF16>
|
||||
struct DataTypeTraits<nvinfer1::DataType::kBF16>
|
||||
{
|
||||
using type = __nv_bfloat16;
|
||||
static char constexpr name[] = "bfloat16";
|
||||
static auto constexpr size = sizeof(type);
|
||||
};
|
||||
#endif
|
||||
|
||||
#ifdef ENABLE_FP8
|
||||
template <>
|
||||
struct CppDataType<nvinfer1::DataType::kFP8>
|
||||
struct DataTypeTraits<nvinfer1::DataType::kFP8>
|
||||
{
|
||||
using type = __nv_fp8_e4m3;
|
||||
static char constexpr name[] = "fp8";
|
||||
static auto constexpr size = sizeof(type);
|
||||
};
|
||||
#endif
|
||||
|
||||
template <nvinfer1::DataType kDataType, bool kUnsigned>
|
||||
struct CppDataType<kDataType, kUnsigned, true>
|
||||
struct DataTypeTraits<kDataType, kUnsigned, true>
|
||||
{
|
||||
using type = typename CppDataType<kDataType, kUnsigned, false>::type*;
|
||||
using type = typename DataTypeTraits<kDataType, kUnsigned, false>::type*;
|
||||
static char constexpr name[] = "*";
|
||||
static auto constexpr size = sizeof(type);
|
||||
};
|
||||
|
||||
//! \brief A wrapper around `nvinfer1::DataType` that provides a support for pointer types.
|
||||
@ -377,11 +401,15 @@ public:
|
||||
//!
|
||||
[[nodiscard]] virtual DataType getDataType() const = 0;
|
||||
|
||||
virtual char const* getDataTypeName() const;
|
||||
|
||||
//!
|
||||
//! \brief Returns the memory type of the buffer.
|
||||
//!
|
||||
[[nodiscard]] virtual MemoryType getMemoryType() const = 0;
|
||||
|
||||
virtual char const* getMemoryTypeName() const;
|
||||
|
||||
//!
|
||||
//! \brief Resizes the buffer. This is a no-op if the new size is smaller than or equal to the current capacity.
|
||||
//!
|
||||
|
||||
@ -39,10 +39,12 @@ public:
|
||||
using TensorPtr = std::shared_ptr<ITensor>;
|
||||
|
||||
explicit Request(ConstTensorPtr ids, std::optional<SizeType> maxNewTokens = std::nullopt,
|
||||
std::optional<SizeType> endId = std::nullopt, std::optional<SizeType> padId = std::nullopt)
|
||||
std::optional<SizeType> endId = std::nullopt)
|
||||
: ids{std::move(ids)}
|
||||
, maxNewTokens{maxNewTokens}
|
||||
, endId{endId}
|
||||
, computeCumLogProbs(false)
|
||||
, computeLogProbs(false)
|
||||
{
|
||||
}
|
||||
|
||||
@ -55,6 +57,9 @@ public:
|
||||
TensorPtr embeddingBias; // [vocabSizePadded], on gpu
|
||||
TensorPtr badWordsList; // [2, badWordsLength], on gpu
|
||||
TensorPtr stopWordsList; // [2, stopWordsLength], on gpu
|
||||
|
||||
bool computeCumLogProbs; // boolean that controls if cumLogProbs should be computed for that request
|
||||
bool computeLogProbs; // boolean that controls if cumLogProbs should be computed for that request
|
||||
};
|
||||
|
||||
class Input : public decoder::Input
|
||||
@ -128,9 +133,7 @@ public:
|
||||
|
||||
//! @brief Gather final beam search results for request `batchIdx`.
|
||||
//! Result will only be available after event returned
|
||||
//! @returns [maxBeamWidth, maxInputLength + maxNewTokens], contains input token ids and generated token ids without
|
||||
//! padding for request `batchIdx`, on gpu
|
||||
virtual std::tuple<CudaEvent, TensorPtr> getFinalOutputIds(SizeType batchIdx) const = 0;
|
||||
virtual CudaEvent finalize(SizeType batchIdx) const = 0;
|
||||
|
||||
//! @returns [batchSize, beamWidth], marks finished requests (per beam), on gpu
|
||||
virtual TensorPtr getFinishedBeams() const = 0;
|
||||
@ -144,6 +147,15 @@ public:
|
||||
//! @returns [batchSize, beamWidth], cumulative log probabilities (per beam), on gpu
|
||||
virtual TensorPtr getCumLogProbs() const = 0;
|
||||
|
||||
//! @returns [beamWidth], cumulative log probabilities (per beam) for request batchIdx, on gpu
|
||||
virtual TensorPtr getCumLogProbs(SizeType batchIdx) const = 0;
|
||||
|
||||
//! @returns [batchSize, beamWidth, maxSeqLen], log probabilities (per beam), on gpu
|
||||
virtual TensorPtr getLogProbs() const = 0;
|
||||
|
||||
//! @returns [beamWidth, maxSeqLen], cumulative log probabilities (per beam) for request batchIdx, on gpu
|
||||
virtual TensorPtr getLogProbs(SizeType batchIdx) const = 0;
|
||||
|
||||
virtual TensorPtr getParentIds() const = 0;
|
||||
|
||||
virtual std::vector<SizeType> getNbSteps() const = 0;
|
||||
|
||||
@ -18,6 +18,7 @@
|
||||
|
||||
#include "tensorrt_llm/runtime/cudaStream.h"
|
||||
#include "tensorrt_llm/runtime/generationInput.h"
|
||||
#include "tensorrt_llm/runtime/generationOutput.h"
|
||||
#include "tensorrt_llm/runtime/iTensor.h"
|
||||
#include "tensorrt_llm/runtime/samplingConfig.h"
|
||||
|
||||
@ -78,7 +79,9 @@ public:
|
||||
= 0;
|
||||
|
||||
//! @brief Initialize the decoder with new batch of inputs.
|
||||
virtual void newBatch(GenerationInput const& inputs, SamplingConfig const& samplingConfig) = 0;
|
||||
virtual void newBatch(
|
||||
GenerationInput const& inputs, GenerationOutput const& outputs, SamplingConfig const& samplingConfig)
|
||||
= 0;
|
||||
|
||||
//! @brief Run one step for all requests without blocking the host thread.
|
||||
virtual void forwardAsync(decoder::Output& output, decoder::Input const& input) = 0;
|
||||
@ -94,11 +97,17 @@ public:
|
||||
}
|
||||
|
||||
//! @brief Gather final beam search results for all requests.
|
||||
virtual TensorPtr getFinalOutputIds() const = 0;
|
||||
virtual void finalize() const = 0;
|
||||
|
||||
//! @returns [batchSize, beamWidth, maxSequenceLength], all token ids, on gpu
|
||||
virtual TensorPtr getOutputIds() const = 0;
|
||||
|
||||
//! @returns [batchSize, maxBeamWidth], cumulative log probabilities (per beam), on gpu
|
||||
virtual TensorPtr getCumLogProbs() const = 0;
|
||||
|
||||
//! @returns [batchSize, maxBeamWidth, maxSequenceLength], log probabilities (per beam), on gpu
|
||||
virtual TensorPtr getLogProbs() const = 0;
|
||||
|
||||
//! @returns [batchSize, beamWidth], latests generated tokens (per beam), on gpu
|
||||
virtual TensorPtr getNewTokens() const = 0;
|
||||
|
||||
|
||||
@ -38,22 +38,22 @@ add_subdirectory(runtime)
|
||||
set(BATCH_MANAGER_TARGET tensorrt_llm_batch_manager_static)
|
||||
set(BATCH_MANAGER_TARGET_ARCH "unknown")
|
||||
|
||||
execute_process(
|
||||
COMMAND grep -oP "(?<=^ID=).+" /etc/os-release
|
||||
COMMAND tr -d "\""
|
||||
COMMAND tr -d "\n"
|
||||
RESULT_VARIABLE _OS_ID_SUCCESS
|
||||
OUTPUT_VARIABLE OS_ID)
|
||||
execute_process(
|
||||
COMMAND grep -oP "(?<=^VERSION_ID=).+" /etc/os-release
|
||||
COMMAND tr -d "\""
|
||||
COMMAND tr -d "\n"
|
||||
RESULT_VARIABLE _OS_VERSION_ID_SUCCESS
|
||||
OUTPUT_VARIABLE OS_VERSION_ID)
|
||||
message(STATUS "Operating System: ${OS_ID}, ${OS_VERSION_ID}")
|
||||
|
||||
message(STATUS "CMAKE_SYSTEM_PROCESSOR: ${CMAKE_SYSTEM_PROCESSOR}")
|
||||
if(NOT WIN32) # Linux
|
||||
execute_process(
|
||||
COMMAND grep -oP "(?<=^ID=).+" /etc/os-release
|
||||
COMMAND tr -d "\""
|
||||
COMMAND tr -d "\n"
|
||||
RESULT_VARIABLE _OS_ID_SUCCESS
|
||||
OUTPUT_VARIABLE OS_ID)
|
||||
execute_process(
|
||||
COMMAND grep -oP "(?<=^VERSION_ID=).+" /etc/os-release
|
||||
COMMAND tr -d "\""
|
||||
COMMAND tr -d "\n"
|
||||
RESULT_VARIABLE _OS_VERSION_ID_SUCCESS
|
||||
OUTPUT_VARIABLE OS_VERSION_ID)
|
||||
message(STATUS "Operating System: ${OS_ID}, ${OS_VERSION_ID}")
|
||||
|
||||
if(CMAKE_SYSTEM_PROCESSOR MATCHES "x86_64")
|
||||
set(BATCH_MANAGER_TARGET_ARCH "x86_64-linux-gnu")
|
||||
elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "aarch64")
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:7a3ec9a8760d7b8ace53e420572aeb1b3607effc92fd56e13351fa4cbddbbb37
|
||||
size 1646420
|
||||
oid sha256:b867c2e048671eecc421244d436436782093baf02f0fd5d49232b3d3042e55ea
|
||||
size 1688216
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:114348de9f6d1b3fa147f4fbccede10b7dbe13da6c5c86e968bb56bf05f9ec5a
|
||||
size 1657852
|
||||
oid sha256:db433a13ec6a017638bbb97b53a98624ad675b395787c99054d48ab370f5e3a0
|
||||
size 1697778
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
0776a4d41c06192c4ca0409ad8b837de libtensorrt_llm_batch_manager_static.a
|
||||
c901725d5d278fd8d41f524f81fe5170 libtensorrt_llm_batch_manager_static.pre_cxx11.a
|
||||
b3330c65d9b23d4f20c2b8d5a7c24cd45c910cd4 commit
|
||||
81f472ac2b68edd03a0265299744347f libtensorrt_llm_batch_manager_static.a
|
||||
4e5e3bbdfffa6deb6a50c541a946ac7a libtensorrt_llm_batch_manager_static.pre_cxx11.a
|
||||
7edd8a21 commit
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:abdce9bc64cecddb39ed14809eefc8bcf7164524a6dd20ec7c8167229f3c22a3
|
||||
size 1557782
|
||||
oid sha256:681917aea11f45d83ba1429ded44ced97cb8ce5f54eb1c3fb3055bc342f0ffbf
|
||||
size 1600734
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:a9109b506e993a041ea238f992bec2a5064dffd9c0a7af10cca0d4d96c5047a9
|
||||
size 1557482
|
||||
oid sha256:d59b04e3229358ec2d9476b07f0361aa4a8539e543312c8952b690173040663d
|
||||
size 1598666
|
||||
|
||||
@ -1,2 +1,2 @@
|
||||
25d1ebdd5977208c25023329c621e970 libtensorrt_llm_batch_manager_static.a
|
||||
5cb1a7a13db34fcaee6b89fcdc1212ce libtensorrt_llm_batch_manager_static.pre_cxx11.a
|
||||
c9d5678a2ec347188457ad4a3a59d483 libtensorrt_llm_batch_manager_static.a
|
||||
53261d576d540ab330f2f2e1f8d99677 libtensorrt_llm_batch_manager_static.pre_cxx11.a
|
||||
|
||||
@ -16,6 +16,8 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "tensorrt_llm/runtime/utils/multiDeviceUtils.h"
|
||||
|
||||
#include <cstdlib>
|
||||
#include <memory>
|
||||
#include <mpi.h>
|
||||
@ -24,16 +26,7 @@
|
||||
#include <vector>
|
||||
|
||||
#define COMM_WORLD MpiComm(MPI_COMM_WORLD)
|
||||
#define MPICHECK(cmd) \
|
||||
do \
|
||||
{ \
|
||||
int e = cmd; \
|
||||
if (e != MPI_SUCCESS) \
|
||||
{ \
|
||||
printf("Failed: MPI error %s:%d '%d'\n", __FILE__, __LINE__, e); \
|
||||
exit(EXIT_FAILURE); \
|
||||
} \
|
||||
} while (0)
|
||||
#define MPICHECK(cmd) TLLM_MPI_CHECK(cmd)
|
||||
|
||||
// A wrapper module of the MPI library.
|
||||
namespace tensorrt_llm::mpi
|
||||
|
||||
@ -92,31 +92,18 @@ namespace threadblock
|
||||
namespace detail
|
||||
{
|
||||
|
||||
/// Partial specialization for half <= int32_t x 8 epilogues avoids shared memory bank conflicts.
|
||||
template <typename ThreadblockShape, typename WarpShape, typename InstructionShape, typename ThreadMap>
|
||||
struct DefaultIteratorsTensorOp<cutlass::half_t, int32_t, 8, ThreadblockShape, WarpShape, InstructionShape, ThreadMap>
|
||||
{
|
||||
|
||||
using WarpTileIterator
|
||||
= cutlass::epilogue::warp::TileIteratorTensorOp<WarpShape, InstructionShape, int32_t, layout::RowMajor>;
|
||||
|
||||
using SharedLoadIterator = cutlass::epilogue::threadblock::SharedLoadIterator<ThreadMap, int32_t>;
|
||||
|
||||
static int const kFragmentsPerIteration = 1;
|
||||
};
|
||||
|
||||
/// Partial specialization for bfloat16_t <= int32_t x 8 epilogues avoids shared memory bank conflicts.
|
||||
template <typename ThreadblockShape, typename WarpShape, typename InstructionShape, typename ThreadMap>
|
||||
struct DefaultIteratorsTensorOp<cutlass::bfloat16_t, int32_t, 8, ThreadblockShape, WarpShape, InstructionShape,
|
||||
ThreadMap>
|
||||
{
|
||||
|
||||
using WarpTileIterator
|
||||
= cutlass::epilogue::warp::TileIteratorTensorOp<WarpShape, InstructionShape, int32_t, layout::RowMajor>;
|
||||
= cutlass::epilogue::warp::TileIteratorTensorOpMixed<WarpShape, InstructionShape, int32_t, 32, 16, 8, 8>;
|
||||
|
||||
using SharedLoadIterator = cutlass::epilogue::threadblock::SharedLoadIterator<ThreadMap, int32_t>;
|
||||
using SharedLoadIterator
|
||||
= cutlass::epilogue::threadblock::SharedLoadIteratorMixed<ThreadMap, int32_t, 32, 16, 8, 8>;
|
||||
|
||||
static int const kFragmentsPerIteration = 1;
|
||||
static int const kFragmentsPerIteration = 2;
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -0,0 +1,438 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*!
|
||||
\file
|
||||
\brief The universal GEMM accommodates serial reductions, parallel reductions, batched strided, and
|
||||
batched array variants.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
// #include <limits>
|
||||
|
||||
#include "cutlass/arch/arch.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/device_kernel.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.h"
|
||||
#include "cutlass/gemm/threadblock/threadblock_swizzle.h"
|
||||
|
||||
#include "cutlass/gemm/device/default_gemm_configuration.h"
|
||||
#include "cutlass/gemm/kernel/default_gemm_universal.h"
|
||||
|
||||
#include "cutlass/trace.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass
|
||||
{
|
||||
namespace gemm
|
||||
{
|
||||
namespace device
|
||||
{
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/*
|
||||
This is the device layer from CUTLASS 2.10 (SHA - cc85b64cf676c45f98a17e3a47c0aafcf817f088)
|
||||
It is replicated here since we needed to duplicate kernel level APIs for mixed dtype GEMMs
|
||||
and SmoothQuant. The newer device layer is not compatible with these older kernel level APIs.
|
||||
|
||||
Note: While CUTLASS 3.x supports stream-k, none of the kernels in the extensions folder support
|
||||
that feature at the moment.
|
||||
*/
|
||||
|
||||
template <typename GemmKernel_>
|
||||
class GemmUniversalBaseCompat
|
||||
{
|
||||
public:
|
||||
using GemmKernel = GemmKernel_;
|
||||
using ThreadblockShape = typename GemmKernel::Mma::Shape;
|
||||
|
||||
using ElementA = typename GemmKernel::ElementA;
|
||||
using LayoutA = typename GemmKernel::LayoutA;
|
||||
using TensorRefA = TensorRef<ElementA const, LayoutA>;
|
||||
static ComplexTransform const kTransformA = GemmKernel::kTransformA;
|
||||
|
||||
using ElementB = typename GemmKernel::ElementB;
|
||||
using LayoutB = typename GemmKernel::LayoutB;
|
||||
using TensorRefB = TensorRef<ElementB const, LayoutB>;
|
||||
static ComplexTransform const kTransformB = GemmKernel::kTransformB;
|
||||
|
||||
using ElementC = typename GemmKernel::ElementC;
|
||||
using LayoutC = typename GemmKernel::LayoutC;
|
||||
using TensorRefC = TensorRef<ElementC const, LayoutC>;
|
||||
using TensorRefD = TensorRef<ElementC, LayoutC>;
|
||||
|
||||
using ElementAccumulator = typename GemmKernel::Mma::Policy::Operator::ElementC;
|
||||
|
||||
using EpilogueOutputOp = typename GemmKernel::EpilogueOutputOp;
|
||||
using ThreadblockSwizzle = typename GemmKernel::ThreadblockSwizzle;
|
||||
using Operator = typename GemmKernel::Operator;
|
||||
|
||||
/// Argument structure
|
||||
using Arguments = typename GemmKernel::Arguments;
|
||||
|
||||
protected:
|
||||
/// Kernel parameters object
|
||||
typename GemmKernel::Params params_;
|
||||
|
||||
protected:
|
||||
/// Private helper to obtain the grid dimensions with fix-up for split-K
|
||||
static void get_grid_shape_(gemm::GemmCoord& grid_tiled_shape, int& gemm_k_size, Arguments const& args)
|
||||
{
|
||||
|
||||
// Determine grid shape
|
||||
ThreadblockSwizzle threadblock_swizzle;
|
||||
|
||||
grid_tiled_shape = threadblock_swizzle.get_tiled_shape(
|
||||
args.problem_size, {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, args.batch_count);
|
||||
|
||||
gemm_k_size = args.problem_size.k();
|
||||
|
||||
if (args.mode == GemmUniversalMode::kGemm || args.mode == GemmUniversalMode::kGemmSplitKParallel)
|
||||
{
|
||||
|
||||
int const kAlignK
|
||||
= const_max(const_max(128 / sizeof_bits<ElementA>::value, 128 / sizeof_bits<ElementB>::value), 1);
|
||||
|
||||
gemm_k_size = round_up(ceil_div(args.problem_size.k(), args.batch_count), kAlignK);
|
||||
|
||||
if (gemm_k_size)
|
||||
{
|
||||
grid_tiled_shape.k() = ceil_div(args.problem_size.k(), gemm_k_size);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
/// Constructs the GEMM.
|
||||
GemmUniversalBaseCompat() {}
|
||||
|
||||
/// Determines whether the GEMM can execute the given problem.
|
||||
static Status can_implement(Arguments const& args)
|
||||
{
|
||||
|
||||
// Determine grid shape
|
||||
cutlass::gemm::GemmCoord grid_tiled_shape;
|
||||
int gemm_k_size = 0;
|
||||
|
||||
get_grid_shape_(grid_tiled_shape, gemm_k_size, args);
|
||||
|
||||
ThreadblockSwizzle threadblock_swizzle;
|
||||
dim3 grid = threadblock_swizzle.get_grid_shape(grid_tiled_shape);
|
||||
|
||||
uint32_t const kGridYZMax = ((1 << (sizeof(uint16_t) * 8)) - 1);
|
||||
|
||||
if (!(grid.y <= kGridYZMax && grid.z <= kGridYZMax))
|
||||
{
|
||||
|
||||
return Status::kErrorInvalidProblem;
|
||||
}
|
||||
|
||||
return GemmKernel::can_implement(args);
|
||||
}
|
||||
|
||||
/// Gets the workspace size
|
||||
static size_t get_workspace_size(Arguments const& args)
|
||||
{
|
||||
|
||||
CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::get_workspace_size()");
|
||||
|
||||
size_t workspace_bytes = 0;
|
||||
|
||||
// Determine grid shape
|
||||
cutlass::gemm::GemmCoord grid_tiled_shape;
|
||||
int gemm_k_size = 0;
|
||||
|
||||
get_grid_shape_(grid_tiled_shape, gemm_k_size, args);
|
||||
|
||||
if (args.mode == GemmUniversalMode::kGemmSplitKParallel)
|
||||
{
|
||||
|
||||
// Split-K parallel always requires a temporary workspace
|
||||
workspace_bytes = sizeof(ElementC) * size_t(args.batch_stride_D) * size_t(grid_tiled_shape.k());
|
||||
}
|
||||
else if (args.mode == GemmUniversalMode::kGemm && grid_tiled_shape.k() > 1)
|
||||
{
|
||||
|
||||
// Serial split-K only requires a temporary workspace if the number of partitions along the
|
||||
// GEMM K dimension is greater than one.
|
||||
workspace_bytes = sizeof(int) * size_t(grid_tiled_shape.m()) * size_t(grid_tiled_shape.n());
|
||||
}
|
||||
|
||||
CUTLASS_TRACE_HOST(" workspace_bytes: " << workspace_bytes);
|
||||
|
||||
workspace_bytes += GemmKernel::get_extra_workspace_size(args, grid_tiled_shape);
|
||||
|
||||
return workspace_bytes;
|
||||
}
|
||||
|
||||
/// Computes the grid shape
|
||||
static dim3 get_grid_shape(Arguments const& args)
|
||||
{
|
||||
|
||||
CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::get_grid_shape()");
|
||||
|
||||
ThreadblockSwizzle threadblock_swizzle;
|
||||
|
||||
cutlass::gemm::GemmCoord grid_tiled_shape;
|
||||
int gemm_k_size = 0;
|
||||
|
||||
get_grid_shape_(grid_tiled_shape, gemm_k_size, args);
|
||||
dim3 result = threadblock_swizzle.get_grid_shape(grid_tiled_shape);
|
||||
|
||||
CUTLASS_TRACE_HOST(" grid_tiled_shape: " << grid_tiled_shape << "\n"
|
||||
<< " result = {" << result << "}");
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/// Computes the maximum number of active blocks per multiprocessor
|
||||
static int maximum_active_blocks(int smem_capacity = -1)
|
||||
{
|
||||
|
||||
CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::maximum_active_blocks()");
|
||||
|
||||
int max_active_blocks = -1;
|
||||
int smem_size = int(sizeof(typename GemmKernel::SharedStorage));
|
||||
|
||||
CUTLASS_TRACE_HOST(" smem_size: " << smem_size << " bytes");
|
||||
|
||||
if (smem_size <= (48 << 10))
|
||||
{
|
||||
|
||||
cudaError_t result = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
&max_active_blocks, Kernel<GemmKernel>, GemmKernel::kThreadCount, smem_size);
|
||||
|
||||
if (result == cudaSuccess)
|
||||
{
|
||||
CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks);
|
||||
return max_active_blocks;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
// Query assuming zero shared memory then compute occupancy limit based on SMEM
|
||||
cudaError_t result = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
&max_active_blocks, Kernel<GemmKernel>, GemmKernel::kThreadCount, 0);
|
||||
|
||||
if (result != cudaSuccess)
|
||||
{
|
||||
|
||||
CUTLASS_TRACE_HOST(
|
||||
" cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error " << cudaGetErrorString(result));
|
||||
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (smem_capacity < 0)
|
||||
{
|
||||
int device_idx = 0;
|
||||
result = cudaGetDevice(&device_idx);
|
||||
|
||||
if (result != cudaSuccess)
|
||||
{
|
||||
return -1;
|
||||
}
|
||||
|
||||
cudaDeviceProp properties;
|
||||
result = cudaGetDeviceProperties(&properties, device_idx);
|
||||
|
||||
if (result != cudaSuccess)
|
||||
{
|
||||
return -1;
|
||||
}
|
||||
|
||||
smem_capacity = static_cast<int>(properties.sharedMemPerMultiprocessor);
|
||||
}
|
||||
|
||||
int occupancy = std::min(max_active_blocks, smem_capacity / smem_size);
|
||||
|
||||
CUTLASS_TRACE_HOST(" occupancy: " << occupancy);
|
||||
|
||||
return occupancy;
|
||||
}
|
||||
|
||||
CUTLASS_TRACE_HOST(" returning internal error");
|
||||
|
||||
return -1;
|
||||
}
|
||||
|
||||
/// Initializes GEMM state from arguments.
|
||||
Status initialize(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr)
|
||||
{
|
||||
|
||||
CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::initialize() - workspace "
|
||||
<< workspace << ", stream: " << (stream ? "non-null" : "null"));
|
||||
|
||||
size_t workspace_bytes = get_workspace_size(args);
|
||||
|
||||
CUTLASS_TRACE_HOST(" workspace_bytes: " << workspace_bytes);
|
||||
|
||||
if (workspace_bytes)
|
||||
{
|
||||
|
||||
if (!workspace)
|
||||
{
|
||||
CUTLASS_TRACE_HOST(" error: device workspace must not be null");
|
||||
|
||||
return Status::kErrorWorkspaceNull;
|
||||
}
|
||||
|
||||
if (args.mode == GemmUniversalMode::kGemm)
|
||||
{
|
||||
CUTLASS_TRACE_HOST(" clearing device workspace");
|
||||
cudaError_t result = cudaMemsetAsync(workspace, 0, workspace_bytes, stream);
|
||||
|
||||
if (result != cudaSuccess)
|
||||
{
|
||||
CUTLASS_TRACE_HOST(" cudaMemsetAsync() returned error " << cudaGetErrorString(result));
|
||||
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Get CUDA grid shape
|
||||
cutlass::gemm::GemmCoord grid_tiled_shape;
|
||||
int gemm_k_size = 0;
|
||||
|
||||
get_grid_shape_(grid_tiled_shape, gemm_k_size, args);
|
||||
|
||||
// Initialize the Params structure
|
||||
params_ = typename GemmKernel::Params(args, grid_tiled_shape, gemm_k_size, static_cast<int*>(workspace));
|
||||
|
||||
// Specify shared memory capacity for kernel.
|
||||
int smem_size = int(sizeof(typename GemmKernel::SharedStorage));
|
||||
|
||||
if (smem_size >= (48 << 10))
|
||||
{
|
||||
cudaError_t result
|
||||
= cudaFuncSetAttribute(Kernel<GemmKernel>, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
|
||||
|
||||
if (result != cudaSuccess)
|
||||
{
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
}
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
/// Lightweight update given a subset of arguments
|
||||
Status update(Arguments const& args, void* workspace = nullptr)
|
||||
{
|
||||
|
||||
CUTLASS_TRACE_HOST("GemmUniversalBaseCompat()::update() - workspace: " << workspace);
|
||||
|
||||
size_t workspace_bytes = get_workspace_size(args);
|
||||
|
||||
if (workspace_bytes && !workspace)
|
||||
{
|
||||
return Status::kErrorWorkspaceNull;
|
||||
}
|
||||
|
||||
params_.update(args, workspace);
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
/// Runs the kernel using initialized state.
|
||||
Status run(cudaStream_t stream = nullptr)
|
||||
{
|
||||
CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::run()");
|
||||
|
||||
//
|
||||
// Configure grid and block dimensions
|
||||
//
|
||||
|
||||
ThreadblockSwizzle threadblock_swizzle;
|
||||
|
||||
dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape);
|
||||
dim3 block(GemmKernel::kThreadCount, 1, 1);
|
||||
|
||||
int smem_size = int(sizeof(typename GemmKernel::SharedStorage));
|
||||
|
||||
//
|
||||
// Launch kernel
|
||||
//
|
||||
|
||||
CUTLASS_TRACE_HOST(" grid: (" << grid << "), block: (" << block << "), SMEM: " << smem_size << " bytes");
|
||||
|
||||
// Launch
|
||||
cutlass::Kernel<GemmKernel><<<grid, block, smem_size, stream>>>(params_);
|
||||
|
||||
//
|
||||
// Query for errors
|
||||
//
|
||||
cudaError_t result = cudaGetLastError();
|
||||
|
||||
if (result != cudaSuccess)
|
||||
{
|
||||
CUTLASS_TRACE_HOST(" grid launch failed with error " << cudaGetErrorString(result));
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
/// Runs the kernel using initialized state.
|
||||
Status operator()(cudaStream_t stream = nullptr)
|
||||
{
|
||||
return run(stream);
|
||||
}
|
||||
|
||||
/// Runs the kernel using initialized state.
|
||||
Status operator()(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr)
|
||||
{
|
||||
|
||||
Status status = initialize(args, workspace, stream);
|
||||
|
||||
if (status == Status::kSuccess)
|
||||
{
|
||||
status = run(stream);
|
||||
}
|
||||
|
||||
return status;
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace device
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -18,6 +18,12 @@
|
||||
file(GLOB_RECURSE SRC_CPP *.cpp)
|
||||
file(GLOB_RECURSE SRC_CU *.cu)
|
||||
|
||||
# skip mmha 48, 80, 96, 112, 144, 160, 192 and 224 for fast build
|
||||
if(FAST_BUILD)
|
||||
list(FILTER SRC_CU EXCLUDE REGEX
|
||||
"decoderMaskedMultiheadAttention(48|80|96|112|144|160|192|224).*cu$")
|
||||
endif()
|
||||
|
||||
add_library(kernels_src OBJECT ${SRC_CPP} ${SRC_CU})
|
||||
set_property(TARGET kernels_src PROPERTY POSITION_INDEPENDENT_CODE ON)
|
||||
set_property(TARGET kernels_src PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)
|
||||
|
||||
@ -14,6 +14,7 @@
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <algorithm> // all_of
|
||||
#include <assert.h>
|
||||
|
||||
#include "tensorrt_llm/common/assert.h"
|
||||
|
||||
@ -19,9 +19,9 @@
|
||||
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
|
||||
#endif // #ifndef _WIN32
|
||||
|
||||
#include "cutlass/gemm/device/gemm_universal_base.h"
|
||||
#include "cutlass/gemm/kernel/default_gemm.h"
|
||||
#include "cutlass_extensions/compute_occupancy.h"
|
||||
#include "cutlass_extensions/gemm/device/gemm_universal_base_compat.h"
|
||||
|
||||
#include "cutlass_extensions/epilogue_helpers.h"
|
||||
#include "cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h"
|
||||
@ -124,7 +124,7 @@ void generic_mixed_gemm_kernelLauncher(const T* A, const WeightType* B, const T*
|
||||
return;
|
||||
}
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalBase<GemmKernel>;
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalBaseCompat<GemmKernel>;
|
||||
|
||||
const int ldb = cutlass::platform::is_same<cutlass::layout::RowMajor, typename MixedGemmArchTraits::LayoutB>::value
|
||||
? n
|
||||
|
||||
@ -22,12 +22,13 @@
|
||||
// clang-format off
|
||||
#include <cutlass/gemm/device/default_gemm_configuration.h>
|
||||
#include <cutlass/gemm/device/gemm.h>
|
||||
#include <cutlass/gemm/device/gemm_universal_base.h>
|
||||
#include <cutlass_extensions/gemm/device/gemm_universal_base_compat.h>
|
||||
#include <cutlass/gemm/kernel/default_gemm.h>
|
||||
#include <cutlass/epilogue/threadblock/epilogue_with_visitor.h>
|
||||
// clang-format on
|
||||
|
||||
#include "cutlass_extensions/compute_occupancy.h"
|
||||
#include "cutlass_extensions/epilogue/threadblock/epilogue_per_row_per_col_scale.h"
|
||||
#include "cutlass_extensions/epilogue/threadblock/epilogue_tensor_op_int32.h"
|
||||
#include "cutlass_extensions/epilogue_helpers.h"
|
||||
#include "cutlass_extensions/gemm_configs.h"
|
||||
@ -123,7 +124,7 @@ void genericInt8GemmKernelLauncher(const int8_t* A, const int8_t* B, tk::QuantMo
|
||||
return;
|
||||
}
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalBase<GemmKernel>;
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalBaseCompat<GemmKernel>;
|
||||
|
||||
typename EpilogueOp::Params linearScalingParams; // TODO: right now it's unused (scaling is done in
|
||||
// visitor, no activation needed)
|
||||
|
||||
@ -226,7 +226,7 @@ void mmha_launch_kernel_ex(
|
||||
}
|
||||
|
||||
// If blocks with larger block size already fill all SMs, then disable the multi blocks mode.
|
||||
mmha::multi_block_grid_setup<T, Dh>(grid, params, dynamic_block_size, available_blocks, tlength, DO_MULTI_BLOCK);
|
||||
mmha::multi_block_grid_setup<T, Dh>(grid, params, available_blocks, dynamic_block_size, tlength, DO_MULTI_BLOCK);
|
||||
|
||||
// Launch kernels based on the valid block size.
|
||||
switch (dynamic_block_size)
|
||||
|
||||
@ -1411,7 +1411,8 @@ __global__ void masked_multihead_attention_kernel(
|
||||
bool has_relative_attention_bias = params.relative_attention_bias != nullptr;
|
||||
// Compute relative attention bias on the fly, with relative attention table [head_num/TP, num_buckets] passed in.
|
||||
// num_buckets passed as relative_attention_bias_stride, max_distance passed as params.max_distance
|
||||
const bool implicit_rel_attn_bias = DO_CROSS_ATTENTION && params.max_distance != 0 && has_relative_attention_bias;
|
||||
// this is a common optimization for both self attention and cross attention
|
||||
const bool implicit_rel_attn_bias = params.max_distance != 0 && has_relative_attention_bias;
|
||||
int relative_attention_bias_stride
|
||||
= params.relative_attention_bias_stride; // num_buckets might be modified below, save it beforehand
|
||||
int max_distance = params.max_distance;
|
||||
@ -1693,12 +1694,15 @@ __global__ void masked_multihead_attention_kernel(
|
||||
|
||||
// Pre-compute the pointer for the relative attention bias.
|
||||
const T* relative_attention_bias_ptr = nullptr;
|
||||
const T* relative_attention_bias_ptr_fixed = nullptr; // record the base for offset
|
||||
if (has_relative_attention_bias)
|
||||
{
|
||||
// "hi" is unsigned, subtracting int from unsigned int causes underflow. Cast to int
|
||||
int64_t offset = implicit_rel_attn_bias
|
||||
? (hi * relative_attention_bias_stride - tlength)
|
||||
: (hi * relative_attention_bias_stride + tlength) * relative_attention_bias_stride;
|
||||
? ((int64_t) hi * relative_attention_bias_stride - tlength)
|
||||
: ((int64_t) hi * relative_attention_bias_stride + tlength) * relative_attention_bias_stride;
|
||||
relative_attention_bias_ptr = ¶ms.relative_attention_bias[offset];
|
||||
relative_attention_bias_ptr_fixed = ¶ms.relative_attention_bias[offset];
|
||||
}
|
||||
|
||||
// Load the value.
|
||||
@ -1706,7 +1710,7 @@ __global__ void masked_multihead_attention_kernel(
|
||||
if (has_relative_attention_bias && tidx == 0)
|
||||
{
|
||||
// TODO: Use a better way to convert from T to float.
|
||||
add(relative_attention_bias, relative_attention_bias_ptr[tlength]);
|
||||
relative_attention_bias = add(relative_attention_bias, relative_attention_bias_ptr[tlength]);
|
||||
}
|
||||
|
||||
// Store that value in shared memory. Keep the Q*K^T value in register for softmax.
|
||||
@ -1769,7 +1773,19 @@ __global__ void masked_multihead_attention_kernel(
|
||||
|
||||
// Pick a number of keys to make sure all the threads of a warp enter (due to shfl_sync).
|
||||
// Take all previous cache as context when we have no beam searching in order to batch as many LDGs as possible.
|
||||
const int context_length = HAS_BEAMS ? beam0_context_length : kv_loop_length;
|
||||
const int context_length
|
||||
= DO_CROSS_ATTENTION ? kv_loop_length : (HAS_BEAMS ? beam0_context_length : kv_loop_length);
|
||||
// Clarifications:
|
||||
// - in self attn, input_length is input text length, tlength is current timestep
|
||||
// - in cross attn, input_length is *decoder* input length (usually 1), tlength is *encoder* input context length
|
||||
// - in beam search, since the cache during generation is organized differently, the following KV compute needs
|
||||
// split into context cache compute and generation cache compute
|
||||
// - for self attn, no-beam search: entire cache can be treated as context cache --> context_length = tlength
|
||||
// - for self attn, beam search: cache of input text length is context cache, other are generation cache -->
|
||||
// context_length = input_length
|
||||
// - for cross attn, no-beam/beam search: cache length is fixed, not differ context/generation cache -->
|
||||
// context_length = tlength Suggestion: we could have a flag HANDLE_GEN_CACHE
|
||||
|
||||
const auto context_ti_end = MULTI_BLOCK_FLAG
|
||||
? divUp(timesteps_per_block, UNROLLED_K_PER_WARP) * UNROLLED_K_PER_WARP
|
||||
: divUp(static_cast<unsigned>(context_length), UNROLLED_K_PER_WARP) * UNROLLED_K_PER_WARP;
|
||||
@ -1872,7 +1888,7 @@ __global__ void masked_multihead_attention_kernel(
|
||||
relative_position_if_large = min(relative_position_if_large, num_buckets - 1);
|
||||
relative_buckets += is_small ? relative_position : relative_position_if_large;
|
||||
relative_attention_bias_ptr
|
||||
= relative_attention_bias_ptr + (tlength - local_time_now) + relative_buckets;
|
||||
= relative_attention_bias_ptr_fixed + (tlength - local_time_now) + relative_buckets;
|
||||
}
|
||||
|
||||
// Prefetch the relative attention bias.
|
||||
@ -1880,7 +1896,7 @@ __global__ void masked_multihead_attention_kernel(
|
||||
if (is_active && has_relative_attention_bias)
|
||||
{
|
||||
// TODO: Use a better way to convert from T to float.
|
||||
add(relative_attention_bias, relative_attention_bias_ptr[local_time_now]);
|
||||
relative_attention_bias = add(relative_attention_bias, relative_attention_bias_ptr[local_time_now]);
|
||||
}
|
||||
|
||||
// Compute the dot product between Q and K.
|
||||
@ -1937,7 +1953,9 @@ __global__ void masked_multihead_attention_kernel(
|
||||
|
||||
// Handle generation key cache with beam searching.
|
||||
// Note that it may be overlapped with the context key loop, but it won't impact the corretness.
|
||||
if (HAS_BEAMS && (!MULTI_BLOCK_FLAG || (c_tile + 1) * timesteps_per_block > beam0_context_length))
|
||||
// Can skip in cross attention mode.
|
||||
if (HAS_BEAMS && !DO_CROSS_ATTENTION
|
||||
&& (!MULTI_BLOCK_FLAG || (c_tile + 1) * timesteps_per_block > beam0_context_length))
|
||||
{
|
||||
// The input length;
|
||||
const int input_length_ = MULTI_BLOCK_FLAG ? beam0_context_length % timesteps_per_block : beam0_context_length;
|
||||
@ -1987,7 +2005,8 @@ __global__ void masked_multihead_attention_kernel(
|
||||
* (num_buckets - max_exact));
|
||||
relative_position_if_large = min(relative_position_if_large, num_buckets - 1);
|
||||
relative_buckets += is_small ? relative_position : relative_position_if_large;
|
||||
relative_attention_bias_ptr = relative_attention_bias_ptr + (tlength - time_now) + relative_buckets;
|
||||
relative_attention_bias_ptr
|
||||
= relative_attention_bias_ptr_fixed + (tlength - time_now) + relative_buckets;
|
||||
}
|
||||
|
||||
// Prefetch the relative attention bias.
|
||||
@ -1995,7 +2014,7 @@ __global__ void masked_multihead_attention_kernel(
|
||||
if (is_active && has_relative_attention_bias)
|
||||
{
|
||||
// TODO: Use a better way to convert from T to float.
|
||||
add(relative_attention_bias, relative_attention_bias_ptr[time_now]);
|
||||
relative_attention_bias = add(relative_attention_bias, relative_attention_bias_ptr[time_now]);
|
||||
}
|
||||
|
||||
// Perform the dot product and normalize qk.
|
||||
@ -2260,7 +2279,8 @@ __global__ void masked_multihead_attention_kernel(
|
||||
// Handle both context and generation value cache without beam searching.
|
||||
// Explicit batching of LDGs (by V_LOOP_UNROLL) as it doesn't depend on indirection tables.
|
||||
// Take all previous cache as context when we have no beam searching in order to batch as many LDGs as possible.
|
||||
const int context_length = HAS_BEAMS ? beam0_context_length : kv_loop_length;
|
||||
const int context_length
|
||||
= DO_CROSS_ATTENTION ? kv_loop_length : (HAS_BEAMS ? beam0_context_length : kv_loop_length);
|
||||
int context_v_loop_end = MULTI_BLOCK_FLAG ? timesteps_per_block : context_length;
|
||||
int generation_v_loop_end = MULTI_BLOCK_FLAG ? timesteps_per_block : kv_loop_length;
|
||||
for (int ti = vo; ti < context_v_loop_end; ti += UNROLLED_V_PER_ITER)
|
||||
@ -2300,7 +2320,7 @@ __global__ void masked_multihead_attention_kernel(
|
||||
}
|
||||
|
||||
// Handle generation value cache with beam searching.
|
||||
if (HAS_BEAMS)
|
||||
if (HAS_BEAMS && !DO_CROSS_ATTENTION)
|
||||
{
|
||||
const auto generation_start_ti
|
||||
= MULTI_BLOCK_FLAG ? vo : (vo + (beam0_context_length / V_PER_ITER) * V_PER_ITER);
|
||||
|
||||
@ -2226,6 +2226,8 @@ inline __device__ Float8_ mul(uint4 a, int64_t b)
|
||||
fc.y = mul<float2, uint32_t, float2>(a.y, make_float2(int8[2], int8[3]));
|
||||
fc.z = mul<float2, uint32_t, float2>(a.z, make_float2(int8[4], int8[5]));
|
||||
fc.w = mul<float2, uint32_t, float2>(a.w, make_float2(int8[6], int8[7]));
|
||||
|
||||
return fc;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -2247,6 +2249,8 @@ inline __device__ Float8_ mul(Float8_ fa, int64_t b)
|
||||
fc.y = mul<float2, float2, float2>(fa.y, make_float2(int8[2], int8[3]));
|
||||
fc.z = mul<float2, float2, float2>(fa.z, make_float2(int8[4], int8[5]));
|
||||
fc.w = mul<float2, float2, float2>(fa.w, make_float2(int8[6], int8[7]));
|
||||
|
||||
return fc;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -2323,6 +2327,8 @@ inline __device__ Float8_ mul(bf16_8_t a, int64_t b)
|
||||
fc.y = mul<float2, __nv_bfloat162, float2>(a.y, make_float2(int8[2], int8[3]));
|
||||
fc.z = mul<float2, __nv_bfloat162, float2>(a.z, make_float2(int8[4], int8[5]));
|
||||
fc.w = mul<float2, __nv_bfloat162, float2>(a.w, make_float2(int8[6], int8[7]));
|
||||
|
||||
return fc;
|
||||
}
|
||||
|
||||
#endif // ENABLE_BF16
|
||||
|
||||
@ -416,8 +416,13 @@ __global__ void finalize(int* output_ids, int* sequence_lengths, float* cum_log_
|
||||
= topk_output_ids[blockIdx.x * (beam_width * 2) * max_seq_len + s_rank[beam_idx] * max_seq_len + i];
|
||||
if (output_log_probs != nullptr)
|
||||
{
|
||||
output_log_probs[blockIdx.x * beam_width * max_seq_len + beam_idx * max_seq_len + i]
|
||||
= topk_log_probs[blockIdx.x * (beam_width * 2) * max_seq_len + s_rank[beam_idx] * max_seq_len + i];
|
||||
int input_len = input_lengths[blockIdx.x * beam_width + beam_idx];
|
||||
if (i >= input_len)
|
||||
{
|
||||
output_log_probs[blockIdx.x * beam_width * max_seq_len + beam_idx * max_seq_len + i - input_len]
|
||||
= topk_log_probs[blockIdx.x * (beam_width * 2) * max_seq_len + s_rank[beam_idx] * max_seq_len
|
||||
+ i];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -471,5 +476,32 @@ void invokeCopyNextStepIds(int* next_step_ids, int** output_ids_ptr, const int*
|
||||
next_step_ids, output_ids_ptr, sequence_lengths, batch_size, beam_width, max_seq_len);
|
||||
}
|
||||
|
||||
__global__ void transposeLogProbs(float* output_log_probs, float* output_log_probs_tiled, const int* sequence_lengths,
|
||||
int batch_size, int beam_width, int max_seq_len)
|
||||
{
|
||||
int index = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
|
||||
const int batch_idx = index / (beam_width * max_seq_len);
|
||||
const int tmp_idx = index % (beam_width * max_seq_len);
|
||||
const int beam_idx = tmp_idx / max_seq_len;
|
||||
const int pos = tmp_idx % max_seq_len;
|
||||
|
||||
if (batch_idx < batch_size && pos < sequence_lengths[batch_idx])
|
||||
{
|
||||
|
||||
output_log_probs[index]
|
||||
= output_log_probs_tiled[pos * batch_size * beam_width + batch_idx * beam_width + beam_idx];
|
||||
}
|
||||
}
|
||||
|
||||
void invokeTransposeLogProbs(float* output_log_probs, float* output_log_probs_tiled, const int* sequence_lengths,
|
||||
int batch_size, int beam_width, int max_seq_len, cudaStream_t stream)
|
||||
{
|
||||
dim3 block(256);
|
||||
dim3 grid(divUp(batch_size * beam_width * max_seq_len, block.x));
|
||||
transposeLogProbs<<<grid, block, 0, stream>>>(
|
||||
output_log_probs, output_log_probs_tiled, sequence_lengths, batch_size, beam_width, max_seq_len);
|
||||
}
|
||||
|
||||
} // namespace kernels
|
||||
} // namespace tensorrt_llm
|
||||
|
||||
@ -62,5 +62,8 @@ void invokeInitializeOutput(int* output_ids, const int* end_ids, int batch_beam,
|
||||
void invokeCopyNextStepIds(int* next_step_ids, int** output_ids_ptr, const int* sequence_lengths, int batch_size,
|
||||
int beam_width, int max_seq_len, cudaStream_t stream);
|
||||
|
||||
void invokeTransposeLogProbs(float* output_log_probs, float* output_log_probs_tiled, const int* sequence_lengths,
|
||||
int batch_size, int beam_width, int max_seq_len, cudaStream_t stream);
|
||||
|
||||
} // namespace kernels
|
||||
} // namespace tensorrt_llm
|
||||
|
||||
@ -195,12 +195,23 @@ __global__ void computeAttentionMask(AttentionMaskDataType* attentionMask, const
|
||||
isValid = (rowIdx < seqLength - 1 && colIdx < seqLength - 1) ||
|
||||
(rowIdx == seqLength - 1 && colIdx < seqLength);
|
||||
// clang-format on
|
||||
// seq_length==4, max_seq_len==5, only use in context phase
|
||||
// seq_length==4, max_seq_len==5
|
||||
// 1 1 1 0 0
|
||||
// 1 1 1 0 0
|
||||
// 1 1 1 0 0
|
||||
// 1 1 1 1 0
|
||||
// 0 0 0 0 0
|
||||
case AttentionMaskType::BIDIRECTIONALGLM:
|
||||
// clang-format off
|
||||
isValid = (colIdx < seqLength - 1) ||
|
||||
(rowIdx == maxSeqLength - 1 && colIdx == maxSeqLength - 1);
|
||||
// clang-format on
|
||||
// seq_length==4, max_seq_len==5
|
||||
// 1 1 1 1 0
|
||||
// 1 1 1 1 0
|
||||
// 1 1 1 1 0
|
||||
// 1 1 1 1 0
|
||||
// 1 1 1 1 1
|
||||
break;
|
||||
}
|
||||
|
||||
|
||||
@ -31,7 +31,10 @@ enum class AttentionMaskType
|
||||
// Mask the padded tokens and all the tokens that come after in a sequence.
|
||||
CAUSAL = 1,
|
||||
// See ChatGLM-6B mask.
|
||||
BIDIRECTIONAL = 2
|
||||
BIDIRECTIONAL = 2,
|
||||
// See GLM-10B mask.
|
||||
// TODO: merge this mask into BIDIRECTIONAL
|
||||
BIDIRECTIONALGLM = 3
|
||||
};
|
||||
|
||||
enum class PositionEmbeddingType : int8_t
|
||||
@ -58,7 +61,7 @@ struct BuildDecoderInfoParams
|
||||
{
|
||||
// The offsets to the 1st token in each sequence. Shape: [batchSize+1].
|
||||
int* seqOffsets;
|
||||
// The number of padded tokens in the corresponding padded tensor. Shape: [numTokens].
|
||||
// The number of padded tokens in the corresponding padded tensor before the current token. Shape: [numTokens].
|
||||
int* paddingOffsets;
|
||||
|
||||
// The mask to mark invalid tokens in Attention - that's not used by the plugins as it can be
|
||||
|
||||
@ -18,6 +18,7 @@
|
||||
#include "tensorrt_llm/common/assert.h"
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <limits>
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
@ -64,6 +65,11 @@ struct KVBlockArray
|
||||
const float tokensPerBlockSeqLog2 = log2(mTokensPerBlock);
|
||||
TLLM_CHECK_WITH_INFO(
|
||||
ceil(tokensPerBlockSeqLog2) == floor(tokensPerBlockSeqLog2), "tokensPerBlock must be power of 2");
|
||||
// NOTE: pointer offset arithmetic offset is performed on int32_t (see this.getRowPtr).
|
||||
// If needed, we could do it on uint32_t or even uint64_t, but that might have performance implications
|
||||
TLLM_CHECK_WITH_INFO(static_cast<int64_t>(mMaxSeqs - 1) * mMaxBlocksPerSeq * 2 + maxBlocksPerSeq
|
||||
<= std::numeric_limits<int32_t>::max(),
|
||||
"kv cache is too large for gpt_attention_plugin");
|
||||
mTokensPerBlockLog2 = static_cast<int>(tokensPerBlockSeqLog2);
|
||||
}
|
||||
|
||||
@ -140,6 +146,11 @@ struct KVLinearBuffer
|
||||
, mMaxSeqLen(tokensPerBlock)
|
||||
, mBytesPerSeq(tokensPerBlock * sizePerToken)
|
||||
{
|
||||
// NOTE: pointer offset arithmetic offset is performed on int32_t (see this.getRowPtr).
|
||||
// If needed, we could do it on uint32_t or even uint64_t, but that might have performance implications
|
||||
TLLM_CHECK_WITH_INFO(
|
||||
static_cast<int64_t>(mMaxSeqs - 1) * mBytesPerSeq * 2 + mBytesPerSeq <= std::numeric_limits<int32_t>::max(),
|
||||
"kv cache is too large for gpt_attention_plugin");
|
||||
}
|
||||
|
||||
__host__ __device__ inline void** getRowPtr(KVIdxType kvIdx, int32_t seqIdx)
|
||||
|
||||
@ -198,12 +198,10 @@ void dispatch_layernorm_type_square_method(const T* input, const T* gamma, const
|
||||
float* scale_orig_quant_per_token, int8_t* normed_output_quant, const dim3 grid, const dim3 block,
|
||||
const size_t shmem_size, cudaStream_t stream)
|
||||
{
|
||||
bool use_shmem = true;
|
||||
if (shmem_size >= (48 << 10))
|
||||
{
|
||||
cudaError_t ret = cudaFuncSetAttribute(
|
||||
generalLayerNorm<T, USE_DIFF_OF_SQUARES>, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size);
|
||||
use_shmem = ret == cudaSuccess;
|
||||
}
|
||||
generalLayerNorm<T, USE_DIFF_OF_SQUARES><<<grid, block, shmem_size, stream>>>(input, gamma, beta, normed_output,
|
||||
eps, tokens, hidden_dim, scale_orig_quant_per_tensor, scale_orig_quant_per_token, normed_output_quant, true);
|
||||
|
||||
@ -155,12 +155,10 @@ void dispatch_rmsnorm_type_square_method(const T* input, const T* gamma, const T
|
||||
float* scale_orig_quant_per_token, int8_t* normed_output_quant, const dim3 grid, const dim3 block,
|
||||
const size_t shmem_size, cudaStream_t stream)
|
||||
{
|
||||
bool use_shmem = true;
|
||||
if (shmem_size >= (48 << 10))
|
||||
{
|
||||
cudaError_t ret
|
||||
= cudaFuncSetAttribute(generalRmsNorm<T>, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size);
|
||||
use_shmem = ret == cudaSuccess;
|
||||
}
|
||||
generalRmsNorm<T><<<grid, block, shmem_size, stream>>>(input, gamma, beta, normed_output, eps, tokens, hidden_dim,
|
||||
scale_orig_quant_per_tensor, scale_orig_quant_per_token, normed_output_quant, true);
|
||||
|
||||
@ -203,7 +203,6 @@ __global__ void topPSsampling(T* sortedLogProbs, int* sortedIdVals, int** ids, i
|
||||
* output.
|
||||
*/
|
||||
|
||||
__shared__ int stopShared;
|
||||
__shared__ float randNumS;
|
||||
|
||||
const int tid = threadIdx.x;
|
||||
@ -233,7 +232,6 @@ __global__ void topPSsampling(T* sortedLogProbs, int* sortedIdVals, int** ids, i
|
||||
// will choose the token which probability makes cumulative probability sum to exceed P'
|
||||
if (threadIdx.x == 0)
|
||||
{
|
||||
stopShared = 0;
|
||||
randNumS = curand_uniform(curandstate + blockIdx.x) * probThreshold;
|
||||
}
|
||||
|
||||
|
||||
@ -1361,7 +1361,6 @@ __global__ void add_fusedQKV_bias_transpose_kernel(T* q_buf, T* k_buf, T* v_buf,
|
||||
|
||||
switch (position_embedding_type)
|
||||
{
|
||||
case PositionEmbeddingType::kRELATIVE:
|
||||
case PositionEmbeddingType::kROPE_GPTJ:
|
||||
{
|
||||
mmha::apply_rotary_embedding(
|
||||
|
||||
@ -18,6 +18,9 @@
|
||||
#include "tensorrt_llm/common/memoryUtils.h"
|
||||
#include "tensorrt_llm/kernels/beamSearchPenaltyKernels.h"
|
||||
#include "tensorrt_llm/layers/baseBeamSearchLayer.h"
|
||||
#include "tensorrt_llm/layers/fillBuffers.h"
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
using namespace tensorrt_llm::common;
|
||||
using namespace tensorrt_llm::kernels;
|
||||
@ -32,7 +35,7 @@ __global__ void update_indir_cache_kernel(int* tgt_indir_cache, const int* src_i
|
||||
int beam_width, int max_kv_cache_length, int max_seq_len)
|
||||
{
|
||||
int time_step = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
int bb_id = threadIdx.y + blockIdx.y * blockDim.y;
|
||||
int bb_id = threadIdx.y + blockIdx.y * blockDim.y; // should be just blockIdx.y?
|
||||
const int current_step{sequence_lengths[bb_id] - 1}; // the sequence_lengths is updated, need to minus 1
|
||||
const int batch_id = bb_id / beam_width;
|
||||
const int beam_id = bb_id % beam_width;
|
||||
@ -46,9 +49,9 @@ __global__ void update_indir_cache_kernel(int* tgt_indir_cache, const int* src_i
|
||||
const int src_beam = parent_ids[batch_id][beam_id * max_seq_len + current_step];
|
||||
|
||||
// for the indir tables, we have the cyclic kv cache.
|
||||
const uint tgt_offset
|
||||
const uint32_t tgt_offset
|
||||
= batch_id * beam_width * max_kv_cache_length + beam_id * max_kv_cache_length + time_step_circ;
|
||||
const uint src_offset
|
||||
const uint32_t src_offset
|
||||
= batch_id * beam_width * max_kv_cache_length + src_beam * max_kv_cache_length + time_step_circ;
|
||||
|
||||
tgt_indir_cache[tgt_offset] = (time_step == current_step) ? beam_id : src_indir_cache[src_offset];
|
||||
@ -113,7 +116,7 @@ void BaseBeamSearchLayer<T>::allocateBuffer(size_t batch_size)
|
||||
repetition_penalty_buf_ = allocator_->reMalloc(repetition_penalty_buf_, sizeof(float) * batch_size, false);
|
||||
|
||||
is_allocate_buffer_ = true;
|
||||
TLLM_LOG_DEBUG("% stop", __PRETTY_FUNCTION__);
|
||||
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
@ -122,25 +125,7 @@ void BaseBeamSearchLayer<T>::setupBase(size_t batch_size, SetupParams const& set
|
||||
allocateBuffer(batch_size);
|
||||
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
|
||||
// Setup penalties.
|
||||
auto fillBuffers
|
||||
= [this, &batch_size](auto const& optParam, auto const defaultValue, auto& hostBuffer, auto& deviceBuffer)
|
||||
{
|
||||
hostBuffer.resize(batch_size);
|
||||
if (!optParam)
|
||||
{
|
||||
std::fill(std::begin(hostBuffer), std::end(hostBuffer), defaultValue);
|
||||
}
|
||||
else if (optParam->size() == 1)
|
||||
{
|
||||
std::fill(std::begin(hostBuffer), std::end(hostBuffer), optParam->front());
|
||||
}
|
||||
else
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(optParam->size() == batch_size, "Argument vector size mismatch.");
|
||||
std::copy(optParam->begin(), optParam->end(), std::begin(hostBuffer));
|
||||
}
|
||||
cudaAutoCpy(deviceBuffer, hostBuffer.data(), batch_size, stream_);
|
||||
};
|
||||
FillBuffers const fillBuffers{batch_size, stream_};
|
||||
|
||||
fillBuffers(setupParams.temperature, 1.0f, mTemperature, temperature_buf_);
|
||||
fillBuffers(setupParams.min_length, 1, mMinLength, min_lengths_buf_);
|
||||
|
||||
@ -321,7 +321,7 @@ void DynamicDecodeLayer<T>::forward(OutputParams& outputs, ForwardParams const&
|
||||
= outputs.cum_log_probs->slice({dynamic_decode_batch_size * beam_width}, dynamic_id_offset);
|
||||
|
||||
dynamic_decode_outputs.beamHypotheses = outputs.beamHypotheses;
|
||||
dynamic_decode_outputs.output_log_probs = outputs.output_log_probs;
|
||||
dynamic_decode_outputs.output_log_probs = outputs.output_log_probs_tiled;
|
||||
|
||||
// only OnlineBeamSearchLayer support beam_search_diversity_rate
|
||||
// when beamHypotheses is used
|
||||
@ -365,11 +365,11 @@ void DynamicDecodeLayer<T>::forward(OutputParams& outputs, ForwardParams const&
|
||||
decode_outputs.cum_log_probs
|
||||
= outputs.cum_log_probs->slice({local_batch_size * beam_width}, local_batch_offset);
|
||||
}
|
||||
if (outputs.output_log_probs)
|
||||
if (outputs.output_log_probs_tiled)
|
||||
{
|
||||
auto const generationStep = step - params.max_input_length;
|
||||
TLLM_CHECK(generationStep >= 0);
|
||||
Tensor& output_log_probs = outputs.output_log_probs.value();
|
||||
Tensor& output_log_probs = outputs.output_log_probs_tiled.value();
|
||||
size_t step_offset = generationStep * batch_size * beam_width;
|
||||
decode_outputs.output_log_probs
|
||||
= output_log_probs.slice({output_log_probs.shape[0] - generationStep, local_batch_size * beam_width},
|
||||
@ -411,6 +411,17 @@ void DynamicDecodeLayer<T>::forward(OutputParams& outputs, ForwardParams const&
|
||||
|
||||
invokeCopyNextStepIds(outputs.newTokens.template getPtr<int>(), idsPtrHost,
|
||||
outputs.sequence_length->template getPtr<int>(), batch_size, beam_width, max_seq_len, stream_);
|
||||
|
||||
// Transpose the output log probs from [max_seq_len, bs, beam_width] to [batch_size, beam_width, max_seq_len]
|
||||
if (outputs.output_log_probs_tiled)
|
||||
{
|
||||
auto logProbsMaxSeqLen = outputs.output_log_probs_tiled.value().shape[0];
|
||||
|
||||
invokeTransposeLogProbs(outputs.output_log_probs.value().template getPtr<float>(),
|
||||
outputs.output_log_probs_tiled.value().template getPtr<float>(),
|
||||
outputs.sequence_length->template getPtr<int>(), batch_size, beam_width, logProbsMaxSeqLen, stream_);
|
||||
}
|
||||
|
||||
sync_check_cuda_error();
|
||||
}
|
||||
|
||||
|
||||
@ -129,14 +129,16 @@ public:
|
||||
std::optional<tc::Tensor> parent_ids; // [max_seq_len, batch_size * beam_width], necessary in beam search
|
||||
std::optional<tc::Tensor> sequence_length; // [batch_size * beam_width], optional
|
||||
std::optional<tc::Tensor>
|
||||
output_log_probs; // [request_ouptut_length, batch_size * beam_width], must be float*, optional
|
||||
output_log_probs_tiled; // [request_output_length, batch_size, beam_width], must be float*, optional
|
||||
std::optional<tc::Tensor>
|
||||
tgt_cache_indirection; // [local_batch_size, beam_width, max_seq_len], the k/v cache index for beam search
|
||||
output_log_probs; // [batchSize, beam_width, request_ouptut_length], must be float*, optional
|
||||
std::optional<tc::Tensor>
|
||||
tgt_cache_indirection; // [local_batch_size, beam_width, max_seq_len], the k/v cache index for beam search
|
||||
std::shared_ptr<kernels::BeamHypotheses>
|
||||
beamHypotheses; // a special structure which maintains some pointers of beam search
|
||||
beamHypotheses; // a special structure which maintains some pointers of beam search
|
||||
|
||||
tc::Tensor output_ids_ptr; // [batch_size] int* (2-d array), each int* has [beam_width, max_seq_len]
|
||||
tc::Tensor parent_ids_ptr; // [batch_size] int* (2-d array), each int* has [beam_width, max_seq_len]
|
||||
tc::Tensor output_ids_ptr; // [batch_size] int* (2-d array), each int* has [beam_width, max_seq_len]
|
||||
tc::Tensor parent_ids_ptr; // [batch_size] int* (2-d array), each int* has [beam_width, max_seq_len]
|
||||
};
|
||||
|
||||
void forward(OutputParams& outputs, ForwardParams const& params);
|
||||
|
||||
68
cpp/tensorrt_llm/layers/fillBuffers.h
Normal file
68
cpp/tensorrt_llm/layers/fillBuffers.h
Normal file
@ -0,0 +1,68 @@
|
||||
/*
|
||||
* Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <algorithm>
|
||||
#include <optional>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
#include "tensorrt_llm/common/assert.h"
|
||||
#include "tensorrt_llm/common/memoryUtils.h"
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
namespace layers
|
||||
{
|
||||
|
||||
// Using a local lambda in beam search layers to fill buffers causes an internal compiler error on nvcc windows.
|
||||
// As a workaround and to promote DRY, the fill logic is refactored into FillBuffers below.
|
||||
struct FillBuffers
|
||||
{
|
||||
|
||||
template <typename T>
|
||||
void operator()(std::optional<std::vector<T>> const& optParam, T const defaultValue, std::vector<T>& hostBuffer,
|
||||
T*& deviceBuffer) const
|
||||
{
|
||||
using tensorrt_llm::common::cudaAutoCpy;
|
||||
|
||||
hostBuffer.resize(batch_size);
|
||||
if (!optParam)
|
||||
{
|
||||
std::fill(std::begin(hostBuffer), std::end(hostBuffer), defaultValue);
|
||||
}
|
||||
else if (optParam->size() == 1)
|
||||
{
|
||||
std::fill(std::begin(hostBuffer), std::end(hostBuffer), optParam->front());
|
||||
}
|
||||
else
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(optParam->size() == batch_size, "Argument vector size mismatch.");
|
||||
std::copy(optParam->begin(), optParam->end(), std::begin(hostBuffer));
|
||||
}
|
||||
cudaAutoCpy(deviceBuffer, hostBuffer.data(), batch_size, stream);
|
||||
}
|
||||
|
||||
size_t batch_size;
|
||||
cudaStream_t stream;
|
||||
};
|
||||
|
||||
} // namespace layers
|
||||
|
||||
} // namespace tensorrt_llm
|
||||
@ -16,6 +16,7 @@
|
||||
|
||||
#include "tensorrt_llm/common/memoryUtils.h"
|
||||
#include "tensorrt_llm/kernels/beamSearchTopkKernels.h"
|
||||
#include "tensorrt_llm/layers/fillBuffers.h"
|
||||
#include "tensorrt_llm/layers/onlineBeamSearchLayer.h"
|
||||
|
||||
using namespace tensorrt_llm::common;
|
||||
@ -49,7 +50,7 @@ __global__ void update_kernel(bool* finished, int** parent_ids_ptr, int* sequenc
|
||||
|
||||
// Increase the seq_len even if the request has finished.
|
||||
// On the following iteration we check if the sequence has finished before
|
||||
if (!finished[beam_idx])
|
||||
if (!finished[blockIdx.x * beam_width + beam_idx])
|
||||
{
|
||||
s_sequence_lengths[beam_idx]++;
|
||||
}
|
||||
@ -96,25 +97,7 @@ void OnlineBeamSearchLayer<T>::setup(size_t batch_size, SetupParams const& setup
|
||||
mDiversityRate = setupParams.beam_search_diversity_rate.value_or(std::vector<float>(0.0f));
|
||||
mLengthPenalty = setupParams.length_penalty.value_or(std::vector<float>(0.0f));
|
||||
|
||||
auto fillBuffers
|
||||
= [this, &batch_size](auto const& optParam, auto const defaultValue, auto& hostBuffer, auto& deviceBuffer)
|
||||
{
|
||||
hostBuffer.resize(batch_size);
|
||||
if (!optParam)
|
||||
{
|
||||
std::fill(std::begin(hostBuffer), std::end(hostBuffer), defaultValue);
|
||||
}
|
||||
else if (optParam->size() == 1)
|
||||
{
|
||||
std::fill(std::begin(hostBuffer), std::end(hostBuffer), optParam->front());
|
||||
}
|
||||
else
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(optParam->size() == batch_size, "Argument vector size mismatch.");
|
||||
std::copy(optParam->begin(), optParam->end(), std::begin(hostBuffer));
|
||||
}
|
||||
cudaAutoCpy(deviceBuffer, hostBuffer.data(), batch_size, stream_);
|
||||
};
|
||||
FillBuffers const fillBuffers{batch_size, stream_};
|
||||
|
||||
fillBuffers(setupParams.beam_search_diversity_rate, 0.0f, mDiversityRate, diversity_rates_buf_);
|
||||
fillBuffers(setupParams.length_penalty, 0.0f, mLengthPenalty, length_penalties_buf_);
|
||||
@ -153,7 +136,6 @@ void OnlineBeamSearchLayer<T>::invokeSoftMax(BeamSearchOutputParams& outputs, So
|
||||
beamHypotheses.end_ids = end_ids;
|
||||
}
|
||||
|
||||
output_log_probs = (outputs.output_log_probs) ? outputs.output_log_probs->template getPtr<float>() : nullptr;
|
||||
invokeTopkSoftMax(logits.template getPtr<T>(), (const T*) (nullptr), finished, sequence_lengths,
|
||||
outputs.cum_log_probs->template getPtr<float>(), output_log_probs, output_ids_ptr.getPtr<int*>(),
|
||||
topk_softmax_workspace_, topk_softmax_workspace_size_, &beamHypotheses, local_batch_size, beam_width,
|
||||
|
||||
@ -44,7 +44,8 @@ set(PLUGIN_LISTS
|
||||
rmsnormQuantizationPlugin
|
||||
weightOnlyGroupwiseQuantMatmulPlugin
|
||||
weightOnlyQuantMatmulPlugin
|
||||
lookupPlugin)
|
||||
lookupPlugin
|
||||
loraPlugin)
|
||||
|
||||
foreach(PLUGIN_ITER ${PLUGIN_LISTS})
|
||||
include_directories(${PLUGIN_ITER})
|
||||
|
||||
@ -26,6 +26,7 @@
|
||||
#include "tensorrt_llm/plugins/layernormPlugin/layernormPlugin.h"
|
||||
#include "tensorrt_llm/plugins/layernormQuantizationPlugin/layernormQuantizationPlugin.h"
|
||||
#include "tensorrt_llm/plugins/lookupPlugin/lookupPlugin.h"
|
||||
#include "tensorrt_llm/plugins/loraPlugin/loraPlugin.h"
|
||||
#if ENABLE_MULTI_DEVICE
|
||||
#include "tensorrt_llm/plugins/ncclPlugin/allgatherPlugin.h"
|
||||
#include "tensorrt_llm/plugins/ncclPlugin/allreducePlugin.h"
|
||||
@ -151,6 +152,7 @@ extern "C"
|
||||
weightOnlyGroupwiseQuantMatmulPluginCreator;
|
||||
static tensorrt_llm::plugins::WeightOnlyQuantMatmulPluginCreator weightOnlyQuantMatmulPluginCreator;
|
||||
static tensorrt_llm::plugins::LookupPluginCreator lookupPluginCreator;
|
||||
static tensorrt_llm::plugins::LoraPluginCreator loraPluginCreator;
|
||||
|
||||
static std::array pluginCreators
|
||||
= { creatorPtr(identityPluginCreator),
|
||||
@ -173,6 +175,7 @@ extern "C"
|
||||
creatorPtr(weightOnlyGroupwiseQuantMatmulPluginCreator),
|
||||
creatorPtr(weightOnlyQuantMatmulPluginCreator),
|
||||
creatorPtr(lookupPluginCreator),
|
||||
creatorPtr(loraPluginCreator),
|
||||
};
|
||||
nbCreators = pluginCreators.size();
|
||||
return pluginCreators.data();
|
||||
|
||||
@ -18,6 +18,7 @@
|
||||
#include "tensorrt_llm/kernels/decoderMaskedMultiheadAttention.h"
|
||||
#include "tensorrt_llm/kernels/gptKernels.h"
|
||||
#include "tensorrt_llm/kernels/unfusedAttentionKernels.h"
|
||||
#include "tensorrt_llm/runtime/iBuffer.h"
|
||||
|
||||
using namespace nvinfer1;
|
||||
using namespace tensorrt_llm::kernels;
|
||||
@ -32,7 +33,8 @@ PluginFieldCollection BertAttentionPluginCreator::mFC{};
|
||||
std::vector<nvinfer1::PluginField> BertAttentionPluginCreator::mPluginAttributes;
|
||||
|
||||
BertAttentionPlugin::BertAttentionPlugin(int num_heads, int head_size, float q_scaling, bool qk_half_accum,
|
||||
ContextFMHAType context_fmha_type, nvinfer1::DataType type, bool do_relative_attention, int max_distance)
|
||||
ContextFMHAType context_fmha_type, nvinfer1::DataType type, bool do_relative_attention, int max_distance,
|
||||
bool remove_padding)
|
||||
: mNumHeads(num_heads)
|
||||
, mHeadSize(head_size)
|
||||
, mQScaling(q_scaling)
|
||||
@ -42,6 +44,7 @@ BertAttentionPlugin::BertAttentionPlugin(int num_heads, int head_size, float q_s
|
||||
, mType(type)
|
||||
, mRelativeAttention(do_relative_attention)
|
||||
, mMaxDistance(max_distance)
|
||||
, mRemovePadding(remove_padding)
|
||||
{
|
||||
// pre-check whether FMHA is supported in order to save memory allocation
|
||||
mEnableContextFMHA = mEnableContextFMHA && (mType == DataType::kHALF) && MHARunner::fmha_supported(mHeadSize, mSM);
|
||||
@ -60,6 +63,7 @@ BertAttentionPlugin::BertAttentionPlugin(const void* data, size_t length)
|
||||
read(d, mType);
|
||||
read(d, mRelativeAttention);
|
||||
read(d, mMaxDistance);
|
||||
read(d, mRemovePadding);
|
||||
TLLM_CHECK(d == a + length);
|
||||
}
|
||||
|
||||
@ -84,13 +88,29 @@ nvinfer1::DimsExprs BertAttentionPlugin::getOutputDimensions(
|
||||
bool BertAttentionPlugin::supportsFormatCombination(
|
||||
int pos, const nvinfer1::PluginTensorDesc* inOut, int nbInputs, int nbOutputs) noexcept
|
||||
{
|
||||
if (pos == 1)
|
||||
{
|
||||
return inOut[pos].type == nvinfer1::DataType::kINT32;
|
||||
// inputs: [0] qkv, [1] input_lengths, [2] max_input_length (optional), [3] relative_attention_bias (optional)
|
||||
// outputs: [X] hidden_states
|
||||
if (nbInputs == 2)
|
||||
{ // BERT
|
||||
if (pos == 1)
|
||||
{
|
||||
return inOut[pos].type == nvinfer1::DataType::kINT32;
|
||||
}
|
||||
else
|
||||
{
|
||||
return (inOut[pos].type == mType) && (inOut[pos].format == TensorFormat::kLINEAR);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
return (inOut[pos].type == mType) && (inOut[pos].format == TensorFormat::kLINEAR);
|
||||
else if (nbInputs > 2)
|
||||
{ // Encoder in encoder-decoder
|
||||
if (pos == 1 || pos == 2)
|
||||
{
|
||||
return inOut[pos].type == nvinfer1::DataType::kINT32;
|
||||
}
|
||||
else
|
||||
{
|
||||
return (inOut[pos].type == mType) && (inOut[pos].format == TensorFormat::kLINEAR);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -102,25 +122,17 @@ void BertAttentionPlugin::configurePlugin(const nvinfer1::DynamicPluginTensorDes
|
||||
size_t BertAttentionPlugin::getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, int nbInputs,
|
||||
const nvinfer1::PluginTensorDesc* outputs, int nbOutputs) const noexcept
|
||||
{
|
||||
const int batch_size = inputs[0].dims.d[0];
|
||||
const int input_seq_len = inputs[0].dims.d[1];
|
||||
// if remove padding, inputs[0] "qkv_hidden_states" dim is [1, num_tokens, 3*hidden_dim] which doesn't have shape
|
||||
// info should get max_batch_size and max_input_length from inputs[1] "input_lengths" and input[2]
|
||||
// "max_input_length"
|
||||
const int batch_size = mRemovePadding ? inputs[1].dims.d[0] : inputs[0].dims.d[0];
|
||||
const int input_seq_len = mRemovePadding ? inputs[2].dims.d[0] : inputs[0].dims.d[1];
|
||||
const int local_hidden_units_ = inputs[0].dims.d[2] / 3;
|
||||
const int beam_width = 1;
|
||||
const int max_input_length = input_seq_len;
|
||||
|
||||
size_t size{0U};
|
||||
if (inputs[0].type == DataType::kHALF)
|
||||
{
|
||||
size = sizeof(half);
|
||||
}
|
||||
else if (inputs[0].type == DataType::kFLOAT)
|
||||
{
|
||||
size = sizeof(float);
|
||||
}
|
||||
auto const size = tensorrt_llm::runtime::BufferDataType(inputs[0].type).getSize();
|
||||
|
||||
const size_t attention_mask_size
|
||||
= mEnableContextFMHA ? 0 : size * batch_size * beam_width * max_input_length * max_input_length;
|
||||
const size_t cu_seqlens_size = sizeof(int) * (batch_size * beam_width + 1);
|
||||
const size_t attention_mask_size = mEnableContextFMHA ? 0 : size * batch_size * input_seq_len * input_seq_len;
|
||||
const size_t cu_seqlens_size = sizeof(int) * (batch_size + 1);
|
||||
const size_t q_buf_2_size = size * batch_size * input_seq_len * local_hidden_units_;
|
||||
const size_t k_buf_2_size = size * batch_size * input_seq_len * local_hidden_units_;
|
||||
const size_t v_buf_2_size = size * batch_size * input_seq_len * local_hidden_units_;
|
||||
@ -153,25 +165,27 @@ int BertAttentionPlugin::enqueueImpl(const nvinfer1::PluginTensorDesc* inputDesc
|
||||
{
|
||||
|
||||
// inputs
|
||||
// input_tensor [batch_size, seq_len, local_hidden_size * 3]
|
||||
// input_tensor [batch_size, seq_len, local_hidden_size*3] or [1, num_tokens, local_hidden_size*3]
|
||||
// input_lengths [batch_size]
|
||||
// relative_attention_bias [num_heads, num_buckets] (optional)
|
||||
// max_input_length [max_input_length] -- use shape dim to represent max value. If remove padding, this records
|
||||
// the max input length among sequences; otherwise same as input_tensor's padded dim[1] relative_attention_bias
|
||||
// [num_heads, num_buckets] (optional)
|
||||
// outputs
|
||||
// output_tensor [batch_size, seq_len, local_hidden_size]
|
||||
// output_tensor [batch_size, seq_len, local_hidden_size] or [1, num_tokens, local_hidden_size]
|
||||
|
||||
const int batch_size = inputDesc[0].dims.d[0];
|
||||
// if remove padding, inputs[0] dim is [1, num_tokens] which doesn't have workspace info
|
||||
// should get max_batch_size from inputs[1] and max_input_length from plugin attribute
|
||||
const int batch_size = mRemovePadding ? inputDesc[1].dims.d[0] : inputDesc[0].dims.d[0];
|
||||
const int input_seq_len = mRemovePadding ? inputDesc[2].dims.d[0] : inputDesc[0].dims.d[1];
|
||||
const int num_tokens = mRemovePadding ? inputDesc[0].dims.d[1] : batch_size * input_seq_len;
|
||||
const int request_batch_size = batch_size;
|
||||
const int input_seq_len = inputDesc[0].dims.d[1];
|
||||
const int request_seq_len = input_seq_len;
|
||||
const int beam_width = 1;
|
||||
const int local_hidden_units_ = inputDesc[0].dims.d[2] / 3;
|
||||
const float q_scaling = mQScaling;
|
||||
|
||||
const T* attention_input = reinterpret_cast<const T*>(inputs[0]);
|
||||
|
||||
const int* input_lengths = reinterpret_cast<const int*>(inputs[1]);
|
||||
const T* relative_attn_table = mRelativeAttention ? reinterpret_cast<const T*>(inputs[2]) : nullptr;
|
||||
|
||||
const T* relative_attn_table = mRelativeAttention ? reinterpret_cast<const T*>(inputs[3]) : nullptr;
|
||||
T* context_buf_ = (T*) (outputs[0]);
|
||||
|
||||
auto cublasHandle = mCublasWrapper->getCublasHandle();
|
||||
@ -186,10 +200,15 @@ int BertAttentionPlugin::enqueueImpl(const nvinfer1::PluginTensorDesc* inputDesc
|
||||
{
|
||||
mCublasWrapper->setFP32GemmConfig();
|
||||
}
|
||||
#ifdef ENABLE_BF16
|
||||
else if constexpr (std::is_same_v<T, __nv_bfloat16>)
|
||||
{
|
||||
mCublasWrapper->setBF16GemmConfig();
|
||||
}
|
||||
#endif
|
||||
|
||||
const size_t attention_mask_size
|
||||
= mEnableContextFMHA ? 0 : sizeof(T) * batch_size * beam_width * input_seq_len * input_seq_len;
|
||||
const size_t cu_seqlens_size = sizeof(int) * (batch_size * beam_width + 1);
|
||||
const size_t attention_mask_size = mEnableContextFMHA ? 0 : sizeof(T) * batch_size * input_seq_len * input_seq_len;
|
||||
const size_t cu_seqlens_size = sizeof(int) * (batch_size + 1);
|
||||
const size_t q_buf_2_size = sizeof(T) * batch_size * input_seq_len * local_hidden_units_;
|
||||
const size_t k_buf_2_size = sizeof(T) * batch_size * input_seq_len * local_hidden_units_;
|
||||
const size_t v_buf_2_size = sizeof(T) * batch_size * input_seq_len * local_hidden_units_;
|
||||
@ -200,8 +219,6 @@ int BertAttentionPlugin::enqueueImpl(const nvinfer1::PluginTensorDesc* inputDesc
|
||||
= mEnableContextFMHA ? 0 : sizeof(float) * batch_size * mNumHeads * input_seq_len * input_seq_len;
|
||||
const size_t padding_offset_size = sizeof(int) * batch_size * input_seq_len;
|
||||
|
||||
mMaxInputLength = input_seq_len;
|
||||
|
||||
// Workspace pointer shift
|
||||
int8_t* workspace_byte_ptr = reinterpret_cast<int8_t*>(workspace);
|
||||
size_t offset = CUBLAS_WORKSPACE_SIZE;
|
||||
@ -223,17 +240,17 @@ int BertAttentionPlugin::enqueueImpl(const nvinfer1::PluginTensorDesc* inputDesc
|
||||
params.paddingOffsets = padding_offset;
|
||||
params.attentionMask = attention_mask;
|
||||
params.seqLengths = input_lengths;
|
||||
params.batchSize = batch_size * beam_width;
|
||||
params.maxSeqLength = mMaxInputLength;
|
||||
params.numTokens = batch_size * beam_width * mMaxInputLength;
|
||||
params.batchSize = batch_size;
|
||||
params.maxSeqLength = input_seq_len;
|
||||
params.numTokens = num_tokens;
|
||||
params.attentionMaskType = AttentionMaskType::PADDING;
|
||||
invokeBuildDecoderInfo(params, stream);
|
||||
sync_check_cuda_error();
|
||||
|
||||
// Padding offset = nullptr here (remove padding is not supported).
|
||||
invokeAddFusedQKVBiasTranspose(q_buf_2_, k_buf_2_, v_buf_2_, const_cast<T*>(attention_input), input_lengths,
|
||||
nullptr, request_batch_size, request_seq_len, batch_size * input_seq_len, mNumHeads, mNumHeads, mHeadSize,
|
||||
mEnableContextFMHA, 0, 0.0f, RotaryScalingType::kNONE, 0.0f, 0, PositionEmbeddingType::kLEARNED_ABSOLUTE,
|
||||
(float*) nullptr, 0, stream);
|
||||
mRemovePadding ? padding_offset : nullptr, batch_size, input_seq_len, num_tokens, mNumHeads, mNumHeads,
|
||||
mHeadSize, mEnableContextFMHA, 0, 0.0f, RotaryScalingType::kNONE, 0.0f, 0,
|
||||
PositionEmbeddingType::kLEARNED_ABSOLUTE, (float*) nullptr, 0, stream);
|
||||
|
||||
const auto gemm_data_type = tc::CudaDataType<T>::value;
|
||||
const int attention_seq_len_1 = request_seq_len; // q length
|
||||
@ -285,7 +302,7 @@ int BertAttentionPlugin::enqueueImpl(const nvinfer1::PluginTensorDesc* inputDesc
|
||||
// [num_heads, num_buckets], with necessary params (max_distance, num_buckets) passed at the end
|
||||
invokeAddRelativeAttentionBiasUnaligned(qk_buf_float_, relative_attn_table, request_batch_size,
|
||||
mNumHeads, attention_seq_len_1, attention_seq_len_2, stream, mMaxDistance > 0,
|
||||
inputDesc[2].dims.d[1], mMaxDistance, true /* bidirectional */);
|
||||
inputDesc[3].dims.d[1], mMaxDistance, true /* bidirectional */);
|
||||
}
|
||||
|
||||
MaskedSoftmaxParam<T, float> param;
|
||||
@ -317,7 +334,7 @@ int BertAttentionPlugin::enqueueImpl(const nvinfer1::PluginTensorDesc* inputDesc
|
||||
// max_output_len + 1. In implicit mode, relative_attention_bias is rel attn table
|
||||
// [num_heads, num_buckets], with necessary params (max_distance, num_buckets) passed at the end
|
||||
invokeAddRelativeAttentionBiasUnaligned(qk_buf_, relative_attn_table, request_batch_size, mNumHeads,
|
||||
attention_seq_len_1, attention_seq_len_2, stream, mMaxDistance > 0, inputDesc[2].dims.d[1],
|
||||
attention_seq_len_1, attention_seq_len_2, stream, mMaxDistance > 0, inputDesc[3].dims.d[1],
|
||||
mMaxDistance, true /* bidirectional */);
|
||||
}
|
||||
|
||||
@ -339,16 +356,15 @@ int BertAttentionPlugin::enqueueImpl(const nvinfer1::PluginTensorDesc* inputDesc
|
||||
attention_seq_len_1 * attention_seq_len_2, qkv_buf_2_, mHeadSize, attention_seq_len_1 * mHeadSize,
|
||||
request_batch_size * mNumHeads);
|
||||
|
||||
if (padding_offset == nullptr)
|
||||
if (!mRemovePadding)
|
||||
{
|
||||
invokeTransposeQKV(context_buf_, qkv_buf_2_, request_batch_size, attention_seq_len_1, mNumHeads, mHeadSize,
|
||||
(float*) nullptr, 0, stream);
|
||||
}
|
||||
else
|
||||
{
|
||||
invokeTransposeAttentionOutRemovePadding(qkv_buf_2_, context_buf_, batch_size * input_seq_len,
|
||||
request_batch_size, attention_seq_len_1, mNumHeads, mHeadSize, padding_offset, (float*) nullptr, 0,
|
||||
stream);
|
||||
invokeTransposeAttentionOutRemovePadding(qkv_buf_2_, context_buf_, num_tokens, request_batch_size,
|
||||
request_seq_len, mNumHeads, mHeadSize, padding_offset, (float*) nullptr, 0, stream);
|
||||
}
|
||||
}
|
||||
return 0;
|
||||
@ -362,6 +378,12 @@ template int BertAttentionPlugin::enqueueImpl<float>(const nvinfer1::PluginTenso
|
||||
const nvinfer1::PluginTensorDesc* outputDesc, const void* const* inputs, void* const* outputs, void* workspace,
|
||||
cudaStream_t stream);
|
||||
|
||||
#ifdef ENABLE_BF16
|
||||
template int BertAttentionPlugin::enqueueImpl<__nv_bfloat16>(const nvinfer1::PluginTensorDesc* inputDesc,
|
||||
const nvinfer1::PluginTensorDesc* outputDesc, const void* const* inputs, void* const* outputs, void* workspace,
|
||||
cudaStream_t stream);
|
||||
#endif
|
||||
|
||||
int BertAttentionPlugin::enqueue(const nvinfer1::PluginTensorDesc* inputDesc,
|
||||
const nvinfer1::PluginTensorDesc* outputDesc, const void* const* inputs, void* const* outputs, void* workspace,
|
||||
cudaStream_t stream) noexcept
|
||||
@ -374,6 +396,12 @@ int BertAttentionPlugin::enqueue(const nvinfer1::PluginTensorDesc* inputDesc,
|
||||
{
|
||||
return enqueueImpl<float>(inputDesc, outputDesc, inputs, outputs, workspace, stream);
|
||||
}
|
||||
#ifdef ENABLE_BF16
|
||||
else if (mType == DataType::kBF16)
|
||||
{
|
||||
return enqueueImpl<__nv_bfloat16>(inputDesc, outputDesc, inputs, outputs, workspace, stream);
|
||||
}
|
||||
#endif
|
||||
return 0;
|
||||
}
|
||||
|
||||
@ -425,7 +453,8 @@ void BertAttentionPlugin::destroy() noexcept
|
||||
size_t BertAttentionPlugin::getSerializationSize() const noexcept
|
||||
{
|
||||
return sizeof(mNumHeads) + sizeof(mHeadSize) + sizeof(mQScaling) + sizeof(mQKHalfAccum) + sizeof(mEnableContextFMHA)
|
||||
+ sizeof(mFMHAForceFP32Acc) + sizeof(mType) + sizeof(mRelativeAttention) + sizeof(mMaxDistance);
|
||||
+ sizeof(mFMHAForceFP32Acc) + sizeof(mType) + sizeof(mRelativeAttention) + sizeof(mMaxDistance)
|
||||
+ sizeof(mRemovePadding);
|
||||
}
|
||||
|
||||
void BertAttentionPlugin::serialize(void* buffer) const noexcept
|
||||
@ -440,6 +469,7 @@ void BertAttentionPlugin::serialize(void* buffer) const noexcept
|
||||
write(d, mType);
|
||||
write(d, mRelativeAttention);
|
||||
write(d, mMaxDistance);
|
||||
write(d, mRemovePadding);
|
||||
assert(d == a + getSerializationSize());
|
||||
}
|
||||
|
||||
@ -459,6 +489,7 @@ BertAttentionPluginCreator::BertAttentionPluginCreator()
|
||||
mPluginAttributes.emplace_back(PluginField("type_id", nullptr, PluginFieldType::kINT32, 1));
|
||||
mPluginAttributes.emplace_back(PluginField("do_relative_attention", nullptr, PluginFieldType::kINT8, 0));
|
||||
mPluginAttributes.emplace_back(PluginField("max_distance", nullptr, PluginFieldType::kINT32, 0));
|
||||
mPluginAttributes.emplace_back(PluginField("remove_padding", nullptr, PluginFieldType::kINT8, 0));
|
||||
mFC.nbFields = mPluginAttributes.size();
|
||||
mFC.fields = mPluginAttributes.data();
|
||||
}
|
||||
@ -488,6 +519,7 @@ IPluginV2* BertAttentionPluginCreator::createPlugin(const char* name, const Plug
|
||||
nvinfer1::DataType type;
|
||||
bool do_relative_attention;
|
||||
int max_distance;
|
||||
bool remove_padding;
|
||||
// Read configurations from each fields
|
||||
for (int i = 0; i < fc->nbFields; ++i)
|
||||
{
|
||||
@ -532,11 +564,16 @@ IPluginV2* BertAttentionPluginCreator::createPlugin(const char* name, const Plug
|
||||
TLLM_CHECK(fields[i].type == PluginFieldType::kINT32);
|
||||
max_distance = static_cast<int>(*(static_cast<const int*>(fields[i].data)));
|
||||
}
|
||||
else if (!strcmp(attrName, "remove_padding"))
|
||||
{
|
||||
TLLM_CHECK(fields[i].type == PluginFieldType::kINT8);
|
||||
remove_padding = static_cast<bool>(*(static_cast<const int8_t*>(fields[i].data)));
|
||||
}
|
||||
}
|
||||
try
|
||||
{
|
||||
auto* obj = new BertAttentionPlugin(num_heads, head_size, q_scaling, qk_half_accum, context_fmha_type, type,
|
||||
do_relative_attention, max_distance);
|
||||
do_relative_attention, max_distance, remove_padding);
|
||||
obj->setPluginNamespace(mNamespace.c_str());
|
||||
return obj;
|
||||
}
|
||||
|
||||
@ -36,7 +36,7 @@ public:
|
||||
|
||||
BertAttentionPlugin(int num_heads, int head_size, float q_scaling, bool qk_half_accum,
|
||||
tensorrt_llm::kernels::ContextFMHAType context_fmha_type, nvinfer1::DataType type,
|
||||
bool do_relative_attention = false, int max_distance = 0);
|
||||
bool do_relative_attention = false, int max_distance = 0, bool remove_padding = false);
|
||||
|
||||
BertAttentionPlugin(const void* data, size_t length);
|
||||
|
||||
@ -78,11 +78,11 @@ private:
|
||||
|
||||
int mNumHeads;
|
||||
int mHeadSize;
|
||||
int mMaxInputLength;
|
||||
float mQScaling;
|
||||
nvinfer1::DataType mType;
|
||||
bool mRelativeAttention = false;
|
||||
int mMaxDistance = 0;
|
||||
bool mRemovePadding = false;
|
||||
|
||||
// unfused mha
|
||||
bool mQKHalfAccum = false;
|
||||
|
||||
@ -217,7 +217,7 @@ void fusedQKV_masked_attention_dispatch(Multihead_attention_params<T_MMHA, CROSS
|
||||
masked_multihead_attention(params, input_params.kv_block_array, stream);
|
||||
}
|
||||
|
||||
#define INSTANTIATE_MMHA_DISPATH(T_MMHA, T) \
|
||||
#define INSTANTIATE_MMHA_DISPATCH(T_MMHA, T) \
|
||||
template void fusedQKV_masked_attention_dispatch(Multihead_attention_params<T_MMHA, false>&, \
|
||||
const FusedQKVMaskedAttentionDispatchParams<T, KVLinearBuffer>&, cudaStream_t stream); \
|
||||
template void fusedQKV_masked_attention_dispatch(Multihead_attention_params<T_MMHA, true>&, \
|
||||
@ -226,12 +226,12 @@ void fusedQKV_masked_attention_dispatch(Multihead_attention_params<T_MMHA, CROSS
|
||||
const FusedQKVMaskedAttentionDispatchParams<T, KVBlockArray>&, cudaStream_t stream); \
|
||||
template void fusedQKV_masked_attention_dispatch(Multihead_attention_params<T_MMHA, true>&, \
|
||||
const FusedQKVMaskedAttentionDispatchParams<T, KVBlockArray>&, cudaStream_t stream);
|
||||
INSTANTIATE_MMHA_DISPATH(float, float)
|
||||
INSTANTIATE_MMHA_DISPATH(uint16_t, half)
|
||||
INSTANTIATE_MMHA_DISPATCH(float, float)
|
||||
INSTANTIATE_MMHA_DISPATCH(uint16_t, half)
|
||||
#ifdef ENABLE_BF16
|
||||
INSTANTIATE_MMHA_DISPATH(__nv_bfloat16, __nv_bfloat16)
|
||||
INSTANTIATE_MMHA_DISPATCH(__nv_bfloat16, __nv_bfloat16)
|
||||
#endif
|
||||
#undef INSTANTIATE_MMHA_DISPATH
|
||||
#undef INSTANTIATE_MMHA_DISPATCH
|
||||
|
||||
GPTAttentionPluginCommon::GPTAttentionPluginCommon(int num_heads, int num_kv_heads, int head_size, int unidirectional,
|
||||
float q_scaling, tensorrt_llm::kernels::PositionEmbeddingType position_embedding_type,
|
||||
@ -498,7 +498,8 @@ int GPTAttentionPluginCommon::enqueueContext(const EnqueueContextParams<T, KVCac
|
||||
const size_t qk_buf_float_size = mEnableContextFMHA ? 0
|
||||
: sizeof(float) * params.batch_size * mNumHeads
|
||||
* params.input_seq_length * (isCrossAttention() ? params.cross_qkv_length : params.input_seq_length);
|
||||
const size_t padding_offset_size = sizeof(int) * params.batch_size * params.input_seq_length;
|
||||
const size_t padding_offset_size
|
||||
= sizeof(int) * params.batch_size * (isCrossAttention() ? params.cross_qkv_length : params.input_seq_length);
|
||||
|
||||
const bool is_qk_buf_float_ = true;
|
||||
|
||||
@ -517,14 +518,17 @@ int GPTAttentionPluginCommon::enqueueContext(const EnqueueContextParams<T, KVCac
|
||||
int* padding_offset = reinterpret_cast<int*>(nextWorkspacePtr(workspace_byte_ptr, offset, padding_offset_size));
|
||||
|
||||
// build attention_mask, cu_seqlens, and padding_offset tensors
|
||||
// Note: self attn and cross attn should use different params
|
||||
// cross attn's seqlen info is from encoder input lengths, not decoder input lengths!
|
||||
// moreover, attn mask for cross attn should be set separately (see below)
|
||||
BuildDecoderInfoParams<T> decoder_params;
|
||||
memset(&decoder_params, 0, sizeof(decoder_params));
|
||||
decoder_params.seqOffsets = cu_seqlens;
|
||||
decoder_params.paddingOffsets = padding_offset;
|
||||
decoder_params.attentionMask = attention_mask;
|
||||
decoder_params.seqLengths = params.context_lengths;
|
||||
decoder_params.attentionMask = isCrossAttention() ? nullptr : attention_mask; // manually set for cross attn
|
||||
decoder_params.seqLengths = isCrossAttention() ? params.encoder_input_lengths : params.context_lengths;
|
||||
decoder_params.batchSize = params.batch_size;
|
||||
decoder_params.maxSeqLength = params.input_seq_length;
|
||||
decoder_params.maxSeqLength = isCrossAttention() ? params.cross_qkv_length : params.input_seq_length;
|
||||
decoder_params.maxKvCacheLength = params.cyclic_kv_cache_length;
|
||||
decoder_params.numTokens = params.num_tokens;
|
||||
decoder_params.attentionMaskType = mMaskType;
|
||||
@ -533,9 +537,27 @@ int GPTAttentionPluginCommon::enqueueContext(const EnqueueContextParams<T, KVCac
|
||||
|
||||
// In cross attention context phase, the attention mask should be a matrix of all ones.
|
||||
// We reassign attention_mask to override what previous invokeBuildDecoderInfo() does
|
||||
// also, invokeBuildDecoderInfo can only handle square mask, not cross B x q_len x kv_len mask
|
||||
// TODO: put this logic in the kernel above. currently not much concern because q_len is mostly = 1
|
||||
if (isCrossAttention())
|
||||
{
|
||||
std::vector<T> h_attention_mask(params.batch_size * params.cross_qkv_length * params.input_seq_length, 1.);
|
||||
std::vector<T> h_attention_mask(params.batch_size * params.input_seq_length * params.cross_qkv_length, 1.);
|
||||
std::vector<int32_t> h_encoder_input_lengths(params.batch_size);
|
||||
cudaMemcpyAsync(h_encoder_input_lengths.data(), params.encoder_input_lengths,
|
||||
sizeof(int32_t) * params.batch_size, cudaMemcpyDeviceToHost, stream);
|
||||
for (int bi = 0; bi < params.batch_size; bi++)
|
||||
{
|
||||
int b_offset = bi * params.input_seq_length * params.cross_qkv_length;
|
||||
for (int qi = 0; qi < params.input_seq_length; qi++)
|
||||
{
|
||||
int q_offset = b_offset + qi * params.cross_qkv_length;
|
||||
if (h_encoder_input_lengths[bi] < params.cross_qkv_length)
|
||||
{
|
||||
std::fill(h_attention_mask.begin() + q_offset + h_encoder_input_lengths[bi],
|
||||
h_attention_mask.begin() + q_offset + params.cross_qkv_length, 0.f);
|
||||
}
|
||||
}
|
||||
}
|
||||
cudaMemcpyAsync(attention_mask, h_attention_mask.data(),
|
||||
sizeof(T) * params.batch_size * params.cross_qkv_length * params.input_seq_length, cudaMemcpyHostToDevice,
|
||||
stream);
|
||||
@ -580,6 +602,8 @@ int GPTAttentionPluginCommon::enqueueContext(const EnqueueContextParams<T, KVCac
|
||||
// FIXME: a temporary solution to make sure the padding part of key/value buffer is 0
|
||||
// NOTE: pointer subtraction is used below since there could be some extra gap due to alignment.
|
||||
// Otherwise, we could do cudaMemsetAsync(k_buf_2_, 0, k_buf_2_size + v_buf_2_size, stream);
|
||||
// cudaMemsetAsync(k_buf_2_, 0, reinterpret_cast<int8_t*>(qk_buf_) - reinterpret_cast<int8_t*>(k_buf_2_),
|
||||
// stream);
|
||||
cudaMemsetAsync(k_buf_2_, 0,
|
||||
reinterpret_cast<int8_t*>(v_buf_2_) - reinterpret_cast<int8_t*>(k_buf_2_) + v_buf_2_size, stream);
|
||||
|
||||
@ -698,23 +722,22 @@ int GPTAttentionPluginCommon::enqueueContext(const EnqueueContextParams<T, KVCac
|
||||
}
|
||||
}
|
||||
|
||||
// add relative position bias
|
||||
if (isRelativePosition())
|
||||
{
|
||||
// Add relative_attention_bias
|
||||
// QK is (batch_size, local_head_num, q_length, k_length), relative_attention_bias is (1, local_head_num,
|
||||
// max_output_len + 1, max_output_len + 1).
|
||||
// broadcast along 1st dim. max_seq_len is already max_output_len + 1.
|
||||
// In implicit mode, relative_attention_bias is relative_attention_table [num_heads, num_buckets], with
|
||||
// necessary params (max_distance, num_buckets) passed at the end
|
||||
invokeAddRelativeAttentionBiasUnaligned(qk_buf_float_, relative_attention_bias, params.batch_size,
|
||||
mNumHeads, attention_seq_len_1,
|
||||
isCrossAttention() ? params.cross_qkv_length : params.cyclic_kv_cache_length, stream, max_distance > 0,
|
||||
relative_attention_bias_stride, max_distance, true /* bidirectional */);
|
||||
}
|
||||
|
||||
if (is_qk_buf_float_ == true)
|
||||
{
|
||||
// add relative position bias
|
||||
if (isRelativePosition())
|
||||
{
|
||||
// Add relative_attention_bias
|
||||
// QK is (batch_size, local_head_num, q_length, k_length), relative_attention_bias is (1,
|
||||
// local_head_num, max_output_len + 1, max_output_len + 1). broadcast along 1st dim. max_seq_len is
|
||||
// already max_output_len + 1. In implicit mode, relative_attention_bias is relative_attention_table
|
||||
// [num_heads, num_buckets], with necessary params (max_distance, num_buckets) passed at the end
|
||||
invokeAddRelativeAttentionBiasUnaligned(qk_buf_float_, relative_attention_bias, params.batch_size,
|
||||
mNumHeads, attention_seq_len_1,
|
||||
isCrossAttention() ? params.cross_qkv_length : params.cyclic_kv_cache_length, stream,
|
||||
max_distance > 0, relative_attention_bias_stride, max_distance, true /* bidirectional */);
|
||||
}
|
||||
|
||||
MaskedSoftmaxParam<T, float> param;
|
||||
param.attention_score = qk_buf_; // (batch_size, head_num, q_length, k_length)
|
||||
param.qk = qk_buf_float_; // (batch_size, head_num, q_length, k_length)
|
||||
@ -729,6 +752,19 @@ int GPTAttentionPluginCommon::enqueueContext(const EnqueueContextParams<T, KVCac
|
||||
}
|
||||
else
|
||||
{
|
||||
// add relative position bias
|
||||
if (isRelativePosition())
|
||||
{
|
||||
// Add relative_attention_bias
|
||||
// QK is (batch_size, local_head_num, q_length, k_length), relative_attention_bias is (1,
|
||||
// local_head_num, max_output_len + 1, max_output_len + 1). broadcast along 1st dim. max_seq_len is
|
||||
// already max_output_len + 1. In implicit mode, relative_attention_bias is relative_attention_table
|
||||
// [num_heads, num_buckets], with necessary params (max_distance, num_buckets) passed at the end
|
||||
invokeAddRelativeAttentionBiasUnaligned(qk_buf_, relative_attention_bias, params.batch_size, mNumHeads,
|
||||
attention_seq_len_1, isCrossAttention() ? params.cross_qkv_length : params.cyclic_kv_cache_length,
|
||||
stream, max_distance > 0, relative_attention_bias_stride, max_distance, true /* bidirectional */);
|
||||
}
|
||||
|
||||
MaskedSoftmaxParam<T, T> param;
|
||||
param.attention_score = qk_buf_; // (batch_size, head_num, q_length, k_length)
|
||||
param.qk = qk_buf_; // (batch_size, head_num, q_length, k_length)
|
||||
@ -935,7 +971,6 @@ int GPTAttentionPluginCommon::enqueueGeneration(
|
||||
dispatch_params.input_lengths = params.context_lengths;
|
||||
dispatch_params.step = step;
|
||||
dispatch_params.q_scaling = q_scaling;
|
||||
dispatch_params.relative_attention_bias_stride = relative_attention_bias_stride;
|
||||
dispatch_params.linear_bias_slopes = isALiBi() ? params.alibi_slopes : nullptr;
|
||||
dispatch_params.ia3_tasks = ia3_tasks;
|
||||
dispatch_params.ia3_key_weights = ia3_key_weights;
|
||||
|
||||
@ -250,7 +250,7 @@ int GPTAttentionPlugin::enqueueSome(int32_t seqIdxBeg, int32_t localNbSeq, int32
|
||||
// such model has an encoder context (for cross attn) and an decoder context (for self attn)
|
||||
// clarify 3 lens:
|
||||
// -- max_context_len: len of decoder input. No "max" concept, it's what it is given.
|
||||
// Also called (decoder_)input_seq_length
|
||||
// Also called (decoder_)input_seq_length, normally 1 for encoder-decoder start token
|
||||
// -- max_seq_len: max allowed len of decoder output, i.e. final results
|
||||
// -- max_encoder_context_len: len of encoder input (in cross attn). Also called encoder_input_seq_length
|
||||
|
||||
|
||||
21
cpp/tensorrt_llm/plugins/loraPlugin/CMakeLists.txt
Normal file
21
cpp/tensorrt_llm/plugins/loraPlugin/CMakeLists.txt
Normal file
@ -0,0 +1,21 @@
|
||||
#
|
||||
# SPDX-FileCopyrightText: Copyright (c) 1993-2022 NVIDIA CORPORATION &
|
||||
# AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may not
|
||||
# use this file except in compliance with the License. You may obtain a copy of
|
||||
# the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
# License for the specific language governing permissions and limitations under
|
||||
# the License.
|
||||
#
|
||||
file(GLOB SRCS *.cpp)
|
||||
set(PLUGIN_SOURCES ${PLUGIN_SOURCES} ${SRCS})
|
||||
set(PLUGIN_SOURCES
|
||||
${PLUGIN_SOURCES}
|
||||
PARENT_SCOPE)
|
||||
573
cpp/tensorrt_llm/plugins/loraPlugin/loraPlugin.cpp
Normal file
573
cpp/tensorrt_llm/plugins/loraPlugin/loraPlugin.cpp
Normal file
@ -0,0 +1,573 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 1993-2022 NVIDIA CORPORATION &
|
||||
* AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "loraPlugin.h"
|
||||
#include "tensorrt_llm/common/memoryUtils.h"
|
||||
#include "tensorrt_llm/runtime/iBuffer.h"
|
||||
|
||||
using namespace nvinfer1;
|
||||
using namespace tensorrt_llm::common;
|
||||
using tensorrt_llm::plugins::LoraPluginCreator;
|
||||
using tensorrt_llm::plugins::LoraPlugin;
|
||||
using tensorrt_llm::plugins::CublasGemmWrapperPtr;
|
||||
using tensorrt_llm::plugins::read;
|
||||
using tensorrt_llm::plugins::write;
|
||||
|
||||
static const char* LORA_PLUGIN_VERSION{"1"};
|
||||
static const char* LORA_PLUGIN_NAME{"Lora"};
|
||||
PluginFieldCollection LoraPluginCreator::mFC{};
|
||||
std::vector<nvinfer1::PluginField> LoraPluginCreator::mPluginAttributes;
|
||||
|
||||
// TODO should reuse the function in gemmPlugin
|
||||
void _getProblemParams(cublasOperation_t& transa, cublasOperation_t& transb, int& m, int& n, int& k, int& lda, int& ldb,
|
||||
int& ldc, bool transA, bool transB, int M, int N, int K)
|
||||
{
|
||||
transa = transB ? CUBLAS_OP_T : CUBLAS_OP_N;
|
||||
transb = transA ? CUBLAS_OP_T : CUBLAS_OP_N;
|
||||
m = N;
|
||||
n = M;
|
||||
k = K;
|
||||
lda = transB ? K : N;
|
||||
ldb = transA ? M : K;
|
||||
ldc = N;
|
||||
}
|
||||
|
||||
// TODO should reuse the function in gemmPlugin
|
||||
void _runGemm(const int M, const int N, const int K, const bool transA, const bool transB,
|
||||
const nvinfer1::DataType type, const CublasGemmWrapperPtr& cublasWrapperPtr, const void* act, const void* weight,
|
||||
void* output, const std::optional<cublasLtMatmulHeuristicResult_t>& heuristic, void* workspace, cudaStream_t stream)
|
||||
{
|
||||
cublasWrapperPtr->setStream(stream);
|
||||
cublasWrapperPtr->setWorkspace(workspace);
|
||||
|
||||
cublasOperation_t transa, transb;
|
||||
int m, n, k;
|
||||
int lda, ldb, ldc;
|
||||
_getProblemParams(transa, transb, m, n, k, lda, ldb, ldc, transA, transB, M, N, K);
|
||||
|
||||
cublasWrapperPtr->createDescriptors(transa, transb, m, n, k, lda, ldb, ldc);
|
||||
cublasWrapperPtr->Gemm(transa, transb, m, n, k, weight, lda, act, ldb, output, ldc, heuristic);
|
||||
cublasWrapperPtr->destroyDescriptors();
|
||||
}
|
||||
|
||||
LoraPlugin::LoraPlugin(int in_hidden_size, int out_hidden_size, int transA, int transB, int lora_module_number,
|
||||
nvinfer1::DataType type, const LoraPlugin::PluginProfilerPtr& pluginProfiler, bool remove_input_padding,
|
||||
int max_context_length, int max_low_rank)
|
||||
: mInHiddenSize(in_hidden_size)
|
||||
, mOutHiddenSize(out_hidden_size)
|
||||
, mTransA(transA)
|
||||
, mTransB(transB)
|
||||
, mType(type)
|
||||
, mPluginProfiler(pluginProfiler)
|
||||
, mRemoveInputPadding(remove_input_padding)
|
||||
, mMaxContextLength(max_context_length)
|
||||
, mMaxLowRank(max_low_rank)
|
||||
{
|
||||
TLLM_LOG_DEBUG("%s", __PRETTY_FUNCTION__);
|
||||
init();
|
||||
}
|
||||
|
||||
// Parameterized constructor
|
||||
LoraPlugin::LoraPlugin(const void* data, size_t length, const LoraPlugin::PluginProfilerPtr& pluginProfiler)
|
||||
: mPluginProfiler(pluginProfiler)
|
||||
{
|
||||
TLLM_LOG_DEBUG("%s", __PRETTY_FUNCTION__);
|
||||
const char *d = reinterpret_cast<const char*>(data), *a = d;
|
||||
read(d, mInHiddenSize);
|
||||
read(d, mOutHiddenSize);
|
||||
read(d, mTransA);
|
||||
read(d, mTransB);
|
||||
read(d, mType);
|
||||
read(d, mRemoveInputPadding);
|
||||
read(d, mMaxContextLength);
|
||||
read(d, mMaxLowRank);
|
||||
|
||||
init();
|
||||
|
||||
mPluginProfiler->deserialize(d, mDims, mGemmId);
|
||||
|
||||
TLLM_CHECK(d == a + length);
|
||||
}
|
||||
|
||||
void LoraPlugin::init()
|
||||
{
|
||||
TLLM_LOG_DEBUG("%s", __PRETTY_FUNCTION__);
|
||||
auto cublasHandle = getCublasHandle();
|
||||
auto cublasLtHandle = getCublasLtHandle();
|
||||
mCublasWrapper = std::make_shared<CublasMMWrapper>(cublasHandle, cublasLtHandle, nullptr, nullptr);
|
||||
|
||||
mPluginProfiler->setTranspose(mTransA, mTransB);
|
||||
|
||||
mGemmId = GemmIdCublas(mDims.n, mDims.k, mType, mTransA, mTransB);
|
||||
}
|
||||
|
||||
void LoraPlugin::setGemmConfig()
|
||||
{
|
||||
TLLM_LOG_DEBUG("%s", __PRETTY_FUNCTION__);
|
||||
if (mType == DataType::kHALF)
|
||||
{
|
||||
mCublasWrapper->setFP16GemmConfig();
|
||||
}
|
||||
else if (mType == DataType::kFLOAT)
|
||||
{
|
||||
mCublasWrapper->setFP32GemmConfig();
|
||||
}
|
||||
#ifdef ENABLE_BF16
|
||||
else if (mType == DataType::kBF16)
|
||||
{
|
||||
mCublasWrapper->setBF16GemmConfig();
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
void LoraPlugin::configGemm()
|
||||
{
|
||||
TLLM_LOG_DEBUG("%s", __PRETTY_FUNCTION__);
|
||||
if (!mDims.isInitialized())
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
setGemmConfig();
|
||||
|
||||
mPluginProfiler->profileTactics(mCublasWrapper, mType, mDims, mGemmId);
|
||||
}
|
||||
|
||||
// IPluginV2DynamicExt Methods
|
||||
nvinfer1::IPluginV2DynamicExt* LoraPlugin::clone() const noexcept
|
||||
{
|
||||
TLLM_LOG_DEBUG("%s", __PRETTY_FUNCTION__);
|
||||
auto* plugin = new LoraPlugin(*this);
|
||||
return plugin;
|
||||
}
|
||||
|
||||
nvinfer1::DimsExprs LoraPlugin::getOutputDimensions(
|
||||
int outputIndex, const nvinfer1::DimsExprs* inputs, int nbInputs, nvinfer1::IExprBuilder& exprBuilder) noexcept
|
||||
{
|
||||
TLLM_LOG_DEBUG("%s", __PRETTY_FUNCTION__);
|
||||
try
|
||||
{
|
||||
TLLM_CHECK(outputIndex == 0);
|
||||
const int nbDimsA = inputs[getInputTensorIdx()].nbDims;
|
||||
DimsExprs ret;
|
||||
ret.nbDims = nbDimsA;
|
||||
|
||||
for (int i = 0; i < ret.nbDims; ++i)
|
||||
{
|
||||
ret.d[0] = 0;
|
||||
}
|
||||
|
||||
if (mTransA)
|
||||
{
|
||||
for (int i = 1; i < nbDimsA; ++i)
|
||||
{
|
||||
ret.d[i - 1] = inputs[getInputTensorIdx()].d[i];
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
for (int i = 0; i < nbDimsA - 1; ++i)
|
||||
{
|
||||
ret.d[i] = inputs[getInputTensorIdx()].d[i];
|
||||
}
|
||||
}
|
||||
|
||||
auto const* outHiddenSize = exprBuilder.constant(mOutHiddenSize);
|
||||
TLLM_CHECK(outHiddenSize != nullptr);
|
||||
ret.d[ret.nbDims - 1] = outHiddenSize;
|
||||
return ret;
|
||||
}
|
||||
catch (const std::exception& e)
|
||||
{
|
||||
caughtError(e);
|
||||
}
|
||||
return DimsExprs{};
|
||||
}
|
||||
|
||||
bool LoraPlugin::supportsFormatCombination(
|
||||
int pos, const nvinfer1::PluginTensorDesc* inOut, int nbInputs, int nbOutputs) noexcept
|
||||
{
|
||||
TLLM_LOG_DEBUG("%s", __PRETTY_FUNCTION__);
|
||||
if (pos == getHostRequestTypesIdx())
|
||||
{
|
||||
return inOut[pos].type == nvinfer1::DataType::kINT32;
|
||||
}
|
||||
else if (pos == getLoraRanksIdx())
|
||||
{
|
||||
return inOut[pos].type == nvinfer1::DataType::kINT32;
|
||||
}
|
||||
else if (pos == getLoraWeightsPtrsIdx())
|
||||
{
|
||||
return inOut[pos].type == nvinfer1::DataType::kINT64;
|
||||
}
|
||||
else if (mRemoveInputPadding && pos == getHostContextLengthsIdx())
|
||||
{
|
||||
return inOut[pos].type == nvinfer1::DataType::kINT32;
|
||||
}
|
||||
else
|
||||
{
|
||||
return (inOut[pos].type == mType) && (inOut[pos].format == TensorFormat::kLINEAR);
|
||||
}
|
||||
}
|
||||
|
||||
int32_t _computeMDimension(bool transA, const int32_t nbDims, const int32_t* dims)
|
||||
{
|
||||
int32_t M = 1;
|
||||
if (transA)
|
||||
{
|
||||
for (int i = nbDims - 1; i > 0; --i)
|
||||
{
|
||||
M *= dims[i];
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
for (int i = 0; i < nbDims - 1; ++i)
|
||||
{
|
||||
M *= dims[i];
|
||||
}
|
||||
}
|
||||
return M;
|
||||
}
|
||||
|
||||
int32_t _computeNDimension(bool transB, const int32_t nbDims, const int32_t* dims)
|
||||
{
|
||||
int32_t N = 1;
|
||||
if (transB)
|
||||
{
|
||||
for (int i = 0; i < nbDims - 1; ++i)
|
||||
{
|
||||
N *= dims[i];
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
for (int i = nbDims - 1; i > 0; --i)
|
||||
{
|
||||
N *= dims[i];
|
||||
}
|
||||
}
|
||||
return N;
|
||||
}
|
||||
|
||||
void LoraPlugin::configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, int nbInputs,
|
||||
const nvinfer1::DynamicPluginTensorDesc* out, int nbOutputs) noexcept
|
||||
{
|
||||
TLLM_LOG_DEBUG("%s", __PRETTY_FUNCTION__);
|
||||
const int nbDimsA = in[0].max.nbDims;
|
||||
const int nbDimsB = in[1].max.nbDims;
|
||||
|
||||
const auto minM = _computeMDimension(mTransA, nbDimsA, in[0].min.d);
|
||||
const auto maxM = _computeMDimension(mTransA, nbDimsA, in[0].max.d);
|
||||
const auto N = _computeNDimension(mTransB, nbDimsB, in[1].max.d);
|
||||
const auto K = mTransA ? in[0].max.d[0] : in[0].max.d[nbDimsA - 1];
|
||||
|
||||
if (!mDims.isInitialized())
|
||||
{
|
||||
mDims = {minM, maxM, N, K};
|
||||
}
|
||||
mGemmId.n = N;
|
||||
mGemmId.k = K;
|
||||
}
|
||||
|
||||
size_t LoraPlugin::getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, int nbInputs,
|
||||
const nvinfer1::PluginTensorDesc* outputs, int nbOutputs) const noexcept
|
||||
{
|
||||
TLLM_LOG_DEBUG("%s", __PRETTY_FUNCTION__);
|
||||
const int nbReq = inputs[getLoraRanksIdx()].dims.d[0];
|
||||
auto const type = inputs[getInputTensorIdx()].type;
|
||||
auto const typeSize = tensorrt_llm::runtime::BufferDataType(type).getSize();
|
||||
|
||||
size_t const lowRankWorkSpaceSize = nbReq * mMaxContextLength * mMaxLowRank * typeSize;
|
||||
|
||||
return CUBLAS_WORKSPACE_SIZE + lowRankWorkSpaceSize;
|
||||
}
|
||||
|
||||
int LoraPlugin::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, const nvinfer1::PluginTensorDesc* outputDesc,
|
||||
const void* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept
|
||||
{
|
||||
TLLM_LOG_DEBUG("%s", __PRETTY_FUNCTION__);
|
||||
// inputs
|
||||
// input [-1, K] (view as 2D)
|
||||
// host_request_type [batch_size] on cpu
|
||||
// lora_ranks [batch_size] on cpu
|
||||
// lora_weights_ptr [batch_size, 2] on cpu
|
||||
// host_context_lengths [batch_size] on cpu
|
||||
// outputs
|
||||
// output [-1, N] (view as 2D)
|
||||
|
||||
auto const typeSize = tensorrt_llm::runtime::BufferDataType(mType).getSize();
|
||||
void* cublasWorkSpace = workspace;
|
||||
void* lowRankWorkSpace = static_cast<char*>(cublasWorkSpace) + CUBLAS_WORKSPACE_SIZE;
|
||||
|
||||
setGemmConfig();
|
||||
auto const batch_size = inputDesc[getLoraRanksIdx()].dims.d[0];
|
||||
auto const lora_ranks = static_cast<int32_t const*>(inputs[getLoraRanksIdx()]);
|
||||
auto const lora_weights_ptr = static_cast<int64_t const*>(inputs[getLoraWeightsPtrsIdx()]);
|
||||
auto const host_context_lengths
|
||||
= mRemoveInputPadding ? static_cast<int32_t const*>(inputs[getHostContextLengthsIdx()]) : nullptr;
|
||||
RequestType const* reqTypes = static_cast<RequestType const*>(inputs[getHostRequestTypesIdx()]);
|
||||
|
||||
size_t handled_token_num = 0;
|
||||
for (int batchIdx = 0; batchIdx < batch_size; batchIdx++)
|
||||
{
|
||||
const RequestType reqType = reqTypes[batchIdx];
|
||||
const auto M = (reqType != RequestType::kCONTEXT)
|
||||
? 1
|
||||
: (mRemoveInputPadding ? host_context_lengths[batchIdx] : inputDesc[0].dims.d[1]);
|
||||
const auto lora_rank = lora_ranks[batchIdx];
|
||||
|
||||
if (lora_rank <= 0)
|
||||
{
|
||||
const auto N = outputDesc[0].dims.d[outputDesc[0].dims.nbDims - 1];
|
||||
void* output = static_cast<void*>(static_cast<char*>(outputs[0]) + handled_token_num * N * typeSize);
|
||||
if (typeSize == 2)
|
||||
{
|
||||
deviceFill((half*) output, M * N, (half) 0.0f, stream);
|
||||
}
|
||||
else
|
||||
{
|
||||
deviceFill((float*) output, M * N, 0.0f, stream);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
// the input shape should be [1, token_num, K] under remove_input_padding,
|
||||
// [batch, seqlen, K] under non-remove_input_padding
|
||||
auto bestTactic = mPluginProfiler->getBestConfig(M, mGemmId);
|
||||
|
||||
const int nbDimsA = inputDesc[0].dims.nbDims;
|
||||
const auto N = lora_rank;
|
||||
|
||||
TLLM_CHECK_WITH_INFO(N <= mMaxLowRank,
|
||||
fmtstr("Invalid low_rank (%d). low_rank must be smaller than mMaxLowRank (%d)", N, mMaxLowRank));
|
||||
const auto K = mTransA ? inputDesc[0].dims.d[0] : inputDesc[0].dims.d[nbDimsA - 1]; // input hidden size
|
||||
const auto N2 = outputDesc[0].dims.d[nbDimsA - 1];
|
||||
// [M, K] -> [M, N] -> [M, N2]
|
||||
|
||||
void* lora_in_weight = reinterpret_cast<void*>(lora_weights_ptr[batchIdx * 2 + 0]);
|
||||
void* lora_out_weight = reinterpret_cast<void*>(lora_weights_ptr[batchIdx * 2 + 1]);
|
||||
const void* input
|
||||
= static_cast<const void*>(static_cast<const char*>(inputs[0]) + handled_token_num * K * typeSize);
|
||||
void* output = static_cast<void*>(static_cast<char*>(outputs[0]) + handled_token_num * N2 * typeSize);
|
||||
_runGemm(M, N, K, mTransA, mTransB, mType, mCublasWrapper, input, lora_in_weight, lowRankWorkSpace,
|
||||
bestTactic, cublasWorkSpace, stream);
|
||||
|
||||
_runGemm(M, N2, N, mTransA, mTransB, mType, mCublasWrapper, lowRankWorkSpace, lora_out_weight, output,
|
||||
bestTactic, cublasWorkSpace, stream);
|
||||
}
|
||||
handled_token_num += M;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
// IPluginV2Ext Methods
|
||||
nvinfer1::DataType LoraPlugin::getOutputDataType(
|
||||
int index, const nvinfer1::DataType* inputTypes, int nbInputs) const noexcept
|
||||
{
|
||||
TLLM_LOG_DEBUG("%s", __PRETTY_FUNCTION__);
|
||||
TLLM_CHECK(index == 0);
|
||||
return inputTypes[0];
|
||||
}
|
||||
|
||||
// IPluginV2 Methods
|
||||
|
||||
const char* LoraPlugin::getPluginType() const noexcept
|
||||
{
|
||||
TLLM_LOG_DEBUG("%s", __PRETTY_FUNCTION__);
|
||||
return LORA_PLUGIN_NAME;
|
||||
}
|
||||
|
||||
const char* LoraPlugin::getPluginVersion() const noexcept
|
||||
{
|
||||
TLLM_LOG_DEBUG("%s", __PRETTY_FUNCTION__);
|
||||
return LORA_PLUGIN_VERSION;
|
||||
}
|
||||
|
||||
int LoraPlugin::getNbOutputs() const noexcept
|
||||
{
|
||||
TLLM_LOG_DEBUG("%s", __PRETTY_FUNCTION__);
|
||||
return 1;
|
||||
}
|
||||
|
||||
int LoraPlugin::initialize() noexcept
|
||||
{
|
||||
TLLM_LOG_DEBUG("%s", __PRETTY_FUNCTION__);
|
||||
configGemm();
|
||||
return 0;
|
||||
}
|
||||
|
||||
void LoraPlugin::destroy() noexcept
|
||||
{
|
||||
TLLM_LOG_DEBUG("%s", __PRETTY_FUNCTION__);
|
||||
delete this;
|
||||
}
|
||||
|
||||
size_t LoraPlugin::getSerializationSize() const noexcept
|
||||
{
|
||||
TLLM_LOG_DEBUG("%s", __PRETTY_FUNCTION__);
|
||||
return sizeof(mInHiddenSize) + sizeof(mOutHiddenSize) + sizeof(mTransA) + sizeof(mTransB) + sizeof(mType)
|
||||
+ mPluginProfiler->getSerializationSize(mGemmId) + sizeof(mRemoveInputPadding) + sizeof(mMaxContextLength)
|
||||
+ sizeof(mMaxLowRank); // selected tactics container size
|
||||
}
|
||||
|
||||
void LoraPlugin::serialize(void* buffer) const noexcept
|
||||
{
|
||||
TLLM_LOG_DEBUG("%s", __PRETTY_FUNCTION__);
|
||||
char *d = static_cast<char*>(buffer), *a = d;
|
||||
write(d, mInHiddenSize);
|
||||
write(d, mOutHiddenSize);
|
||||
write(d, mTransA);
|
||||
write(d, mTransB);
|
||||
write(d, mType);
|
||||
write(d, mRemoveInputPadding);
|
||||
write(d, mMaxContextLength);
|
||||
write(d, mMaxLowRank);
|
||||
mPluginProfiler->serialize(d, mGemmId);
|
||||
|
||||
assert(d == a + getSerializationSize());
|
||||
}
|
||||
|
||||
void LoraPlugin::terminate() noexcept {}
|
||||
|
||||
///////////////
|
||||
|
||||
LoraPluginCreator::LoraPluginCreator()
|
||||
{
|
||||
TLLM_LOG_DEBUG("%s", __PRETTY_FUNCTION__);
|
||||
// Fill PluginFieldCollection with PluginField arguments metadata
|
||||
mPluginAttributes.clear();
|
||||
mPluginAttributes.emplace_back(PluginField("transA", nullptr, PluginFieldType::kINT32, 0));
|
||||
mPluginAttributes.emplace_back(PluginField("transB", nullptr, PluginFieldType::kINT32, 0));
|
||||
mPluginAttributes.emplace_back(PluginField("lora_module_number", nullptr, PluginFieldType::kINT32, 0));
|
||||
mPluginAttributes.emplace_back(PluginField("type_id", nullptr, PluginFieldType::kINT32, 1));
|
||||
mFC.nbFields = mPluginAttributes.size();
|
||||
mFC.fields = mPluginAttributes.data();
|
||||
}
|
||||
|
||||
const char* LoraPluginCreator::getPluginName() const noexcept
|
||||
{
|
||||
TLLM_LOG_DEBUG("%s", __PRETTY_FUNCTION__);
|
||||
return LORA_PLUGIN_NAME;
|
||||
}
|
||||
|
||||
const char* LoraPluginCreator::getPluginVersion() const noexcept
|
||||
{
|
||||
TLLM_LOG_DEBUG("%s", __PRETTY_FUNCTION__);
|
||||
return LORA_PLUGIN_VERSION;
|
||||
}
|
||||
|
||||
const PluginFieldCollection* LoraPluginCreator::getFieldNames() noexcept
|
||||
{
|
||||
TLLM_LOG_DEBUG("%s", __PRETTY_FUNCTION__);
|
||||
return &mFC;
|
||||
}
|
||||
|
||||
IPluginV2* LoraPluginCreator::createPlugin(const char* name, const PluginFieldCollection* fc) noexcept
|
||||
{
|
||||
TLLM_LOG_DEBUG("%s", __PRETTY_FUNCTION__);
|
||||
|
||||
const PluginField* fields = fc->fields;
|
||||
nvinfer1::DataType type;
|
||||
int lora_module_number;
|
||||
int in_hidden_size, out_hidden_size, transA, transB;
|
||||
bool remove_input_padding;
|
||||
int max_context_length;
|
||||
int max_low_rank;
|
||||
// Read configurations from each fields
|
||||
for (int i = 0; i < fc->nbFields; ++i)
|
||||
{
|
||||
const char* attrName = fields[i].name;
|
||||
if (!strcmp(attrName, "in_hidden_size"))
|
||||
{
|
||||
TLLM_CHECK(fields[i].type == PluginFieldType::kINT32);
|
||||
in_hidden_size = static_cast<int>(*(static_cast<const int*>(fields[i].data)));
|
||||
}
|
||||
else if (!strcmp(attrName, "out_hidden_size"))
|
||||
{
|
||||
TLLM_CHECK(fields[i].type == PluginFieldType::kINT32);
|
||||
out_hidden_size = static_cast<int>(*(static_cast<const int*>(fields[i].data)));
|
||||
}
|
||||
else if (!strcmp(attrName, "transa"))
|
||||
{
|
||||
TLLM_CHECK(fields[i].type == PluginFieldType::kINT32);
|
||||
transA = static_cast<int>(*(static_cast<const int*>(fields[i].data)));
|
||||
}
|
||||
else if (!strcmp(attrName, "transb"))
|
||||
{
|
||||
TLLM_CHECK(fields[i].type == PluginFieldType::kINT32);
|
||||
transB = static_cast<int>(*(static_cast<const int*>(fields[i].data)));
|
||||
}
|
||||
else if (!strcmp(attrName, "type_id"))
|
||||
{
|
||||
TLLM_CHECK(fields[i].type == PluginFieldType::kINT32);
|
||||
type = static_cast<nvinfer1::DataType>(*(static_cast<const nvinfer1::DataType*>(fields[i].data)));
|
||||
}
|
||||
else if (!strcmp(attrName, "remove_input_padding"))
|
||||
{
|
||||
TLLM_CHECK(fields[i].type == PluginFieldType::kINT8);
|
||||
remove_input_padding = static_cast<bool>(*(static_cast<const int8_t*>(fields[i].data)));
|
||||
}
|
||||
else if (!strcmp(attrName, "max_context_length"))
|
||||
{
|
||||
TLLM_CHECK(fields[i].type == PluginFieldType::kINT32);
|
||||
max_context_length = static_cast<int>(*(static_cast<const int*>(fields[i].data)));
|
||||
}
|
||||
else if (!strcmp(attrName, "max_low_rank"))
|
||||
{
|
||||
TLLM_CHECK(fields[i].type == PluginFieldType::kINT32);
|
||||
max_low_rank = static_cast<int>(*(static_cast<const int*>(fields[i].data)));
|
||||
}
|
||||
}
|
||||
try
|
||||
{
|
||||
// LoraPluginCreator is unique and shared for an engine generation
|
||||
// Create plugin profiler with shared tactics map
|
||||
// FIXME enable tactic profiler
|
||||
auto pluginProfiler = gemmPluginProfileManager.createGemmPluginProfiler(/* inference */ false, /* skip */ true);
|
||||
auto* obj = new LoraPlugin(in_hidden_size, out_hidden_size, transA, transB, lora_module_number, type,
|
||||
pluginProfiler, remove_input_padding, max_context_length, max_low_rank);
|
||||
obj->setPluginNamespace(mNamespace.c_str());
|
||||
return obj;
|
||||
}
|
||||
catch (const std::exception& e)
|
||||
{
|
||||
caughtError(e);
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
IPluginV2* LoraPluginCreator::deserializePlugin(const char* name, const void* serialData, size_t serialLength) noexcept
|
||||
{
|
||||
TLLM_LOG_DEBUG("%s", __PRETTY_FUNCTION__);
|
||||
// This object will be deleted when the network is destroyed, which will
|
||||
// call LoraPlugin::destroy()
|
||||
try
|
||||
{
|
||||
// LoraPluginCreator is unique and shared for an engine generation
|
||||
// Create plugin profiler with shared tactics map
|
||||
// FIXME enable tactic profiler
|
||||
auto pluginProfiler = gemmPluginProfileManager.createGemmPluginProfiler(/* inference */ true, /* skip */ true);
|
||||
auto* obj = new LoraPlugin(serialData, serialLength, pluginProfiler);
|
||||
obj->setPluginNamespace(mNamespace.c_str());
|
||||
return obj;
|
||||
}
|
||||
catch (const std::exception& e)
|
||||
{
|
||||
caughtError(e);
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
161
cpp/tensorrt_llm/plugins/loraPlugin/loraPlugin.h
Normal file
161
cpp/tensorrt_llm/plugins/loraPlugin/loraPlugin.h
Normal file
@ -0,0 +1,161 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 1993-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef TRT_LORA_PLUGIN_H
|
||||
#define TRT_LORA_PLUGIN_H
|
||||
#include "tensorrt_llm/common/cublasMMWrapper.h"
|
||||
#include "tensorrt_llm/plugins/common/gemmPluginProfiler.h"
|
||||
#include "tensorrt_llm/plugins/common/plugin.h"
|
||||
#include "tensorrt_llm/plugins/gemmPlugin/gemmPlugin.h"
|
||||
#include <cassert>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
namespace tensorrt_llm::plugins
|
||||
{
|
||||
|
||||
using CublasGemmWrapper = tensorrt_llm::common::CublasMMWrapper;
|
||||
using CublasGemmWrapperPtr = std::shared_ptr<CublasGemmWrapper>;
|
||||
|
||||
class LoraPlugin : public BasePlugin
|
||||
{
|
||||
public:
|
||||
using PluginProfilerPtr = std::shared_ptr<CublasLtGemmPluginProfiler>;
|
||||
|
||||
LoraPlugin() = delete;
|
||||
|
||||
LoraPlugin(int in_hidden_size, int out_hidden_size, int transA, int transB, int lora_module_number,
|
||||
nvinfer1::DataType type, const PluginProfilerPtr& profiler, bool remove_input_padding, int max_context_length,
|
||||
int max_low_rank);
|
||||
|
||||
LoraPlugin(const void* data, size_t length, const PluginProfilerPtr& profiler);
|
||||
|
||||
~LoraPlugin() override = default;
|
||||
|
||||
// IPluginV2DynamicExt Methods
|
||||
nvinfer1::IPluginV2DynamicExt* clone() const noexcept override;
|
||||
nvinfer1::DimsExprs getOutputDimensions(int outputIndex, const nvinfer1::DimsExprs* inputs, int nbInputs,
|
||||
nvinfer1::IExprBuilder& exprBuilder) noexcept override;
|
||||
bool supportsFormatCombination(
|
||||
int pos, const nvinfer1::PluginTensorDesc* inOut, int nbInputs, int nbOutputs) noexcept override;
|
||||
void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, int nbInputs,
|
||||
const nvinfer1::DynamicPluginTensorDesc* out, int nbOutputs) noexcept override;
|
||||
size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, int nbInputs,
|
||||
const nvinfer1::PluginTensorDesc* outputs, int nbOutputs) const noexcept override;
|
||||
int enqueue(const nvinfer1::PluginTensorDesc* inputDesc, const nvinfer1::PluginTensorDesc* outputDesc,
|
||||
const void* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept override;
|
||||
|
||||
// IPluginV2Ext Methods
|
||||
nvinfer1::DataType getOutputDataType(
|
||||
int index, const nvinfer1::DataType* inputTypes, int nbInputs) const noexcept override;
|
||||
|
||||
// IPluginV2 Methods
|
||||
const char* getPluginType() const noexcept override;
|
||||
const char* getPluginVersion() const noexcept override;
|
||||
int getNbOutputs() const noexcept override;
|
||||
int initialize() noexcept override;
|
||||
void terminate() noexcept override;
|
||||
size_t getSerializationSize() const noexcept override;
|
||||
void serialize(void* buffer) const noexcept override;
|
||||
void destroy() noexcept override;
|
||||
|
||||
private:
|
||||
void init();
|
||||
void configGemm();
|
||||
void setGemmConfig();
|
||||
|
||||
using IndexType = std::int32_t;
|
||||
|
||||
IndexType getInputTensorIdx() const
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
|
||||
IndexType getHostRequestTypesIdx() const
|
||||
{
|
||||
return 1;
|
||||
}
|
||||
|
||||
IndexType getLoraRanksIdx() const
|
||||
{
|
||||
return 2;
|
||||
}
|
||||
|
||||
IndexType getLoraWeightsPtrsIdx() const
|
||||
{
|
||||
return 3;
|
||||
}
|
||||
|
||||
IndexType getHostContextLengthsIdx() const
|
||||
{
|
||||
TLLM_CHECK(mRemoveInputPadding);
|
||||
return 4;
|
||||
}
|
||||
|
||||
enum class RequestType : int32_t
|
||||
{
|
||||
kCONTEXT = 0,
|
||||
kGENERATION = 1
|
||||
};
|
||||
|
||||
private:
|
||||
const std::string mLayerName;
|
||||
|
||||
int mInHiddenSize;
|
||||
int mOutHiddenSize;
|
||||
int mTransA;
|
||||
int mTransB;
|
||||
nvinfer1::DataType mType;
|
||||
bool mRemoveInputPadding;
|
||||
int mMaxContextLength;
|
||||
int mMaxLowRank;
|
||||
|
||||
// @fixme: seems this is shared across multiple clones.
|
||||
// If we deep copy the wrapper inside clone(), then we may avoid the mutex inside the wrapper?
|
||||
CublasGemmWrapperPtr mCublasWrapper;
|
||||
|
||||
GemmDims mDims{};
|
||||
GemmIdCublas mGemmId{};
|
||||
|
||||
PluginProfilerPtr mPluginProfiler;
|
||||
};
|
||||
|
||||
class LoraPluginCreator : public BaseCreator
|
||||
{
|
||||
public:
|
||||
LoraPluginCreator();
|
||||
|
||||
const char* getPluginName() const noexcept override;
|
||||
|
||||
const char* getPluginVersion() const noexcept override;
|
||||
|
||||
const nvinfer1::PluginFieldCollection* getFieldNames() noexcept override;
|
||||
|
||||
nvinfer1::IPluginV2* createPlugin(const char* name, const nvinfer1::PluginFieldCollection* fc) noexcept override;
|
||||
|
||||
nvinfer1::IPluginV2* deserializePlugin(
|
||||
const char* name, const void* serialData, size_t serialLength) noexcept override;
|
||||
|
||||
private:
|
||||
GemmPluginProfilerManager<CublasLtGemmPluginProfiler> gemmPluginProfileManager;
|
||||
static nvinfer1::PluginFieldCollection mFC;
|
||||
static std::vector<nvinfer1::PluginField> mPluginAttributes;
|
||||
};
|
||||
|
||||
} // namespace tensorrt_llm::plugins
|
||||
|
||||
#endif // TRT_LORA_PLUGIN_H
|
||||
@ -333,12 +333,13 @@ int WeightOnlyGroupwiseQuantMatmulPlugin::enqueue(const nvinfer1::PluginTensorDe
|
||||
const half* biases_ptr = (mQuantAlgo & BIAS) ? reinterpret_cast<const half*>(inputs[mBiasesInputIdx]) : nullptr;
|
||||
const half* act_ptr = reinterpret_cast<const half*>((mQuantAlgo & PRE_QUANT_SCALE) ? workspace : inputs[0]);
|
||||
|
||||
TLLM_CHECK_WITH_INFO(mType == nvinfer1::DataType::kHALF
|
||||
#if defined(ENABLE_BF16)
|
||||
|| mType == nvinfer1::DataType::kBF16
|
||||
#endif
|
||||
,
|
||||
TLLM_CHECK_WITH_INFO(mType == nvinfer1::DataType::kHALF || mType == nvinfer1::DataType::kBF16,
|
||||
"No valid weightOnlyGropwiseQuantMatmul configuration");
|
||||
#else
|
||||
TLLM_CHECK_WITH_INFO(mType == nvinfer1::DataType::kHALF, "No valid weightOnlyGropwiseQuantMatmul configuration");
|
||||
#endif
|
||||
|
||||
tensorrt_llm::kernels::WeightOnlyActivationType weight_only_act_type;
|
||||
int real_n = n * INT8_INT4_RATIO;
|
||||
if (mType == nvinfer1::DataType::kHALF)
|
||||
|
||||
@ -288,12 +288,12 @@ int WeightOnlyQuantMatmulPlugin::enqueue(const nvinfer1::PluginTensorDesc* input
|
||||
const int ws_size = m_weightOnlyGemmRunner->getWorkspaceSize(m, n, k);
|
||||
const auto& bestTactic = mPluginProfiler->getBestConfig(m, mGemmId);
|
||||
TLLM_CHECK_WITH_INFO(bestTactic, "No valid weight only groupwise GEMM tactic");
|
||||
TLLM_CHECK_WITH_INFO(mType == nvinfer1::DataType::kHALF ||
|
||||
#if defined(ENABLE_BF16)
|
||||
mType == nvinfer1::DataType::kBF16
|
||||
#endif
|
||||
,
|
||||
TLLM_CHECK_WITH_INFO(mType == nvinfer1::DataType::kHALF || mType == nvinfer1::DataType::kBF16,
|
||||
"No valid weightOnlyQuantMatmul configuration");
|
||||
#else
|
||||
TLLM_CHECK_WITH_INFO(mType == nvinfer1::DataType::kHALF, "No valid weightOnlyQuantMatmul configuration");
|
||||
#endif
|
||||
|
||||
tensorrt_llm::kernels::WeightOnlyQuantType weight_only_quant_type;
|
||||
tensorrt_llm::kernels::WeightOnlyActivationType weight_only_act_type;
|
||||
|
||||
@ -15,6 +15,7 @@
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <pybind11/functional.h>
|
||||
#include <pybind11/operators.h>
|
||||
#include <torch/extension.h>
|
||||
|
||||
@ -72,7 +73,8 @@ PYBIND11_MODULE(TRTLLM_PYBIND_MODULE, m)
|
||||
.def_readwrite("ids", &tpr::GenerationOutput::ids)
|
||||
.def_readwrite("lengths", &tpr::GenerationOutput::lengths)
|
||||
.def_readwrite("log_probs", &tpr::GenerationOutput::logProbs)
|
||||
.def_readwrite("context_logits", &tpr::GenerationOutput::contextLogits);
|
||||
.def_readwrite("context_logits", &tpr::GenerationOutput::contextLogits)
|
||||
.def_readwrite("on_token_generated", &tpr::GenerationOutput::onTokenGenerated);
|
||||
|
||||
py::class_<tb::kv_cache_manager::KvCacheConfig>(m, "KvCacheConfig")
|
||||
.def(py::init<std::optional<tr::SizeType>, std::optional<tr::SizeType>, std::optional<float>>(),
|
||||
@ -175,6 +177,9 @@ PYBIND11_MODULE(TRTLLM_PYBIND_MODULE, m)
|
||||
.def_property("compute_context_logits",
|
||||
py::overload_cast<>(&tr::GptModelConfig::computeContextLogits, py::const_),
|
||||
py::overload_cast<bool>(&tr::GptModelConfig::computeContextLogits))
|
||||
.def_property("compute_generation_logits",
|
||||
py::overload_cast<>(&tr::GptModelConfig::computeGenerationLogits, py::const_),
|
||||
py::overload_cast<bool>(&tr::GptModelConfig::computeGenerationLogits))
|
||||
.def_property("model_variant", &tr::GptModelConfig::getModelVariant, &tr::GptModelConfig::setModelVariant)
|
||||
.def_property("use_custom_all_reduce", py::overload_cast<>(&tr::GptModelConfig::useCustomAllReduce, py::const_),
|
||||
py::overload_cast<bool>(&tr::GptModelConfig::useCustomAllReduce));
|
||||
|
||||
@ -16,6 +16,7 @@
|
||||
*/
|
||||
#include "generationOutput.h"
|
||||
|
||||
#include "tensorrt_llm/runtime/torch.h"
|
||||
#include "tensorrt_llm/runtime/torchView.h"
|
||||
|
||||
namespace tr = tensorrt_llm::runtime;
|
||||
@ -34,6 +35,12 @@ std::shared_ptr<tr::GenerationOutput> GenerationOutput::toTrtLlm() const
|
||||
{
|
||||
output->contextLogits = tr::TorchView::of(contextLogits.value());
|
||||
}
|
||||
// TODO(mseznec): add support for onTokenGenerated
|
||||
|
||||
if (onTokenGenerated)
|
||||
{
|
||||
output->onTokenGenerated = [delegate = onTokenGenerated](
|
||||
tr::GenerationOutput::TensorPtr const& ids, tr::SizeType step, bool finished)
|
||||
{ delegate(tr::Torch::tensor(ids), step, finished); };
|
||||
}
|
||||
return output;
|
||||
}
|
||||
|
||||
@ -114,7 +114,8 @@ void BufferManager::copy(IBuffer const& src, void* dst, MemoryType dstType) cons
|
||||
|
||||
void BufferManager::copy(IBuffer const& src, IBuffer& dst) const
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(src.getDataType() == dst.getDataType(), "Incompatible data types");
|
||||
TLLM_CHECK_WITH_INFO(src.getDataType() == dst.getDataType(),
|
||||
tc::fmtstr("Incompatible data types: %s != %s", src.getDataTypeName(), dst.getDataTypeName()));
|
||||
TLLM_CHECK_WITH_INFO(src.getSizeInBytes() == dst.getSizeInBytes(),
|
||||
tc::fmtstr("Incompatible buffer sizes: %lu != %lu", src.getSizeInBytes(), dst.getSizeInBytes()));
|
||||
copy(src, dst.data(), dst.getMemoryType());
|
||||
@ -192,3 +193,49 @@ void BufferManager::initMemoryPool(int device)
|
||||
auto maxThreshold = std::numeric_limits<std::uint64_t>::max();
|
||||
TLLM_CUDA_CHECK(cudaMemPoolSetAttribute(memPool, cudaMemPoolAttrReleaseThreshold, &maxThreshold));
|
||||
}
|
||||
|
||||
std::size_t BufferManager::memoryPoolReserved(int device)
|
||||
{
|
||||
::cudaMemPool_t memPool;
|
||||
TLLM_CUDA_CHECK(cudaDeviceGetDefaultMemPool(&memPool, device));
|
||||
std::size_t reserved = 0;
|
||||
TLLM_CUDA_CHECK(cudaMemPoolGetAttribute(memPool, cudaMemPoolAttrReservedMemCurrent, &reserved));
|
||||
return reserved;
|
||||
}
|
||||
|
||||
std::size_t BufferManager::memoryPoolUsed(int device)
|
||||
{
|
||||
::cudaMemPool_t memPool;
|
||||
TLLM_CUDA_CHECK(cudaDeviceGetDefaultMemPool(&memPool, device));
|
||||
std::size_t used = 0;
|
||||
TLLM_CUDA_CHECK(cudaMemPoolGetAttribute(memPool, cudaMemPoolAttrUsedMemCurrent, &used));
|
||||
return used;
|
||||
}
|
||||
|
||||
void BufferManager::memoryPoolTrimTo(int device, std::size_t size)
|
||||
{
|
||||
::cudaMemPool_t memPool;
|
||||
TLLM_CUDA_CHECK(cudaDeviceGetDefaultMemPool(&memPool, device));
|
||||
TLLM_CUDA_CHECK(cudaMemPoolTrimTo(memPool, size));
|
||||
}
|
||||
|
||||
std::size_t BufferManager::memoryPoolReserved() const
|
||||
{
|
||||
return memoryPoolReserved(mStream->getDevice());
|
||||
}
|
||||
|
||||
std::size_t BufferManager::memoryPoolUsed() const
|
||||
{
|
||||
return memoryPoolUsed(mStream->getDevice());
|
||||
}
|
||||
|
||||
std::size_t BufferManager::memoryPoolFree() const
|
||||
{
|
||||
return memoryPoolFree(mStream->getDevice());
|
||||
}
|
||||
|
||||
void BufferManager::memoryPoolTrimTo(std::size_t size)
|
||||
{
|
||||
mStream->synchronize();
|
||||
memoryPoolTrimTo(mStream->getDevice(), size);
|
||||
}
|
||||
|
||||
@ -41,10 +41,13 @@ GptDecoder<T>::GptDecoder(size_t vocabSize, size_t vocabSizePadded, CudaStreamPt
|
||||
|
||||
mDynamicDecodeLayer = std::make_shared<tensorrt_llm::layers::DynamicDecodeLayer<T>>(
|
||||
vocabSize, vocabSizePadded, stream->get(), &mAllocator, isFreeBufferAfterForward, &prop);
|
||||
|
||||
auto constexpr nvFloatType = TRTDataType<float>::value;
|
||||
mLogProbsTiled = mManager.emptyTensor(MemoryType::kGPU, nvFloatType);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void GptDecoder<T>::setup(SamplingConfig const& samplingConfig, size_t batchSize)
|
||||
void GptDecoder<T>::setup(SamplingConfig const& samplingConfig, size_t batchSize, SizeType maxSequenceLength)
|
||||
{
|
||||
typename layers::DynamicDecodeLayer<T>::SetupParams setupParams;
|
||||
|
||||
@ -71,6 +74,10 @@ void GptDecoder<T>::setup(SamplingConfig const& samplingConfig, size_t batchSize
|
||||
setupParams.length_penalty = samplingConfig.lengthPenalty;
|
||||
|
||||
mDynamicDecodeLayer->setup(batchSize, samplingConfig.beamWidth, setupParams);
|
||||
|
||||
mLogProbsTiled->reshape(
|
||||
ITensor::makeShape({maxSequenceLength, static_cast<SizeType>(batchSize), samplingConfig.beamWidth}));
|
||||
mManager.setZero(*mLogProbsTiled);
|
||||
}
|
||||
|
||||
namespace
|
||||
@ -128,7 +135,7 @@ typename tl::DynamicDecodeLayer<T>::ForwardParams prepareInputs(DecodingInput co
|
||||
|
||||
template <typename T>
|
||||
typename tl::DynamicDecodeLayer<T>::OutputParams prepareOutputs(
|
||||
DecodingOutput& output, DecodingInput::TensorPtr const& inputLengths)
|
||||
DecodingOutput& output, DecodingInput::TensorPtr const& inputLengths, DecodingOutput::TensorPtr& logProbsTiled)
|
||||
{
|
||||
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
|
||||
typename tl::DynamicDecodeLayer<T>::OutputParams outputParams(tcc::toTllmTensor(*output.ids));
|
||||
@ -168,6 +175,7 @@ typename tl::DynamicDecodeLayer<T>::OutputParams prepareOutputs(
|
||||
if (output.logProbs)
|
||||
{
|
||||
outputParams.output_log_probs = tcc::toTllmTensor(*output.logProbs);
|
||||
outputParams.output_log_probs_tiled = tcc::toTllmTensor(*logProbsTiled);
|
||||
}
|
||||
|
||||
outputParams.beamHypotheses = std::make_shared<tensorrt_llm::kernels::BeamHypotheses>();
|
||||
@ -218,7 +226,7 @@ bool GptDecoder<T>::forward(DecodingOutput& output, DecodingInput const& input)
|
||||
{
|
||||
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
|
||||
auto forwardParams = prepareInputs<T>(input);
|
||||
auto outputParams = prepareOutputs<T>(output, input.lengths);
|
||||
auto outputParams = prepareOutputs<T>(output, input.lengths, mLogProbsTiled);
|
||||
|
||||
BufferManager::ITensorPtr finishedSum;
|
||||
std::int32_t* finishedSumHost = nullptr;
|
||||
@ -256,19 +264,14 @@ void GptDecoder<T>::forwardAsync(DecodingOutput& output, DecodingInput const& in
|
||||
{
|
||||
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
|
||||
auto forwardParams = prepareInputs<T>(input);
|
||||
auto outputParams = prepareOutputs<T>(output, input.lengths);
|
||||
auto outputParams = prepareOutputs<T>(output, input.lengths, mLogProbsTiled);
|
||||
|
||||
mDynamicDecodeLayer->forward(outputParams, forwardParams);
|
||||
}
|
||||
|
||||
namespace tensorrt_llm::runtime
|
||||
{
|
||||
template class GptDecoder<float>;
|
||||
template class GptDecoder<half>;
|
||||
} // namespace tensorrt_llm::runtime
|
||||
|
||||
// this should be similar to gatherTree in cpp/tensorrt_llm/thop/gatherTreeOp.cpp
|
||||
void IGptDecoder::gatherTree(ITensor& finalOutputIds, DecodingOutput const& decodingOutput,
|
||||
template <typename T>
|
||||
void GptDecoder<T>::gatherTree(ITensor& finalOutputIds, DecodingOutput const& decodingOutput,
|
||||
DecodingInput const& decodingInput, BufferManager const& manager)
|
||||
{
|
||||
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
|
||||
@ -300,7 +303,7 @@ void IGptDecoder::gatherTree(ITensor& finalOutputIds, DecodingOutput const& deco
|
||||
beamHypotheses.sequence_lengths_src = bufferCast<SizeType>(*decodingOutput.lengths);
|
||||
beamHypotheses.parent_ids_src = bufferCast<TokenIdType>(*decodingOutput.parentIds);
|
||||
beamHypotheses.output_ids_src = bufferCast<TokenIdType>(*decodingOutput.ids);
|
||||
beamHypotheses.log_probs_src = nullptr;
|
||||
beamHypotheses.log_probs_src = bufferCast<float>(*mLogProbsTiled);
|
||||
beamHypotheses.max_seq_len = maxSeqLength;
|
||||
beamHypotheses.length_penalties
|
||||
= nullptr; // TODO (bhsueh) should set length penalties, this should be a gpu tensor When it is set as
|
||||
@ -316,17 +319,24 @@ void IGptDecoder::gatherTree(ITensor& finalOutputIds, DecodingOutput const& deco
|
||||
beamHypotheses.is_done = bufferCast<bool>(*decodingOutput.beamHypotheses.isDone);
|
||||
beamHypotheses.input_lengths = bufferCast<SizeType>(*decodingInput.lengths);
|
||||
|
||||
// This is where transpose is done
|
||||
tensorrt_llm::kernels::invokeInsertUnfinishedPath(beamHypotheses, bufferCast<bool>(*decodingOutput.finished),
|
||||
bufferCast<float>(*decodingOutput.cumLogProbs), batchSize, beamWidth, stream.get());
|
||||
sync_check_cuda_error();
|
||||
|
||||
tensorrt_llm::kernels::invokeFinalize(bufferCast<TokenIdType>(finalOutputIds),
|
||||
bufferCast<SizeType>(*decodingOutput.lengths), bufferCast<float>(*decodingOutput.cumLogProbs),
|
||||
nullptr, // output_logs
|
||||
beamHypotheses.output_ids_tgt, beamHypotheses.sequence_lengths_tgt, beamHypotheses.normed_scores,
|
||||
beamHypotheses.cum_log_probs, beamHypotheses.log_probs, beamHypotheses.num_beams, beamHypotheses.input_lengths,
|
||||
beamWidth, maxSeqLength, batchSize, stream.get());
|
||||
decodingOutput.logProbs ? bufferCast<float>(*decodingOutput.logProbs) : nullptr, beamHypotheses.output_ids_tgt,
|
||||
beamHypotheses.sequence_lengths_tgt, beamHypotheses.normed_scores, beamHypotheses.cum_log_probs,
|
||||
beamHypotheses.log_probs, beamHypotheses.num_beams, beamHypotheses.input_lengths, beamWidth, maxSeqLength,
|
||||
batchSize, stream.get());
|
||||
sync_check_cuda_error();
|
||||
|
||||
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
namespace tensorrt_llm::runtime
|
||||
{
|
||||
template class GptDecoder<float>;
|
||||
template class GptDecoder<half>;
|
||||
} // namespace tensorrt_llm::runtime
|
||||
|
||||
@ -100,6 +100,7 @@ GptDecoderBatch::GptDecoderBatch(
|
||||
mFinishedSum = mBufferManager.pinned(ITensor::makeShape({1}), nvSizeType);
|
||||
dOutput->lengths = mBufferManager.emptyTensor(MemoryType::kGPU, nvSizeType);
|
||||
dOutput->cumLogProbs = mBufferManager.emptyTensor(MemoryType::kGPU, nvFloatType);
|
||||
dOutput->logProbs = mBufferManager.emptyTensor(MemoryType::kGPU, nvFloatType);
|
||||
dOutput->beamHypotheses.empty(mBufferManager);
|
||||
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
@ -144,10 +145,14 @@ void GptDecoderBatch::setup(SizeType maxBatchSize, SizeType maxBeamWidth, SizeTy
|
||||
dOutput.finishedSum->reshape(maxBatchSizeShape);
|
||||
mBufferManager.setZero(*dOutput.finishedSum);
|
||||
|
||||
dOutput.cumLogProbs->reshape(maxBatchSizeXmaxBeamWidth);
|
||||
mBufferManager.setZero(*dOutput.cumLogProbs);
|
||||
|
||||
dOutput.logProbs->reshape(ITensor::makeShape({maxBatchSize, maxBeamWidth, mMaxSequenceLength}));
|
||||
mBufferManager.setZero(*dOutput.logProbs);
|
||||
|
||||
if (maxBeamWidth > 1)
|
||||
{
|
||||
dOutput.cumLogProbs->reshape(maxBatchSizeXmaxBeamWidth);
|
||||
mBufferManager.setZero(*dOutput.cumLogProbs);
|
||||
dOutput.beamHypotheses.reshape(maxBatchSize, maxBeamWidth, mMaxSequenceLength);
|
||||
}
|
||||
else
|
||||
@ -262,13 +267,25 @@ void GptDecoderBatch::newRequest(
|
||||
dOutput->newTokens = ITensor::slice(dJointOutput.newTokens, batchIdx, localBatchSize);
|
||||
manager.setZero(*dOutput->newTokens);
|
||||
|
||||
if (beamWidth > 1)
|
||||
// cumLogProb is mandatory for beamWidth > 1
|
||||
dOutput->cumLogProbs = nullptr;
|
||||
if (request.computeCumLogProbs || beamWidth > 1)
|
||||
{
|
||||
dOutput->cumLogProbs = ITensor::slice(dJointOutput.cumLogProbs, batchIdx, localBatchSize);
|
||||
manager.setZero(*IBuffer::slice(dOutput->cumLogProbs, 0, 1));
|
||||
manager.setZero(*dOutput->cumLogProbs);
|
||||
}
|
||||
|
||||
dOutput->logProbs = nullptr;
|
||||
if (request.computeLogProbs)
|
||||
{
|
||||
dOutput->logProbs = ITensor::slice(dJointOutput.logProbs, batchIdx, localBatchSize);
|
||||
manager.setZero(*dOutput->logProbs);
|
||||
}
|
||||
|
||||
if (beamWidth > 1)
|
||||
{
|
||||
kernels::invokeFill(
|
||||
*IBuffer::slice(dOutput->cumLogProbs, 1, beamWidth - 1), DecodingOutput::kNegativeInfinity, *stream);
|
||||
|
||||
dOutput->parentIds = ITensor::slice(dJointOutput.parentIds, batchIdx, localBatchSize);
|
||||
dOutput->parentIds->reshape(outputIdsShape);
|
||||
manager.setZero(*dOutput->parentIds);
|
||||
@ -277,7 +294,7 @@ void GptDecoderBatch::newRequest(
|
||||
}
|
||||
|
||||
// remaining
|
||||
mDecoders[batchIdx]->setup(samplingConfig, localBatchSize);
|
||||
mDecoders[batchIdx]->setup(samplingConfig, localBatchSize, mMaxSequenceLength);
|
||||
mBeamWidths[batchIdx] = beamWidth;
|
||||
mNbSteps[batchIdx] = 0;
|
||||
mFinished[batchIdx] = false;
|
||||
@ -407,6 +424,7 @@ CudaEvent GptDecoderBatch::postProcessRequest(SizeType batchIdx) const
|
||||
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
|
||||
auto& stream = mStreams[batchIdx];
|
||||
auto manager = BufferManager{stream};
|
||||
auto& decoder = *mDecoders[batchIdx];
|
||||
|
||||
auto& dInput = *mDecodingInputs[batchIdx];
|
||||
auto& dOutput = *mDecodingOutputs[batchIdx];
|
||||
@ -414,7 +432,7 @@ CudaEvent GptDecoderBatch::postProcessRequest(SizeType batchIdx) const
|
||||
// TODO can we do this inplace?
|
||||
auto& outputIds = dOutput.ids;
|
||||
auto finalOutputIds = manager.gpu(outputIds->getShape(), outputIds->getDataType());
|
||||
IGptDecoder::gatherTree(*finalOutputIds, dOutput, dInput, manager);
|
||||
decoder.gatherTree(*finalOutputIds, dOutput, dInput, manager);
|
||||
manager.copy(*finalOutputIds, *outputIds);
|
||||
|
||||
CudaEvent event{};
|
||||
@ -424,7 +442,8 @@ CudaEvent GptDecoderBatch::postProcessRequest(SizeType batchIdx) const
|
||||
return event;
|
||||
}
|
||||
|
||||
void GptDecoderBatch::newBatch(GenerationInput const& inputs, SamplingConfig const& samplingConfig)
|
||||
void GptDecoderBatch::newBatch(
|
||||
GenerationInput const& inputs, GenerationOutput const& outputs, SamplingConfig const& samplingConfig)
|
||||
{
|
||||
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
|
||||
// split batch into single requests
|
||||
@ -458,7 +477,10 @@ void GptDecoderBatch::newBatch(GenerationInput const& inputs, SamplingConfig con
|
||||
inputView = ITensor::slice(inputs.ids, batchIdx, 1);
|
||||
inputView->reshape(inputShape);
|
||||
}
|
||||
auto request = decoder_batch::Request{inputView, inputs.maxNewTokens, inputs.endId, inputs.padId};
|
||||
|
||||
auto request = decoder_batch::Request{inputView, inputs.maxNewTokens, inputs.endId};
|
||||
request.computeCumLogProbs = (outputs.cumLogProbs != nullptr);
|
||||
request.computeLogProbs = (outputs.logProbs != nullptr);
|
||||
|
||||
if (inputs.embeddingBias)
|
||||
{
|
||||
@ -517,7 +539,7 @@ void GptDecoderBatch::forwardSync()
|
||||
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
IStatefulGptDecoder::TensorPtr GptDecoderBatch::getFinalOutputIds() const
|
||||
void GptDecoderBatch::finalize() const
|
||||
{
|
||||
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
|
||||
for (SizeType batchIdx = 0; batchIdx < mActualBatchSize; ++batchIdx)
|
||||
@ -525,13 +547,12 @@ IStatefulGptDecoder::TensorPtr GptDecoderBatch::getFinalOutputIds() const
|
||||
postProcessRequest(batchIdx);
|
||||
}
|
||||
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
|
||||
return getOutputIds();
|
||||
}
|
||||
|
||||
std::tuple<CudaEvent, IStatefulGptDecoder::TensorPtr> GptDecoderBatch::getFinalOutputIds(SizeType batchIdx) const
|
||||
CudaEvent GptDecoderBatch::finalize(SizeType batchIdx) const
|
||||
{
|
||||
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
|
||||
auto event = postProcessRequest(batchIdx);
|
||||
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
|
||||
return {std::move(event), getOutputIds(batchIdx)};
|
||||
return event;
|
||||
}
|
||||
|
||||
@ -107,6 +107,7 @@ GptJsonConfig parseJson(InputType&& i)
|
||||
= parseJsonFieldOr<SizeType>(builderConfig, "max_prompt_embedding_table_size", 0);
|
||||
|
||||
auto const computeContextLogits = parseJsonFieldOr(builderConfig, "gather_all_token_logits", false);
|
||||
auto const computeGenerationLogits = parseJsonFieldOr(builderConfig, "gather_all_token_logits", false);
|
||||
|
||||
auto const& pluginConfig = json.at("plugin_config");
|
||||
auto const pagedKvCache = pluginConfig.at("paged_kv_cache");
|
||||
@ -125,6 +126,7 @@ GptJsonConfig parseJson(InputType&& i)
|
||||
modelConfig.setQuantMode(quantMode);
|
||||
modelConfig.setNbKvHeads(numKvHeads);
|
||||
modelConfig.computeContextLogits(computeContextLogits);
|
||||
modelConfig.computeGenerationLogits(computeGenerationLogits);
|
||||
|
||||
modelConfig.setMaxBatchSize(maxBatchSize);
|
||||
modelConfig.setMaxInputLen(maxInputLen);
|
||||
@ -132,10 +134,10 @@ GptJsonConfig parseJson(InputType&& i)
|
||||
modelConfig.setMaxNumTokens(maxNumTokens);
|
||||
modelConfig.setMaxPromptEmbeddingTableSize(maxPromptEmbeddingTableSize);
|
||||
|
||||
if (name == std::string("chatglm-6b"))
|
||||
if (name == std::string("chatglm_6b") || name == std::string("glm_10b"))
|
||||
{
|
||||
modelConfig.setModelVariant(GptModelConfig::ModelVariant::kGlm);
|
||||
// kGlm is only for ChatGLM-6B and Glm-10B
|
||||
// kGlm is only for ChatGLM-6B and GLM-10B
|
||||
}
|
||||
|
||||
return GptJsonConfig{name, precision, tensorParallelism, pipelineParallelism, modelConfig};
|
||||
|
||||
@ -55,7 +55,7 @@ GptSession::GptSession(Config const& sessionConfig, GptModelConfig const& modelC
|
||||
{
|
||||
if (mWorldConfig.isPipelineParallel())
|
||||
{
|
||||
mPipelineComm = NcclCommunicator::createPipelineComm(mWorldConfig, *mLogger);
|
||||
mPipelineComm = NcclCommunicator::createPipelineComm(mWorldConfig);
|
||||
mCommStream = std::make_shared<CudaStream>();
|
||||
}
|
||||
|
||||
@ -72,7 +72,7 @@ nvinfer1::ILogger& GptSession::getLogger() const
|
||||
return *mLogger;
|
||||
}
|
||||
|
||||
BufferManager& GptSession::getBufferManager() const
|
||||
BufferManager const& GptSession::getBufferManager() const
|
||||
{
|
||||
return mRuntime->getBufferManager();
|
||||
}
|
||||
@ -163,7 +163,8 @@ void GptSession::createKvCacheManager(SizeType batchSize, SizeType beamWidth, Si
|
||||
kvDtype = mModelConfig.getDataType();
|
||||
}
|
||||
|
||||
auto const maxNumTokens = bmkv::KVCacheManager::getMaxNumTokens(config, kvDtype, mModelConfig, mWorldConfig);
|
||||
auto const maxNumTokens
|
||||
= bmkv::KVCacheManager::getMaxNumTokens(config, kvDtype, mModelConfig, mWorldConfig, getBufferManager());
|
||||
TLLM_LOG_INFO("Using %d tokens in paged KV cache.", maxNumTokens);
|
||||
auto const maxNumBlocks = tc::ceilDiv(maxNumTokens, tokensPerBlock);
|
||||
auto const maxBlocksPerSeq = tc::ceilDiv(maxSequenceLength, tokensPerBlock);
|
||||
@ -302,12 +303,12 @@ void GptSession::kvCacheAddSequences(SizeType beamWidth, SizeType microBatchId,
|
||||
}
|
||||
|
||||
ITensor::SharedPtr GptSession::initDecoder(ITensor& outputIds, GenerationInput const& inputs,
|
||||
SamplingConfig const& samplingConfig, SizeType microBatchId) const
|
||||
GenerationOutput const& outputs, SamplingConfig const& samplingConfig, SizeType microBatchId) const
|
||||
{
|
||||
if (mWorldConfig.isLastPipelineParallelRank())
|
||||
{
|
||||
auto& decoder = mDecoders.at(microBatchId);
|
||||
decoder->newBatch(inputs, samplingConfig);
|
||||
decoder->newBatch(inputs, outputs, samplingConfig);
|
||||
return decoder->getNewTokens();
|
||||
}
|
||||
else if (mWorldConfig.isFirstPipelineParallelRank())
|
||||
@ -444,6 +445,39 @@ std::vector<GenerationInput> splitInputs(GenerationInput const& inputs, SizeType
|
||||
return inputBatches;
|
||||
}
|
||||
|
||||
std::vector<GenerationOutput> splitOutputs(GenerationOutput& outputs, SizeType microBatchSize, BufferManager& manager)
|
||||
{
|
||||
auto const numRequests = outputs.ids->getShape().d[0];
|
||||
|
||||
std::vector<GenerationOutput> outputBatches;
|
||||
for (auto batchOffset = 0; batchOffset < numRequests; batchOffset += microBatchSize)
|
||||
{
|
||||
auto const batchSize = std::min(microBatchSize, numRequests - batchOffset);
|
||||
|
||||
outputBatches.emplace_back(ITensor::slice(outputs.ids, batchOffset, batchSize),
|
||||
ITensor::slice(outputs.lengths, batchOffset, batchSize));
|
||||
|
||||
if (outputs.cumLogProbs)
|
||||
{
|
||||
outputBatches.back().cumLogProbs = ITensor::slice(outputs.cumLogProbs, batchOffset, batchSize);
|
||||
}
|
||||
if (outputs.logProbs)
|
||||
{
|
||||
outputBatches.back().logProbs = ITensor::slice(outputs.logProbs, batchOffset, batchSize);
|
||||
}
|
||||
if (outputs.contextLogits)
|
||||
{
|
||||
outputBatches.back().contextLogits = ITensor::slice(outputs.contextLogits, batchOffset, batchSize);
|
||||
}
|
||||
if (outputs.generationLogits)
|
||||
{
|
||||
outputBatches.back().generationLogits = ITensor::slice(outputs.generationLogits, batchOffset, batchSize);
|
||||
}
|
||||
}
|
||||
|
||||
return outputBatches;
|
||||
}
|
||||
|
||||
void updateOutputIds(ITensor::SharedPtr const& outputIds, ITensor::SharedPtr const& newTokens, SizeType decoderStep,
|
||||
CudaStream const& stream)
|
||||
{ // assemble outputIds of all micro batches
|
||||
@ -473,32 +507,86 @@ void GptSession::generate(
|
||||
auto const beamWidth = samplingConfig.beamWidth;
|
||||
outputs.ids->reshape(ITensor::makeShape({batchSize, beamWidth, mDecoderMaxSequenceLength}));
|
||||
outputs.lengths->reshape(ITensor::makeShape({batchSize, beamWidth}));
|
||||
if (mWorldConfig.isLastPipelineParallelRank() && mModelConfig.computeContextLogits())
|
||||
if (mWorldConfig.isLastPipelineParallelRank())
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(outputs.contextLogits,
|
||||
"outputs.contextLogits is nullptr. It must be allocated when computeContextLogits() is enabled.");
|
||||
auto const vocabSizePadded = mModelConfig.getVocabSizePadded(mWorldConfig.getSize());
|
||||
auto const inputLengthsHost = manager.copyFrom(*inputLengths, MemoryType::kCPU);
|
||||
auto const inputLengthsRange = BufferRange<SizeType>(*inputLengthsHost);
|
||||
auto const maxInputLength = *std::max_element(inputLengthsRange.begin(), inputLengthsRange.end());
|
||||
outputs.contextLogits->reshape(ITensor::makeShape({batchSize, maxInputLength, vocabSizePadded}));
|
||||
if (outputs.cumLogProbs)
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(outputs.cumLogProbs,
|
||||
"outputs.cumLogProbs is nullptr. It must be allocated when computeLogProbs is true");
|
||||
outputs.cumLogProbs->reshape(ITensor::makeShape({batchSize, beamWidth}));
|
||||
}
|
||||
if (outputs.logProbs)
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(
|
||||
outputs.logProbs, "outputs.logProbs is nullptr. It must be allocated when computeLogProbs is true");
|
||||
outputs.logProbs->reshape(ITensor::makeShape({batchSize, beamWidth, mDecoderMaxSequenceLength}));
|
||||
}
|
||||
if (mModelConfig.computeContextLogits() || mModelConfig.computeGenerationLogits())
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(outputs.contextLogits,
|
||||
"outputs.contextLogits is nullptr. It must be allocated when computeContextLogits() is enabled.");
|
||||
auto const vocabSizePadded = mModelConfig.getVocabSizePadded(mWorldConfig.getSize());
|
||||
auto const inputLengthsHost = manager.copyFrom(*inputLengths, MemoryType::kCPU);
|
||||
auto const inputLengthsRange = BufferRange<SizeType>(*inputLengthsHost);
|
||||
auto const maxInputLength = *std::max_element(inputLengthsRange.begin(), inputLengthsRange.end());
|
||||
|
||||
if (mModelConfig.computeContextLogits())
|
||||
{
|
||||
outputs.contextLogits->reshape(ITensor::makeShape({batchSize, maxInputLength, vocabSizePadded}));
|
||||
}
|
||||
|
||||
// Initialize the output generation logits buffer
|
||||
if (mModelConfig.computeGenerationLogits())
|
||||
{
|
||||
SizeType maxNewTokens = 0;
|
||||
if (inputs.maxNewTokens)
|
||||
{
|
||||
maxNewTokens = inputs.maxNewTokens.value();
|
||||
}
|
||||
else
|
||||
{
|
||||
for (auto iter = inputLengthsRange.begin(); iter != inputLengthsRange.end(); iter++)
|
||||
{
|
||||
maxNewTokens = std::max(maxNewTokens, mDecoderMaxSequenceLength - *iter);
|
||||
}
|
||||
}
|
||||
|
||||
TLLM_CHECK_WITH_INFO(maxNewTokens, "maxNewTokens is null");
|
||||
|
||||
TLLM_CHECK_WITH_INFO(outputs.generationLogits,
|
||||
"outputs.generationLogits is nullptr. It must be allocated when computeGenerationLogits() is "
|
||||
"enabled.");
|
||||
outputs.generationLogits->reshape(
|
||||
ITensor::makeShape({batchSize, beamWidth, maxNewTokens - 1, vocabSizePadded}));
|
||||
auto const generationLogitsShape = outputs.generationLogits->getShape();
|
||||
TLLM_CHECK_WITH_INFO(generationLogitsShape.d[0] == batchSize, "Invalid dim[0]");
|
||||
TLLM_CHECK_WITH_INFO(generationLogitsShape.d[1] == beamWidth, "Invalid dim[1]");
|
||||
TLLM_CHECK_WITH_INFO(generationLogitsShape.d[2] == maxNewTokens - 1, "Invalid dim[2]");
|
||||
TLLM_CHECK_WITH_INFO(generationLogitsShape.d[3] == vocabSizePadded, "Invalid dim[3]");
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
// callbacks
|
||||
auto const onTokenGenerated = createOnTokenGeneratedCallback(outputs);
|
||||
|
||||
if (batchSize <= mMicroBatchConfig.genBatchSize)
|
||||
{
|
||||
std::vector<GenerationInput> microBatches{inputs};
|
||||
generateBatched(outputs, microBatches, samplingConfig);
|
||||
std::vector<GenerationInput> microBatchesInputs{inputs};
|
||||
std::vector<GenerationOutput> microBatchesOutputs{outputs};
|
||||
generateBatched(microBatchesOutputs, microBatchesInputs, samplingConfig, onTokenGenerated);
|
||||
}
|
||||
else
|
||||
{
|
||||
auto const microBatches = splitInputs(inputs, mMicroBatchConfig.genBatchSize, manager);
|
||||
generateBatched(outputs, microBatches, samplingConfig);
|
||||
auto const microBatchesInputs = splitInputs(inputs, mMicroBatchConfig.genBatchSize, manager);
|
||||
auto microBatchesOutputs = splitOutputs(outputs, mMicroBatchConfig.genBatchSize, manager);
|
||||
generateBatched(microBatchesOutputs, microBatchesInputs, samplingConfig, onTokenGenerated);
|
||||
}
|
||||
|
||||
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
std::function<void(SizeType step, bool finished)> GptSession::createOnTokenGeneratedCallback(GenerationOutput& outputs)
|
||||
GptSession::TokenGeneratedCallback GptSession::createOnTokenGeneratedCallback(GenerationOutput& outputs)
|
||||
{
|
||||
if (outputs.onTokenGenerated && mWorldConfig.isFirstPipelineParallelRank())
|
||||
{
|
||||
@ -514,13 +602,15 @@ std::function<void(SizeType step, bool finished)> GptSession::createOnTokenGener
|
||||
}
|
||||
}
|
||||
|
||||
void GptSession::generateBatched(
|
||||
GenerationOutput& outputs, std::vector<GenerationInput> const& microBatches, SamplingConfig const& samplingConfig)
|
||||
void GptSession::generateBatched(std::vector<GenerationOutput>& microBatchesOutputs,
|
||||
std::vector<GenerationInput> const& microBatchesInputs, SamplingConfig const& samplingConfig,
|
||||
TokenGeneratedCallback const& onTokenGenerated)
|
||||
{
|
||||
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
|
||||
|
||||
auto& manager = mRuntime->getBufferManager();
|
||||
auto const numMicroBatches = static_cast<SizeType>(microBatches.size());
|
||||
TLLM_CHECK(microBatchesInputs.size() == microBatchesOutputs.size());
|
||||
auto const numMicroBatches = static_cast<SizeType>(microBatchesInputs.size());
|
||||
TLLM_CHECK(numMicroBatches > 0);
|
||||
TLLM_CHECK(numMicroBatches <= mMicroBatchConfig.numGenBatches);
|
||||
SizeType const beamWidth{samplingConfig.beamWidth};
|
||||
@ -528,7 +618,7 @@ void GptSession::generateBatched(
|
||||
// Initialize and reshape buffers
|
||||
for (auto microBatchId = 0; microBatchId < numMicroBatches; ++microBatchId)
|
||||
{
|
||||
auto const& microBatchInputs = microBatches.at(microBatchId);
|
||||
auto const& microBatchInputs = microBatchesInputs.at(microBatchId);
|
||||
auto& buffers = *mBuffers.at(microBatchId);
|
||||
buffers.initFromInput(*microBatchInputs.ids, microBatchInputs.lengths, microBatchInputs.packed, beamWidth,
|
||||
mDecoderMaxKvCacheLength, mDecoderMaxSequenceLength, manager);
|
||||
@ -549,14 +639,29 @@ void GptSession::generateBatched(
|
||||
auto& buffers = *mBuffers.at(microBatchId);
|
||||
auto const batchOffset = microBatchOffsets.at(microBatchId);
|
||||
kvCacheAddSequences(beamWidth, microBatchId, batchOffset);
|
||||
auto const& microBatchInputs = microBatches.at(microBatchId);
|
||||
auto const microBatchSize = buffers.generationConfig.batchSize;
|
||||
buffers.outputIds = ITensor::slice(outputs.ids, batchOffset, microBatchSize);
|
||||
buffers.outputLengths = ITensor::slice(outputs.lengths, batchOffset, microBatchSize);
|
||||
buffers.newTokens = initDecoder(*buffers.outputIds, microBatchInputs, samplingConfig, microBatchId);
|
||||
if (mWorldConfig.isLastPipelineParallelRank() && mModelConfig.computeContextLogits())
|
||||
auto const& microBatchInputs = microBatchesInputs.at(microBatchId);
|
||||
auto& microBatchOutputs = microBatchesOutputs.at(microBatchId);
|
||||
buffers.outputIds = microBatchOutputs.ids;
|
||||
buffers.outputLengths = microBatchOutputs.lengths;
|
||||
buffers.newTokens
|
||||
= initDecoder(*buffers.outputIds, microBatchInputs, microBatchOutputs, samplingConfig, microBatchId);
|
||||
|
||||
if (mWorldConfig.isLastPipelineParallelRank())
|
||||
{
|
||||
buffers.logits = ITensor::slice(outputs.contextLogits, batchOffset, microBatchSize);
|
||||
buffers.cumLogProbs = nullptr;
|
||||
if (microBatchOutputs.cumLogProbs)
|
||||
{
|
||||
buffers.cumLogProbs = microBatchOutputs.cumLogProbs;
|
||||
}
|
||||
buffers.logProbs = nullptr;
|
||||
if (microBatchOutputs.logProbs)
|
||||
{
|
||||
buffers.logProbs = microBatchOutputs.logProbs;
|
||||
}
|
||||
if (mModelConfig.computeContextLogits())
|
||||
{
|
||||
buffers.logits = microBatchOutputs.contextLogits;
|
||||
}
|
||||
}
|
||||
if (mModelConfig.usePromptTuning())
|
||||
{
|
||||
@ -564,9 +669,6 @@ void GptSession::generateBatched(
|
||||
}
|
||||
}
|
||||
|
||||
// Prepare the onTokenGenerated callback
|
||||
auto const onTokenGenerated = createOnTokenGeneratedCallback(outputs);
|
||||
|
||||
if (useCudaGraphs())
|
||||
{
|
||||
for (auto& instance : mCudaGraphInstances)
|
||||
@ -577,7 +679,7 @@ void GptSession::generateBatched(
|
||||
|
||||
auto kvCacheManager = mModelConfig.usePagedKvCache() ? mKvCacheManager.get() : nullptr;
|
||||
|
||||
executeContextStep(microBatches, microBatchOffsets, kvCacheManager);
|
||||
executeContextStep(microBatchesInputs, microBatchOffsets, kvCacheManager);
|
||||
|
||||
std::vector<bool> microBatchesFinished(numMicroBatches, false);
|
||||
SizeType numBatchesFinished{0};
|
||||
@ -585,8 +687,8 @@ void GptSession::generateBatched(
|
||||
while (numBatchesFinished < numMicroBatches)
|
||||
{
|
||||
++step;
|
||||
numBatchesFinished
|
||||
+= executeGenerationStep(step, microBatches, microBatchOffsets, kvCacheManager, microBatchesFinished);
|
||||
numBatchesFinished += executeGenerationStep(
|
||||
step, microBatchesInputs, microBatchesOutputs, microBatchOffsets, kvCacheManager, microBatchesFinished);
|
||||
|
||||
onTokenGenerated(step - 1, numBatchesFinished == numMicroBatches);
|
||||
}
|
||||
@ -608,9 +710,23 @@ void GptSession::generateBatched(
|
||||
|
||||
// TODO(micro batching) use mCommStream?
|
||||
if (beamWidth > 1)
|
||||
finalizeOutputIds(microBatchId);
|
||||
{
|
||||
finalize(microBatchId);
|
||||
}
|
||||
else if (!mWorldConfig.isPipelineParallel())
|
||||
manager.copy(*mDecoders.at(microBatchId)->getOutputIds(), *mBuffers.at(microBatchId)->outputIds);
|
||||
{
|
||||
auto& buffers = *mBuffers.at(microBatchId);
|
||||
auto& decoder = *mDecoders.at(microBatchId);
|
||||
manager.copy(*decoder.getOutputIds(), *buffers.outputIds);
|
||||
|
||||
auto& cumLogProbs = buffers.cumLogProbs;
|
||||
if (cumLogProbs)
|
||||
manager.copy(*decoder.getCumLogProbs(), *buffers.cumLogProbs);
|
||||
|
||||
auto& logProbs = buffers.logProbs;
|
||||
if (logProbs)
|
||||
manager.copy(*decoder.getLogProbs(), *buffers.logProbs);
|
||||
}
|
||||
}
|
||||
|
||||
manager.getStream().synchronize();
|
||||
@ -668,14 +784,15 @@ void GptSession::executeContextStep(std::vector<GenerationInput> const& generati
|
||||
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
SizeType GptSession::executeGenerationStep(SizeType step, std::vector<GenerationInput> const& microBatches,
|
||||
std::vector<SizeType> const& microBatchOffsets, KvCacheManager* kvCacheManager,
|
||||
std::vector<bool>& microBatchesFinished)
|
||||
SizeType GptSession::executeGenerationStep(SizeType step, std::vector<GenerationInput> const& microBatchesInputs,
|
||||
std::vector<GenerationOutput>& microBatchesOutputs, std::vector<SizeType> const& microBatchOffsets,
|
||||
KvCacheManager* kvCacheManager, std::vector<bool>& microBatchesFinished)
|
||||
{
|
||||
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
|
||||
TLLM_CHECK(microBatchesInputs.size() == microBatchesOutputs.size());
|
||||
auto& manager = mRuntime->getBufferManager();
|
||||
|
||||
auto const numMicroBatches = static_cast<SizeType>(microBatches.size());
|
||||
auto const numMicroBatches = static_cast<SizeType>(microBatchesInputs.size());
|
||||
SizeType numBatchesFinished{0};
|
||||
|
||||
auto const flipFlopId = step % 2;
|
||||
@ -725,6 +842,18 @@ SizeType GptSession::executeGenerationStep(SizeType step, std::vector<Generation
|
||||
}
|
||||
sync_check_cuda_error();
|
||||
|
||||
if (mModelConfig.computeGenerationLogits())
|
||||
{
|
||||
auto& outputs = microBatchesOutputs.at(generationBatchId);
|
||||
auto const firstBatchSlotIdx = microBatchOffsets.at(generationBatchId);
|
||||
auto const microBatchSize = buffers.generationConfig.batchSize;
|
||||
auto const beamWidth = buffers.generationConfig.beamWidth;
|
||||
|
||||
buffers.postEachGenerationStep(manager, outputs.generationLogits, step - 1, firstBatchSlotIdx,
|
||||
microBatchSize, beamWidth, mWorldConfig);
|
||||
}
|
||||
sync_check_cuda_error();
|
||||
|
||||
std::swap(buffers.cacheIndirectionDecoderInput, buffers.cacheIndirectionDecoderOutput);
|
||||
|
||||
auto const decoderStep = generationConfig.maxInputLength + step;
|
||||
@ -765,14 +894,14 @@ void GptSession::decoderStepAsync(SizeType decoderStep, SizeType microBatchId)
|
||||
auto const beamWidth = cacheIndirection.getShape().d[1];
|
||||
for (auto peerIdx = 0; peerIdx < mWorldConfig.getPipelineParallelism() - 1; ++peerIdx)
|
||||
{
|
||||
mPipelineComm->send<SizeType>(*decoder.getNbFinished(), pipelineGroup[peerIdx], *mCommStream, *mLogger);
|
||||
mPipelineComm->send<SizeType>(*decoder.getNbFinished(), pipelineGroup[peerIdx], *mCommStream);
|
||||
if (beamWidth > 1)
|
||||
{
|
||||
mPipelineComm->send<SizeType>(cacheIndirection, pipelineGroup[peerIdx], *mCommStream, *mLogger);
|
||||
mPipelineComm->send<SizeType>(cacheIndirection, pipelineGroup[peerIdx], *mCommStream);
|
||||
}
|
||||
mPipelineComm->send<SizeType>(sequenceLengths, pipelineGroup[peerIdx], *mCommStream, *mLogger);
|
||||
mPipelineComm->send<SizeType>(sequenceLengths, pipelineGroup[peerIdx], *mCommStream);
|
||||
}
|
||||
mPipelineComm->send<TokenIdType>(*decoder.getNewTokens(), pipelineGroup.front(), *mCommStream, *mLogger);
|
||||
mPipelineComm->send<TokenIdType>(*decoder.getNewTokens(), pipelineGroup.front(), *mCommStream);
|
||||
}
|
||||
}
|
||||
else // pipeline parallel mode
|
||||
@ -781,19 +910,19 @@ void GptSession::decoderStepAsync(SizeType decoderStep, SizeType microBatchId)
|
||||
mCommStream->wait(mCommEvent.get());
|
||||
auto const pipelineGroup = mWorldConfig.getPipelineParallelGroup();
|
||||
auto const peer = pipelineGroup.back();
|
||||
mPipelineComm->receive<SizeType>(*buffers.nbFinished, peer, *mCommStream, *mLogger);
|
||||
mPipelineComm->receive<SizeType>(*buffers.nbFinished, peer, *mCommStream);
|
||||
|
||||
auto& cacheIndirection = *buffers.cacheIndirectionDecoderOutput;
|
||||
auto& sequenceLengths = *buffers.sequenceLengths;
|
||||
auto const beamWidth = cacheIndirection.getShape().d[1];
|
||||
if (beamWidth > 1)
|
||||
{
|
||||
mPipelineComm->receive<SizeType>(cacheIndirection, peer, *mCommStream, *mLogger);
|
||||
mPipelineComm->receive<SizeType>(cacheIndirection, peer, *mCommStream);
|
||||
}
|
||||
mPipelineComm->receive<SizeType>(sequenceLengths, peer, *mCommStream, *mLogger);
|
||||
mPipelineComm->receive<SizeType>(sequenceLengths, peer, *mCommStream);
|
||||
if (mWorldConfig.isFirstPipelineParallelRank())
|
||||
{ // receive newTokens from last rank on a separate stream
|
||||
mPipelineComm->receive<TokenIdType>(*newTokens, peer, *mCommStream, *mLogger);
|
||||
mPipelineComm->receive<TokenIdType>(*newTokens, peer, *mCommStream);
|
||||
updateOutputIds(outputIds, newTokens, decoderStep, *mCommStream);
|
||||
}
|
||||
mCommStream->record(mReceivedEvents.at(microBatchId).get());
|
||||
@ -837,12 +966,16 @@ bool GptSession::shouldStopSync(SizeType batchSize, SizeType beamWidth, SizeType
|
||||
return nbFinished == batchSize * beamWidth;
|
||||
}
|
||||
|
||||
void GptSession::finalizeOutputIds(SizeType microBatchId)
|
||||
void GptSession::finalize(SizeType microBatchId)
|
||||
{
|
||||
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
|
||||
auto& manager = mRuntime->getBufferManager();
|
||||
auto& outputIds = *mBuffers.at(microBatchId)->outputIds;
|
||||
auto& sequenceLengths = *mBuffers.at(microBatchId)->sequenceLengths;
|
||||
auto& buffers = *mBuffers.at(microBatchId);
|
||||
auto& decoder = mDecoders.at(microBatchId);
|
||||
auto& outputIds = buffers.outputIds;
|
||||
auto& cumLogProbs = buffers.cumLogProbs;
|
||||
auto& logProbs = buffers.logProbs;
|
||||
auto& sequenceLengths = buffers.sequenceLengths;
|
||||
|
||||
if (mWorldConfig.isPipelineParallel())
|
||||
{
|
||||
@ -852,20 +985,56 @@ void GptSession::finalizeOutputIds(SizeType microBatchId)
|
||||
if (mWorldConfig.isLastPipelineParallelRank())
|
||||
{ // send ids from last to first
|
||||
auto const peer = pipelineGroup.front();
|
||||
auto const finalOutputIds = mDecoders.at(microBatchId)->getFinalOutputIds();
|
||||
mPipelineComm->send<TokenIdType>(*finalOutputIds, peer, stream, *mLogger);
|
||||
mPipelineComm->send<SizeType>(sequenceLengths, peer, stream, *mLogger);
|
||||
decoder->finalize();
|
||||
auto finalOutputIds = decoder->getOutputIds();
|
||||
|
||||
mPipelineComm->send<TokenIdType>(*finalOutputIds, peer, stream);
|
||||
mPipelineComm->send<SizeType>(*sequenceLengths, peer, stream);
|
||||
manager.copy(*finalOutputIds, *outputIds);
|
||||
|
||||
if (cumLogProbs)
|
||||
{
|
||||
auto finalCumLogProbs = decoder->getCumLogProbs();
|
||||
mPipelineComm->send<float>(*finalCumLogProbs, peer, stream);
|
||||
manager.copy(*finalCumLogProbs, *cumLogProbs);
|
||||
}
|
||||
if (logProbs)
|
||||
{
|
||||
auto finalLogProbs = decoder->getLogProbs();
|
||||
mPipelineComm->send<float>(*finalLogProbs, peer, stream);
|
||||
manager.copy(*finalLogProbs, *logProbs);
|
||||
}
|
||||
}
|
||||
else if (mWorldConfig.isFirstPipelineParallelRank())
|
||||
{ // receive ids from last on first
|
||||
auto const peer = pipelineGroup.back();
|
||||
mPipelineComm->receive<TokenIdType>(outputIds, peer, stream, *mLogger);
|
||||
mPipelineComm->receive<SizeType>(sequenceLengths, peer, stream, *mLogger);
|
||||
mPipelineComm->receive<TokenIdType>(*outputIds, peer, stream);
|
||||
mPipelineComm->receive<SizeType>(*sequenceLengths, peer, stream);
|
||||
if (cumLogProbs)
|
||||
{
|
||||
mPipelineComm->receive<float>(*cumLogProbs, peer, stream);
|
||||
}
|
||||
if (logProbs)
|
||||
{
|
||||
mPipelineComm->receive<float>(*logProbs, peer, stream);
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
manager.copy(*mDecoders.at(microBatchId)->getFinalOutputIds(), outputIds);
|
||||
decoder->finalize();
|
||||
auto finalOutputIds = decoder->getOutputIds();
|
||||
manager.copy(*finalOutputIds, *outputIds);
|
||||
if (cumLogProbs)
|
||||
{
|
||||
auto finalCumLogProbs = decoder->getCumLogProbs();
|
||||
manager.copy(*finalCumLogProbs, *cumLogProbs);
|
||||
}
|
||||
if (logProbs)
|
||||
{
|
||||
auto finalLogProbs = decoder->getLogProbs();
|
||||
manager.copy(*finalLogProbs, *logProbs);
|
||||
}
|
||||
// sequenceLengths are already updated by decoder
|
||||
}
|
||||
|
||||
|
||||
@ -17,6 +17,7 @@
|
||||
#include "tensorrt_llm/runtime/iBuffer.h"
|
||||
#include "tensorrt_llm/runtime/iTensor.h"
|
||||
|
||||
#include "tensorrt_llm/common/assert.h"
|
||||
#include "tensorrt_llm/common/cudaUtils.h"
|
||||
#include "tensorrt_llm/runtime/bufferView.h"
|
||||
|
||||
@ -79,3 +80,31 @@ std::ostream& tensorrt_llm::runtime::operator<<(std::ostream& output, IBuffer co
|
||||
ITensor::makeShape({static_cast<SizeType>(buffer.getSize())}), buffer.getCapacity());
|
||||
return output << *tensor;
|
||||
}
|
||||
|
||||
char const* IBuffer::getDataTypeName() const
|
||||
{
|
||||
switch (getDataType())
|
||||
{
|
||||
case nvinfer1::DataType::kINT64: return DataTypeTraits<nvinfer1::DataType::kINT64>::name;
|
||||
case nvinfer1::DataType::kINT32: return DataTypeTraits<nvinfer1::DataType::kINT32>::name;
|
||||
case nvinfer1::DataType::kFLOAT: return DataTypeTraits<nvinfer1::DataType::kFLOAT>::name;
|
||||
case nvinfer1::DataType::kBF16: return DataTypeTraits<nvinfer1::DataType::kBF16>::name;
|
||||
case nvinfer1::DataType::kHALF: return DataTypeTraits<nvinfer1::DataType::kHALF>::name;
|
||||
case nvinfer1::DataType::kBOOL: return DataTypeTraits<nvinfer1::DataType::kBOOL>::name;
|
||||
case nvinfer1::DataType::kUINT8: return DataTypeTraits<nvinfer1::DataType::kUINT8>::name;
|
||||
case nvinfer1::DataType::kINT8: return DataTypeTraits<nvinfer1::DataType::kINT8>::name;
|
||||
case nvinfer1::DataType::kFP8: return DataTypeTraits<nvinfer1::DataType::kFP8>::name;
|
||||
}
|
||||
TLLM_THROW("Unknown data type");
|
||||
}
|
||||
|
||||
char const* IBuffer::getMemoryTypeName() const
|
||||
{
|
||||
switch (getMemoryType())
|
||||
{
|
||||
case MemoryType::kPINNED: return MemoryTypeString<MemoryType::kPINNED>::value;
|
||||
case MemoryType::kCPU: return MemoryTypeString<MemoryType::kCPU>::value;
|
||||
case MemoryType::kGPU: return MemoryTypeString<MemoryType::kGPU>::value;
|
||||
}
|
||||
TLLM_THROW("Unknown memory type");
|
||||
}
|
||||
|
||||
@ -64,39 +64,38 @@ struct NcclDataType<std::int32_t>
|
||||
} // namespace
|
||||
|
||||
template <typename T>
|
||||
void NcclCommunicator::send(
|
||||
T* sendbuff, size_t count, int peer, CudaStream const& stream, nvinfer1::ILogger& logger) const
|
||||
void NcclCommunicator::send(T* sendbuff, size_t count, int peer, CudaStream const& stream) const
|
||||
{
|
||||
#if ENABLE_MULTI_DEVICE
|
||||
auto datatype = NcclDataType<std::remove_cv_t<T>>::value;
|
||||
TLLM_NCCL_CHECK(ncclSend(sendbuff, count, datatype, peer, mComm, stream.get()), logger);
|
||||
TLLM_NCCL_CHECK(ncclSend(sendbuff, count, datatype, peer, mComm, stream.get()));
|
||||
#else
|
||||
TLLM_THROW("Multi device support is disabled.");
|
||||
#endif // ENABLE_MULTI_DEVICE
|
||||
}
|
||||
|
||||
template void NcclCommunicator::send(std::uint8_t*, size_t, int, CudaStream const&, nvinfer1::ILogger&) const;
|
||||
template void NcclCommunicator::send(std::int32_t*, size_t, int, CudaStream const&, nvinfer1::ILogger&) const;
|
||||
template void NcclCommunicator::send(std::uint8_t const*, size_t, int, CudaStream const&, nvinfer1::ILogger&) const;
|
||||
template void NcclCommunicator::send(std::int32_t const*, size_t, int, CudaStream const&, nvinfer1::ILogger&) const;
|
||||
template void NcclCommunicator::send(std::uint8_t*, size_t, int, CudaStream const&) const;
|
||||
template void NcclCommunicator::send(std::int32_t*, size_t, int, CudaStream const&) const;
|
||||
template void NcclCommunicator::send(std::uint8_t const*, size_t, int, CudaStream const&) const;
|
||||
template void NcclCommunicator::send(std::int32_t const*, size_t, int, CudaStream const&) const;
|
||||
template void NcclCommunicator::send(float const*, size_t, int, CudaStream const&) const;
|
||||
|
||||
template <typename T>
|
||||
void NcclCommunicator::receive(
|
||||
T* sendbuff, size_t count, int peer, CudaStream const& stream, nvinfer1::ILogger& logger) const
|
||||
void NcclCommunicator::receive(T* sendbuff, size_t count, int peer, CudaStream const& stream) const
|
||||
{
|
||||
#if ENABLE_MULTI_DEVICE
|
||||
auto datatype = NcclDataType<std::remove_cv_t<T>>::value;
|
||||
TLLM_NCCL_CHECK(ncclRecv(sendbuff, count, datatype, peer, mComm, stream.get()), logger);
|
||||
TLLM_NCCL_CHECK(ncclRecv(sendbuff, count, datatype, peer, mComm, stream.get()));
|
||||
#else
|
||||
TLLM_THROW("Multi device support is disabled.");
|
||||
#endif // ENABLE_MULTI_DEVICE
|
||||
}
|
||||
|
||||
template void NcclCommunicator::receive(std::uint8_t*, size_t, int, CudaStream const&, nvinfer1::ILogger&) const;
|
||||
template void NcclCommunicator::receive(std::int32_t*, size_t, int, CudaStream const&, nvinfer1::ILogger&) const;
|
||||
template void NcclCommunicator::receive(std::uint8_t*, size_t, int, CudaStream const&) const;
|
||||
template void NcclCommunicator::receive(std::int32_t*, size_t, int, CudaStream const&) const;
|
||||
template void NcclCommunicator::receive(float*, size_t, int, CudaStream const&) const;
|
||||
|
||||
std::shared_ptr<NcclCommunicator> NcclCommunicator::createPipelineComm(
|
||||
WorldConfig const& worldConfig, nvinfer1::ILogger& logger)
|
||||
std::shared_ptr<NcclCommunicator> NcclCommunicator::createPipelineComm(WorldConfig const& worldConfig)
|
||||
{
|
||||
#if ENABLE_MULTI_DEVICE
|
||||
int const myRank = worldConfig.getRank();
|
||||
@ -108,18 +107,18 @@ std::shared_ptr<NcclCommunicator> NcclCommunicator::createPipelineComm(
|
||||
ncclGetUniqueId(&id);
|
||||
for (auto peer = 1; peer < worldSize; ++peer)
|
||||
{
|
||||
TLLM_MPI_CHECK(MPI_Send(&id, sizeof(id), MPI_BYTE, peer, 0, MPI_COMM_WORLD), logger);
|
||||
TLLM_MPI_CHECK(MPI_Send(&id, sizeof(id), MPI_BYTE, peer, 0, MPI_COMM_WORLD));
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
auto constexpr peer = 0;
|
||||
MPI_Status status;
|
||||
TLLM_MPI_CHECK(MPI_Recv(&id, sizeof(id), MPI_BYTE, peer, 0, MPI_COMM_WORLD, &status), logger);
|
||||
TLLM_MPI_CHECK(MPI_Recv(&id, sizeof(id), MPI_BYTE, peer, 0, MPI_COMM_WORLD, &status));
|
||||
}
|
||||
|
||||
auto pipelineComm = std::make_shared<NcclCommunicator>();
|
||||
TLLM_NCCL_CHECK(ncclCommInitRank(&pipelineComm->mComm, worldSize, id, myRank), logger);
|
||||
TLLM_NCCL_CHECK(ncclCommInitRank(&pipelineComm->mComm, worldSize, id, myRank));
|
||||
|
||||
return pipelineComm;
|
||||
#else
|
||||
|
||||
@ -30,25 +30,24 @@ class NcclCommunicator
|
||||
{
|
||||
public:
|
||||
template <typename T>
|
||||
void send(T* sendbuff, size_t count, int peer, CudaStream const& stream, nvinfer1::ILogger& logger) const;
|
||||
void send(T* sendbuff, size_t count, int peer, CudaStream const& stream) const;
|
||||
|
||||
template <typename T>
|
||||
void send(IBuffer const& buf, int peer, CudaStream const& stream, nvinfer1::ILogger& logger) const
|
||||
void send(IBuffer const& buf, int peer, CudaStream const& stream) const
|
||||
{
|
||||
send(bufferCast<T>(buf), buf.getSize(), peer, stream, logger);
|
||||
send(bufferCast<T>(buf), buf.getSize(), peer, stream);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void receive(T* sendbuff, size_t count, int peer, CudaStream const& stream, nvinfer1::ILogger& logger) const;
|
||||
void receive(T* sendbuff, size_t count, int peer, CudaStream const& stream) const;
|
||||
|
||||
template <typename T>
|
||||
void receive(IBuffer& buf, int peer, CudaStream const& stream, nvinfer1::ILogger& logger) const
|
||||
void receive(IBuffer& buf, int peer, CudaStream const& stream) const
|
||||
{
|
||||
receive(bufferCast<T>(buf), buf.getSize(), peer, stream, logger);
|
||||
receive(bufferCast<T>(buf), buf.getSize(), peer, stream);
|
||||
}
|
||||
|
||||
static std::shared_ptr<NcclCommunicator> createPipelineComm(
|
||||
WorldConfig const& worldConfig, nvinfer1::ILogger& logger);
|
||||
static std::shared_ptr<NcclCommunicator> createPipelineComm(WorldConfig const& worldConfig);
|
||||
|
||||
private:
|
||||
ncclComm_t mComm;
|
||||
|
||||
@ -83,6 +83,9 @@ void RuntimeBuffers::clear()
|
||||
cacheIndirectionDecoderInput = nullptr;
|
||||
cacheIndirectionDecoderOutput = nullptr;
|
||||
|
||||
cumLogProbs = nullptr;
|
||||
logProbs = nullptr;
|
||||
|
||||
hiddenStates = nullptr;
|
||||
|
||||
allocated = false;
|
||||
@ -155,10 +158,8 @@ void RuntimeBuffers::create(TllmRuntime& runtime, GptModelConfig const& modelCon
|
||||
if (modelConfig.useGptAttentionPlugin())
|
||||
{
|
||||
pastKeyValueLengths = manager.emptyTensor(MemoryType::kCPU, nvinfer1::DataType::kINT32);
|
||||
for (SizeType i = 0; i < modelConfig.getNbLayers(); ++i)
|
||||
{
|
||||
maxKvCacheLengths.emplace_back(manager.emptyTensor(MemoryType::kCPU, nvinfer1::DataType::kINT32));
|
||||
}
|
||||
maxKvCacheLengths
|
||||
= utils::createBufferVector(runtime, localNbLayers, MemoryType::kCPU, nvinfer1::DataType::kINT32);
|
||||
}
|
||||
else
|
||||
{
|
||||
@ -238,11 +239,8 @@ void RuntimeBuffers::reshape(GptModelConfig const& modelConfig, WorldConfig cons
|
||||
if (modelConfig.useGptAttentionPlugin())
|
||||
{
|
||||
pastKeyValueLengths->reshape(ITensor::makeShape({batchSize}));
|
||||
for (SizeType i = 0; i < modelConfig.getNbLayers(); ++i)
|
||||
{
|
||||
maxKvCacheLengths[i]->reshape(ITensor::makeShape({1}));
|
||||
}
|
||||
requestTypes->reshape(ITensor::makeShape({batchSize}));
|
||||
utils::reshapeBufferVector(maxKvCacheLengths, ITensor::makeShape({1}));
|
||||
}
|
||||
else
|
||||
{
|
||||
@ -511,6 +509,21 @@ void RuntimeBuffers::postContextStep(std::vector<RuntimeBuffers> const& contextB
|
||||
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
void RuntimeBuffers::postEachGenerationStep(BufferManager& manager, TensorPtr outputGenerationLogits, SizeType step,
|
||||
SizeType firstBatchSlotIdx, SizeType microBatchSize, SizeType beamWidth, WorldConfig const& worldConfig)
|
||||
{
|
||||
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
|
||||
|
||||
if (worldConfig.isLastPipelineParallelRank())
|
||||
{
|
||||
kernels::copyLatestTokenLogitsInGeneration(
|
||||
*outputGenerationLogits, *logits, step, firstBatchSlotIdx, microBatchSize, beamWidth, manager.getStream());
|
||||
manager.getStream().synchronize();
|
||||
}
|
||||
|
||||
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
void RuntimeBuffers::prepareContextStep(TensorPtr const& inputIds, TokenIdType const padId, BufferManager& manager,
|
||||
KvCacheManager const* kvCacheManager, SizeType firstBatchSlotIdx, GptModelConfig const& modelConfig,
|
||||
WorldConfig const& worldConfig)
|
||||
@ -523,6 +536,9 @@ void RuntimeBuffers::prepareContextStep(TensorPtr const& inputIds, TokenIdType c
|
||||
// use context lengths only in context step
|
||||
sequenceLengths = contextLengthsDevice;
|
||||
|
||||
// get local number of layers.
|
||||
auto const localNbLayers = modelConfig.getNbLayers(worldConfig.getPipelineParallelism());
|
||||
|
||||
if (modelConfig.useGptAttentionPlugin())
|
||||
{
|
||||
auto pastKeyValueLengthsPtr = bufferCast<SizeType>(*pastKeyValueLengths);
|
||||
@ -534,7 +550,7 @@ void RuntimeBuffers::prepareContextStep(TensorPtr const& inputIds, TokenIdType c
|
||||
std::fill_n(RequestTypesPtr, batchSize, 0);
|
||||
|
||||
// Set maxKvCacheLengths buffer to the same value currently.
|
||||
for (auto layer = 0; layer < modelConfig.getNbLayers(); ++layer)
|
||||
for (auto layer = 0; layer < localNbLayers; ++layer)
|
||||
{
|
||||
bufferCast<SizeType>(*maxKvCacheLengths[layer])[0] = generationConfig.maxKvCacheLength;
|
||||
}
|
||||
@ -803,12 +819,7 @@ void RuntimeBuffers::getRuntimeBuffers(TensorMap& inputBuffers, TensorMap& outpu
|
||||
inputBuffers.insert_or_assign("host_past_key_value_lengths", pastKeyValueLengths);
|
||||
inputBuffers.insert_or_assign("host_request_types", requestTypes);
|
||||
inputBuffers.insert_or_assign("sequence_length", sequenceLengths);
|
||||
|
||||
for (SizeType i = 0; i < modelConfig.getNbLayers(); ++i)
|
||||
{
|
||||
std::string name = "host_max_kv_cache_length_" + std::to_string(i);
|
||||
inputBuffers.insert_or_assign(name, maxKvCacheLengths[i]);
|
||||
}
|
||||
utils::insertTensorVector(inputBuffers, "host_max_kv_cache_length_", maxKvCacheLengths, firstLayerId);
|
||||
|
||||
if (modelConfig.usePackedInput())
|
||||
{
|
||||
|
||||
@ -106,6 +106,10 @@ public:
|
||||
// decoder
|
||||
TensorPtr nbFinished;
|
||||
|
||||
// Log probs
|
||||
TensorPtr cumLogProbs;
|
||||
TensorPtr logProbs;
|
||||
|
||||
// pipeline parallelism
|
||||
TensorPtr hiddenStates;
|
||||
|
||||
@ -135,6 +139,9 @@ public:
|
||||
void postContextStep(std::vector<RuntimeBuffers> const& contextBuffers, BufferManager& manager,
|
||||
GptModelConfig const& modelConfig, WorldConfig const& worldConfig);
|
||||
|
||||
void postEachGenerationStep(BufferManager& manager, TensorPtr outputGenerationLogits, SizeType step,
|
||||
SizeType firstBatchSlotIdx, SizeType microBatchSize, SizeType beamWidth, WorldConfig const& worldConfig);
|
||||
|
||||
void prepareContextStep(TensorPtr const& inputIds, TokenIdType padId, BufferManager& manager,
|
||||
KvCacheManager const* kvCacheManager, SizeType firstBatchSlotIdx, GptModelConfig const& modelConfig,
|
||||
WorldConfig const& worldConfig);
|
||||
|
||||
@ -1013,4 +1013,87 @@ void gatherLastTokenLogits(ITensor& output, ITensor const& input, ITensor const&
|
||||
}
|
||||
}
|
||||
|
||||
// In the following kernel, we launch a grid with microBatchSize * beamWidth blocks of threads. Each thread block
|
||||
// copies a `vocabSizePadded` length logits tensor from the "inputLogits (microBatchSize, beamWidth, vocabSizePadded)"
|
||||
// to the "outputGenerationLogits (batchSize, beamWidth, outPutLen, vocabSizePadded)"
|
||||
template <typename T>
|
||||
__global__ void copyLatestTokenLogitsInGenerationKernel(T* outputGenerationLogits, T const* inputLogits, int step,
|
||||
int firstBatchSlotIdx, int beamWidth, int outPutLen, int vocabSizePadded)
|
||||
{
|
||||
// The relatively batch slot index that this thread block in microBatchSize.
|
||||
int relativeBatchSlotIdx = blockIdx.x / beamWidth;
|
||||
|
||||
// The Absolute batch slot index in batchSize.
|
||||
int absoluteBatchSlotIdx = firstBatchSlotIdx + relativeBatchSlotIdx;
|
||||
|
||||
// The beam index that this thread block process
|
||||
int mbeamIdx = blockIdx.x % beamWidth;
|
||||
|
||||
// The output pointer.
|
||||
const unsigned int outputOffset
|
||||
= (absoluteBatchSlotIdx * beamWidth * outPutLen + mbeamIdx * outPutLen + step) * vocabSizePadded;
|
||||
T* outputPtr = &outputGenerationLogits[outputOffset];
|
||||
|
||||
// The input pointer.
|
||||
const unsigned int inputOffset = (relativeBatchSlotIdx * beamWidth + mbeamIdx) * vocabSizePadded;
|
||||
T const* inputPtr = &inputLogits[inputOffset];
|
||||
|
||||
// The threads in the block collaborate to copy the logits.
|
||||
for (int idx = threadIdx.x; idx < vocabSizePadded; idx += blockDim.x)
|
||||
{
|
||||
outputPtr[idx] = inputPtr[idx];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void invokeCopyLatestTokenLogitsInGeneration(ITensor& output, ITensor const& input, SizeType step,
|
||||
SizeType firstBatchSlotIdx, SizeType microBatchSize, SizeType beamWidth, CudaStream const& stream)
|
||||
{
|
||||
auto const& outputShape = output.getShape();
|
||||
auto const maxBatchSize = static_cast<std::uint32_t>(outputShape.d[0]);
|
||||
auto const _beamWidth = static_cast<std::uint32_t>(outputShape.d[1]);
|
||||
auto const outPutLen = static_cast<std::uint32_t>(outputShape.d[2]);
|
||||
auto const vocabSizePadded = static_cast<std::uint32_t>(outputShape.d[3]);
|
||||
|
||||
TLLM_CHECK_WITH_INFO(maxBatchSize >= microBatchSize, "Invalid output shape: dim[0]");
|
||||
TLLM_CHECK_WITH_INFO(_beamWidth == beamWidth, "Invalid output shape: dim[1]");
|
||||
TLLM_CHECK_WITH_INFO(outPutLen >= step, "Invalid output shape: dim[2]");
|
||||
|
||||
auto const& inputShape = input.getShape();
|
||||
TLLM_CHECK_WITH_INFO(inputShape.d[0] == microBatchSize, "Invalid input shape: dim[0]");
|
||||
TLLM_CHECK_WITH_INFO(inputShape.d[1] == beamWidth, "Invalid input shape: dim[1]");
|
||||
TLLM_CHECK_WITH_INFO(inputShape.d[2] == vocabSizePadded, "Invalid input shape: dim[2]");
|
||||
|
||||
dim3 const blockSize{256, 1};
|
||||
dim3 const gridSize{static_cast<std::uint32_t>(microBatchSize * beamWidth), 1};
|
||||
|
||||
copyLatestTokenLogitsInGenerationKernel<<<gridSize, blockSize, 0, stream.get()>>>(
|
||||
bufferCast<T>(output), bufferCast<T>(input), step, firstBatchSlotIdx, beamWidth, outPutLen, vocabSizePadded);
|
||||
}
|
||||
|
||||
void copyLatestTokenLogitsInGeneration(ITensor& output, ITensor const& input, SizeType step, SizeType firstBatchSlotIdx,
|
||||
SizeType microBatchSize, SizeType beamWidth, CudaStream const& stream)
|
||||
{
|
||||
switch (input.getDataType())
|
||||
{
|
||||
case nvinfer1::DataType::kFLOAT:
|
||||
invokeCopyLatestTokenLogitsInGeneration<float>(
|
||||
output, input, step, firstBatchSlotIdx, microBatchSize, beamWidth, stream);
|
||||
break;
|
||||
case nvinfer1::DataType::kHALF:
|
||||
invokeCopyLatestTokenLogitsInGeneration<half>(
|
||||
output, input, step, firstBatchSlotIdx, microBatchSize, beamWidth, stream);
|
||||
break;
|
||||
case nvinfer1::DataType::kBF16:
|
||||
invokeCopyLatestTokenLogitsInGeneration<__nv_bfloat16>(
|
||||
output, input, step, firstBatchSlotIdx, microBatchSize, beamWidth, stream);
|
||||
break;
|
||||
case nvinfer1::DataType::kFP8:
|
||||
invokeCopyLatestTokenLogitsInGeneration<__nv_fp8_e4m3>(
|
||||
output, input, step, firstBatchSlotIdx, microBatchSize, beamWidth, stream);
|
||||
break;
|
||||
default: TLLM_CHECK_WITH_INFO(false, "data type not supported");
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace tensorrt_llm::runtime::kernels
|
||||
|
||||
@ -81,4 +81,7 @@ void tileTensorInplace(ITensor& tensor, SizeType beamWidth, CudaStream const& st
|
||||
void gatherLastTokenLogits(
|
||||
ITensor& output, ITensor const& input, ITensor const& lastTokenIds, CudaStream const& stream);
|
||||
|
||||
void copyLatestTokenLogitsInGeneration(ITensor& output, ITensor const& input, SizeType step, SizeType firstBatchSlotIdx,
|
||||
SizeType microBatchSize, SizeType beamWidth, CudaStream const& stream);
|
||||
|
||||
} // namespace tensorrt_llm::runtime::kernels
|
||||
|
||||
@ -120,7 +120,8 @@ void StatefulGptDecoder::reshapeBuffers(
|
||||
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
void StatefulGptDecoder::newBatch(GenerationInput const& inputs, SamplingConfig const& samplingConfig)
|
||||
void StatefulGptDecoder::newBatch(
|
||||
GenerationInput const& inputs, GenerationOutput const& outputs, SamplingConfig const& samplingConfig)
|
||||
{
|
||||
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
|
||||
auto& manager = mBufferManager;
|
||||
@ -132,7 +133,7 @@ void StatefulGptDecoder::newBatch(GenerationInput const& inputs, SamplingConfig
|
||||
auto const beamWidth = samplingConfig.beamWidth;
|
||||
|
||||
reshapeBuffers(batchSize, beamWidth, mMaxKvCacheLength, mMaxSequenceLength);
|
||||
mDecoder->setup(samplingConfig, batchSize);
|
||||
mDecoder->setup(samplingConfig, batchSize, mMaxSequenceLength);
|
||||
|
||||
// sanity checks, should always be true after reshape
|
||||
auto const& outputIdsShape = mDecodingOutput->ids->getShape();
|
||||
@ -189,6 +190,19 @@ void StatefulGptDecoder::newBatch(GenerationInput const& inputs, SamplingConfig
|
||||
manager.setZero(*dOutput.finished);
|
||||
manager.setZero(*dOutput.finishedSum);
|
||||
|
||||
// If outputs contains cumLogProbs, use that
|
||||
if (outputs.cumLogProbs)
|
||||
{
|
||||
dOutput.cumLogProbs = outputs.cumLogProbs;
|
||||
}
|
||||
dOutput.logProbs = outputs.logProbs;
|
||||
|
||||
if (dOutput.cumLogProbs)
|
||||
manager.setZero(*dOutput.cumLogProbs);
|
||||
|
||||
if (dOutput.logProbs)
|
||||
manager.setZero(*dOutput.logProbs);
|
||||
|
||||
if (beamWidth > 1)
|
||||
{
|
||||
std::vector<float> cumLogProbsHost(batchSize * beamWidth, DecodingOutput::kNegativeInfinity);
|
||||
@ -199,13 +213,6 @@ void StatefulGptDecoder::newBatch(GenerationInput const& inputs, SamplingConfig
|
||||
}
|
||||
manager.copy(cumLogProbsHost.data(), *dOutput.cumLogProbs);
|
||||
|
||||
// kernels::invokeFill(*dOutput.cumLogProbs, DecodingOutput::kNegativeInfinity, *stream);
|
||||
// for (SizeType batchIdx = 0; batchIdx < batchSize; ++batchIdx)
|
||||
// {
|
||||
// auto cumLogProbsSlice = ITensor::slice(dOutput.cumLogProbs, batchIdx, 1);
|
||||
// manager.setZero(*IBuffer::slice(cumLogProbsSlice, 0, 1));
|
||||
// }
|
||||
|
||||
manager.setZero(*dOutput.parentIds);
|
||||
dOutput.beamHypotheses.init(manager, endId);
|
||||
}
|
||||
@ -268,14 +275,14 @@ void StatefulGptDecoder::forwardSync()
|
||||
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
IStatefulGptDecoder::TensorPtr StatefulGptDecoder::getFinalOutputIds() const
|
||||
void StatefulGptDecoder::finalize() const
|
||||
{
|
||||
// TODO (rkobus) can we do this inplace?
|
||||
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
|
||||
auto& outputIds = mDecodingOutput->ids;
|
||||
auto finalOutputIds = mBufferManager.gpu(outputIds->getShape(), outputIds->getDataType());
|
||||
IGptDecoder::gatherTree(*finalOutputIds, *mDecodingOutput, *mDecodingInput, mBufferManager);
|
||||
mDecoder->gatherTree(*finalOutputIds, *mDecodingOutput, *mDecodingInput, mBufferManager);
|
||||
mBufferManager.copy(*finalOutputIds, *outputIds);
|
||||
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
|
||||
return outputIds;
|
||||
return;
|
||||
}
|
||||
|
||||
@ -43,14 +43,15 @@ public:
|
||||
nvinfer1::DataType dtype) override;
|
||||
|
||||
//! @brief Initialize the decoder with new batch of inputs.
|
||||
void newBatch(GenerationInput const& input, SamplingConfig const& samplingConfig) override;
|
||||
void newBatch(
|
||||
GenerationInput const& input, GenerationOutput const& output, SamplingConfig const& samplingConfig) override;
|
||||
|
||||
void forwardAsync(decoder::Output& output, decoder::Input const& input) override;
|
||||
|
||||
void forwardSync() override;
|
||||
|
||||
//! @brief Gather final results for all requests.
|
||||
[[nodiscard]] TensorPtr getFinalOutputIds() const override;
|
||||
void finalize() const override;
|
||||
|
||||
//! @returns [batchSize, maxBeamWidth, maxInputLength + maxNewTokens], contains input token ids and generated token
|
||||
//! ids without padding, on gpu
|
||||
@ -59,6 +60,18 @@ public:
|
||||
return mDecodingOutput->ids;
|
||||
}
|
||||
|
||||
//! @returns [batchSize, maxBeamWidth], cumulative log probabilities (per beam), on gpu
|
||||
[[nodiscard]] TensorPtr getCumLogProbs() const override
|
||||
{
|
||||
return mDecodingOutput->cumLogProbs;
|
||||
}
|
||||
|
||||
//! @returns [batchSize, maxBeamWidth], cumulative log probabilities (per beam), on gpu
|
||||
[[nodiscard]] TensorPtr getLogProbs() const override
|
||||
{
|
||||
return mDecodingOutput->logProbs;
|
||||
}
|
||||
|
||||
//! @returns [batchSize, maxBeamWidth], tokens generated in last forward pass, on gpu
|
||||
[[nodiscard]] TensorPtr getNewTokens() const override
|
||||
{
|
||||
|
||||
@ -16,6 +16,7 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "tensorrt_llm/common/assert.h"
|
||||
#include "tensorrt_llm/common/stringUtils.h"
|
||||
|
||||
#include <mpi.h>
|
||||
@ -24,30 +25,21 @@
|
||||
#include <nccl.h>
|
||||
#endif // ENABLE_MULTI_DEVICE
|
||||
|
||||
#define TLLM_MPI_CHECK(cmd, logger) \
|
||||
#define TLLM_MPI_CHECK(cmd) \
|
||||
do \
|
||||
{ \
|
||||
auto e = cmd; \
|
||||
if (e != MPI_SUCCESS) \
|
||||
{ \
|
||||
(logger).log(nvinfer1::ILogger::Severity::kERROR, \
|
||||
tensorrt_llm::common::fmtstr("Failed: MPI error %s:%d '%d'", __FILE__, __LINE__, e).c_str()); \
|
||||
exit(EXIT_FAILURE); \
|
||||
} \
|
||||
TLLM_CHECK_WITH_INFO(e == MPI_SUCCESS, \
|
||||
tensorrt_llm::common::fmtstr("Failed: MPI error %s:%d '%d'", __FILE__, __LINE__, e).c_str()); \
|
||||
} while (0)
|
||||
|
||||
#if ENABLE_MULTI_DEVICE
|
||||
#define TLLM_NCCL_CHECK(cmd, logger) \
|
||||
#define TLLM_NCCL_CHECK(cmd) \
|
||||
do \
|
||||
{ \
|
||||
ncclResult_t r = cmd; \
|
||||
if (r != ncclSuccess) \
|
||||
{ \
|
||||
(logger).log(nvinfer1::ILogger::Severity::kERROR, \
|
||||
tensorrt_llm::common::fmtstr( \
|
||||
"Failed, NCCL error %s:%d '%s'\n", __FILE__, __LINE__, ncclGetErrorString(r)) \
|
||||
.c_str()); \
|
||||
exit(EXIT_FAILURE); \
|
||||
} \
|
||||
TLLM_CHECK_WITH_INFO(r == ncclSuccess, \
|
||||
tensorrt_llm::common::fmtstr("Failed, NCCL error %s:%d '%s'\n", __FILE__, __LINE__, ncclGetErrorString(r)) \
|
||||
.c_str()); \
|
||||
} while (0)
|
||||
#endif // ENABLE_MULTI_DEVICE
|
||||
|
||||
@ -21,6 +21,7 @@
|
||||
#include "tensorrt_llm/runtime/tllmLogger.h"
|
||||
#include "tensorrt_llm/runtime/utils/multiDeviceUtils.h"
|
||||
|
||||
#include <csignal>
|
||||
#include <cstdlib>
|
||||
#include <mpi.h>
|
||||
|
||||
@ -40,15 +41,18 @@ void initMpi(nvinfer1::ILogger& logger, int threadMode = MPI_THREAD_FUNNELED)
|
||||
}
|
||||
|
||||
int initialized = 0;
|
||||
TLLM_MPI_CHECK(MPI_Initialized(&initialized), logger);
|
||||
TLLM_MPI_CHECK(MPI_Initialized(&initialized));
|
||||
if (!initialized)
|
||||
{
|
||||
logger.log(
|
||||
nvinfer1::ILogger::Severity::kINFO, tc::fmtstr("Initializing MPI with thread mode %d", threadMode).c_str());
|
||||
int providedMode;
|
||||
TLLM_MPI_CHECK(MPI_Init_thread(nullptr, nullptr, threadMode, &providedMode), logger);
|
||||
TLLM_MPI_CHECK(MPI_Init_thread(nullptr, nullptr, threadMode, &providedMode));
|
||||
TLLM_CHECK_WITH_INFO(providedMode >= threadMode, "MPI_Init_thread failed");
|
||||
std::atexit([]() { MPI_Finalize(); });
|
||||
|
||||
auto previousHandler = std::signal(SIGABRT, [](int signal) { MPI_Abort(MPI_COMM_WORLD, EXIT_FAILURE); });
|
||||
TLLM_CHECK_WITH_INFO(previousHandler != SIG_ERR, "Signal handler setup failed");
|
||||
}
|
||||
|
||||
mpiInitialized = true;
|
||||
@ -61,7 +65,7 @@ bool WorldConfig::validConfig(nvinfer1::ILogger& logger, SizeType tensorParallel
|
||||
initMpi(logger);
|
||||
|
||||
int mpiSize;
|
||||
TLLM_MPI_CHECK(MPI_Comm_size(MPI_COMM_WORLD, &mpiSize), logger);
|
||||
TLLM_MPI_CHECK(MPI_Comm_size(MPI_COMM_WORLD, &mpiSize));
|
||||
return mpiSize == tensorParallelism * pipelineParallelism;
|
||||
}
|
||||
|
||||
@ -71,8 +75,8 @@ WorldConfig WorldConfig::mpi(nvinfer1::ILogger& logger, SizeType gpusPerNode, st
|
||||
initMpi(logger);
|
||||
|
||||
int mpiSize, mpiRank;
|
||||
TLLM_MPI_CHECK(MPI_Comm_size(MPI_COMM_WORLD, &mpiSize), logger);
|
||||
TLLM_MPI_CHECK(MPI_Comm_rank(MPI_COMM_WORLD, &mpiRank), logger);
|
||||
TLLM_MPI_CHECK(MPI_Comm_size(MPI_COMM_WORLD, &mpiSize));
|
||||
TLLM_MPI_CHECK(MPI_Comm_rank(MPI_COMM_WORLD, &mpiRank));
|
||||
logger.log(nvinfer1::ILogger::Severity::kINFO, tc::fmtstr("MPI size: %d, rank: %d", mpiSize, mpiRank).c_str());
|
||||
|
||||
auto pp = pipelineParallelism.value_or(1);
|
||||
|
||||
@ -15,18 +15,16 @@
|
||||
*/
|
||||
|
||||
#include "tensorrt_llm/thop/ncclCommunicatorOp.h"
|
||||
#include "tensorrt_llm/runtime/tllmLogger.h"
|
||||
|
||||
namespace torch_ext
|
||||
{
|
||||
|
||||
NcclCommunicatorOp::NcclCommunicatorOp(int64_t tpSize, int64_t ppSize, int64_t rank)
|
||||
: mLogger(std::make_shared<tensorrt_llm::runtime::TllmLogger>())
|
||||
, mRank(static_cast<int32_t>(rank))
|
||||
: mRank(static_cast<int32_t>(rank))
|
||||
{
|
||||
tensorrt_llm::runtime::WorldConfig worldConfig{
|
||||
static_cast<int32_t>(tpSize), static_cast<int32_t>(ppSize), static_cast<int32_t>(rank)};
|
||||
mPipelineComm = tensorrt_llm::runtime::NcclCommunicator::createPipelineComm(worldConfig, *mLogger);
|
||||
mPipelineComm = tensorrt_llm::runtime::NcclCommunicator::createPipelineComm(worldConfig);
|
||||
}
|
||||
|
||||
void NcclCommunicatorOp::send(th::Tensor tensor, int64_t toRank) const
|
||||
@ -34,7 +32,7 @@ void NcclCommunicatorOp::send(th::Tensor tensor, int64_t toRank) const
|
||||
auto ptr = reinterpret_cast<std::uint8_t*>(get_ptr<int8_t>(tensor));
|
||||
size_t const size = tensor.numel() * th::elementSize(th::typeMetaToScalarType(tensor.dtype()));
|
||||
tensorrt_llm::runtime::CudaStream cudaStream{at::cuda::getCurrentCUDAStream().stream(), mRank, false};
|
||||
mPipelineComm->send(ptr, size, static_cast<int32_t>(toRank), cudaStream, *mLogger);
|
||||
mPipelineComm->send(ptr, size, static_cast<int32_t>(toRank), cudaStream);
|
||||
}
|
||||
|
||||
void NcclCommunicatorOp::recv(th::Tensor& tensor, int64_t fromRank) const
|
||||
@ -42,7 +40,7 @@ void NcclCommunicatorOp::recv(th::Tensor& tensor, int64_t fromRank) const
|
||||
auto ptr = reinterpret_cast<std::uint8_t*>(get_ptr<int8_t>(tensor));
|
||||
size_t const size = tensor.numel() * th::elementSize(th::typeMetaToScalarType(tensor.dtype()));
|
||||
tensorrt_llm::runtime::CudaStream cudaStream{at::cuda::getCurrentCUDAStream().stream(), mRank, false};
|
||||
mPipelineComm->receive(ptr, size, static_cast<int32_t>(fromRank), cudaStream, *mLogger);
|
||||
mPipelineComm->receive(ptr, size, static_cast<int32_t>(fromRank), cudaStream);
|
||||
}
|
||||
|
||||
} // namespace torch_ext
|
||||
|
||||
@ -33,7 +33,6 @@ public:
|
||||
void recv(th::Tensor& tensor, int64_t fromRank) const;
|
||||
|
||||
private:
|
||||
std::shared_ptr<nvinfer1::ILogger> mLogger;
|
||||
int32_t mRank;
|
||||
std::shared_ptr<tensorrt_llm::runtime::NcclCommunicator> mPipelineComm;
|
||||
};
|
||||
|
||||
@ -9,7 +9,6 @@
|
||||
#include "tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelLauncher.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <cstdint>
|
||||
#include <cstdio>
|
||||
#include <cstdlib>
|
||||
@ -64,11 +63,19 @@ struct BType<WeightOnlyQuantType::Int8b>
|
||||
struct CutlassKernel;
|
||||
struct CudaKernel;
|
||||
|
||||
void simple_assert(bool flag)
|
||||
{
|
||||
if (!flag)
|
||||
{
|
||||
throw std::runtime_error("assert failed");
|
||||
}
|
||||
}
|
||||
|
||||
template <typename KernelFlag, WeightOnlyActivationType AFlag, WeightOnlyQuantType BFlag>
|
||||
float benchmark_perchannel(void* act, void* weight, void* scales, void* zeros, void* bias, void* out, int m, int n,
|
||||
int k, int group_size, int warmup, int iter)
|
||||
{
|
||||
assert(zeros == nullptr && bias == nullptr && group_size == 0);
|
||||
simple_assert(zeros == nullptr && bias == nullptr && group_size == 0);
|
||||
cudaStream_t s;
|
||||
cudaStreamCreate(&s);
|
||||
cudaEvent_t begin, end;
|
||||
@ -115,7 +122,6 @@ float benchmark_perchannel(void* act, void* weight, void* scales, void* zeros, v
|
||||
cudaEventSynchronize(end);
|
||||
float time;
|
||||
cudaEventElapsedTime(&time, begin, end);
|
||||
fast_time = std::min(fast_time, time);
|
||||
if (time < fast_time)
|
||||
{
|
||||
fast_time = time;
|
||||
@ -127,7 +133,6 @@ float benchmark_perchannel(void* act, void* weight, void* scales, void* zeros, v
|
||||
{
|
||||
gemm.gemm(act, weight, scales, out, m, n, k, best_config, ws_ptr, ws_bytes, s);
|
||||
}
|
||||
cudaProfilerStart();
|
||||
cudaEventRecord(begin, s);
|
||||
for (int i = 0; i < iter; ++i)
|
||||
{
|
||||
@ -151,7 +156,7 @@ template <typename KernelFlag, WeightOnlyActivationType AFlag, WeightOnlyQuantTy
|
||||
float benchmark_groupwise(void* act, void* weight, void* scales, void* zeros, void* bias, void* out, int m, int n,
|
||||
int k, int group_size, int warmup, int iter)
|
||||
{
|
||||
assert(zeros && bias && (group_size == 64 || group_size == 128));
|
||||
simple_assert(zeros && bias && (group_size == 64 || group_size == 128));
|
||||
cudaStream_t s;
|
||||
cudaStreamCreate(&s);
|
||||
cudaEvent_t begin, end;
|
||||
@ -198,7 +203,6 @@ float benchmark_groupwise(void* act, void* weight, void* scales, void* zeros, vo
|
||||
cudaEventSynchronize(end);
|
||||
float time;
|
||||
cudaEventElapsedTime(&time, begin, end);
|
||||
fast_time = std::min(fast_time, time);
|
||||
if (time < fast_time)
|
||||
{
|
||||
fast_time = time;
|
||||
@ -210,7 +214,6 @@ float benchmark_groupwise(void* act, void* weight, void* scales, void* zeros, vo
|
||||
{
|
||||
gemm.gemm(act, weight, scales, zeros, bias, out, m, n, k, group_size, best_config, ws_ptr, ws_bytes, s);
|
||||
}
|
||||
cudaProfilerStart();
|
||||
cudaEventRecord(begin, s);
|
||||
for (int i = 0; i < iter; ++i)
|
||||
{
|
||||
@ -379,11 +382,20 @@ bool benchmark(int m, int n, int k, int group_size, int warmup, int iter)
|
||||
}
|
||||
|
||||
float time1, time2;
|
||||
time1 = benchmark_perchannel<CudaKernel, AFlag, BFlag>(
|
||||
d_act.data(), d_weight.data(), d_scales.data(), p_zeros, p_bias, d_out.data(), m, n, k, 0, warmup, iter);
|
||||
std::function<decltype(benchmark_perchannel<CudaKernel, AFlag, BFlag>)> benchmark_func_cuda
|
||||
= benchmark_perchannel<CudaKernel, AFlag, BFlag>;
|
||||
std::function<decltype(benchmark_perchannel<CutlassKernel, AFlag, BFlag>)> benchmark_func_cutlass
|
||||
= benchmark_perchannel<CutlassKernel, AFlag, BFlag>;
|
||||
if (group_size != 0)
|
||||
{
|
||||
benchmark_func_cuda = benchmark_groupwise<CudaKernel, AFlag, BFlag>;
|
||||
benchmark_func_cutlass = benchmark_groupwise<CutlassKernel, AFlag, BFlag>;
|
||||
}
|
||||
time1 = benchmark_func_cuda(d_act.data(), d_weight.data(), d_scales.data(), p_zeros, p_bias, d_out.data(), m, n, k,
|
||||
group_size, warmup, iter);
|
||||
d_out.copy_to(h_out1.data());
|
||||
time2 = benchmark_perchannel<CutlassKernel, AFlag, BFlag>(
|
||||
d_act.data(), d_weight.data(), d_scales.data(), p_zeros, p_bias, d_out.data(), m, n, k, 0, warmup, iter);
|
||||
time2 = benchmark_func_cutlass(d_act.data(), d_weight.data(), d_scales.data(), p_zeros, p_bias, d_out.data(), m, n,
|
||||
k, group_size, warmup, iter);
|
||||
d_out.copy_to(h_out2.data());
|
||||
float quant_scale = 1.f / (1 << (8 / elem_per_byte - 1));
|
||||
bool pass = compare<AT>(h_out1.data(), h_out2.data(), m * n, quant_scale);
|
||||
|
||||
@ -20,7 +20,6 @@ import shutil as _shutil
|
||||
import subprocess as _sp
|
||||
import sys
|
||||
import typing as _tp
|
||||
from collections import OrderedDict as _OrderedDict
|
||||
from pathlib import Path as _Path
|
||||
|
||||
import torch.multiprocessing as _mp
|
||||
@ -35,11 +34,11 @@ engine_target_path = _pl.Path(
|
||||
import build as _ecb
|
||||
|
||||
|
||||
def build_engine(model_version: str, weight_dir: _pl.Path, engine_dir: _pl.Path,
|
||||
def build_engine(model_name: str, weight_dir: _pl.Path, engine_dir: _pl.Path,
|
||||
world_size, *args):
|
||||
args = [
|
||||
'-m',
|
||||
str(model_version),
|
||||
str(model_name),
|
||||
'--log_level=error',
|
||||
'--model_dir',
|
||||
str(weight_dir),
|
||||
@ -47,6 +46,8 @@ def build_engine(model_version: str, weight_dir: _pl.Path, engine_dir: _pl.Path,
|
||||
str(engine_dir),
|
||||
'--max_batch_size=2',
|
||||
'--max_beam_width=2',
|
||||
"--max_input_len=512",
|
||||
"--max_output_len=512",
|
||||
'--builder_opt=0',
|
||||
f'--world_size={world_size}',
|
||||
] + list(args)
|
||||
@ -64,14 +65,8 @@ def run_command(command: _tp.Sequence[str], *, cwd=None, **kwargs) -> None:
|
||||
|
||||
def build_engines(model_cache: _tp.Optional[str] = None, world_size: int = 1):
|
||||
|
||||
model_name_dict = _OrderedDict([
|
||||
["chatglm-6b", "1"],
|
||||
["chatglm2-6b", "2"],
|
||||
["chatglm3-6b", "3"],
|
||||
])
|
||||
hf_dir_list = [
|
||||
resources_dir / model_name for model_name in model_name_dict.keys()
|
||||
]
|
||||
model_name_list = ["chatglm_6b", "chatglm2_6b", "chatglm3_6b"]
|
||||
hf_dir_list = [resources_dir / model_name for model_name in model_name_list]
|
||||
trt_dir = resources_dir / "trtModel"
|
||||
|
||||
run_command(
|
||||
@ -80,21 +75,23 @@ def build_engines(model_cache: _tp.Optional[str] = None, world_size: int = 1):
|
||||
cwd=resources_dir)
|
||||
|
||||
# Clone the model directory
|
||||
for model_name, hf_dir in zip(model_name_dict.keys(), hf_dir_list):
|
||||
for model_name, hf_dir in zip(model_name_list, hf_dir_list):
|
||||
if not _Path(hf_dir).exists():
|
||||
run_command(
|
||||
[
|
||||
"git",
|
||||
"clone",
|
||||
"https://huggingface.co/THUDM/" + model_name,
|
||||
"https://huggingface.co/THUDM/" +
|
||||
model_name.replace("_", "-"),
|
||||
model_name,
|
||||
],
|
||||
cwd=resources_dir,
|
||||
)
|
||||
|
||||
print("\nBuilding engines")
|
||||
for model, hf_dir in zip(model_name_dict.items(), hf_dir_list):
|
||||
print("Building %s" % model[0])
|
||||
build_engine(model[1], hf_dir, trt_dir, world_size)
|
||||
for model_name, hf_dir in zip(model_name_list, hf_dir_list):
|
||||
print("Building %s" % model_name)
|
||||
build_engine(model_name, hf_dir, trt_dir, world_size)
|
||||
|
||||
if not _Path(engine_target_path).exists():
|
||||
_Path(engine_target_path).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
@ -16,7 +16,6 @@
|
||||
|
||||
import json
|
||||
import sys
|
||||
from collections import OrderedDict
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
@ -39,21 +38,14 @@ from build import find_engines # isort:skip
|
||||
|
||||
def generate(model_name, batch_size, beam_width):
|
||||
|
||||
model_name_dict = OrderedDict([
|
||||
["chatglm-6b", "1"],
|
||||
["chatglm2-6b", "2"],
|
||||
["chatglm3-6b", "3"],
|
||||
])
|
||||
|
||||
print("generate expected %s output BatchSize=%d, BeamWidth=%d" %
|
||||
(model_name, batch_size, beam_width))
|
||||
|
||||
args = parse_arguments()
|
||||
args = parse_arguments(['-m', model_name])
|
||||
if batch_size == 1:
|
||||
args.input_text = args.input_text[:1]
|
||||
elif batch_size > 2:
|
||||
args.input_text += args.input_text[0] * (batch_size - 2)
|
||||
args.model_version = model_name_dict[model_name]
|
||||
args.beam_width = beam_width
|
||||
args.tokenizer_dir = resources_dir / model_name
|
||||
args.engine_dir = Path(__file__).parent.parent / "models/rt_engine/chatglm"
|
||||
@ -65,17 +57,22 @@ def generate(model_name, batch_size, beam_width):
|
||||
config = json.load(f)
|
||||
assert (config['builder_config']['name'] == model_name)
|
||||
dtype = config['builder_config']['precision']
|
||||
end_id = config['builder_config']['eos_token_id']
|
||||
pad_id = config['builder_config']['pad_token_id']
|
||||
config['builder_config']['max_batch_size']
|
||||
max_input_len = config['builder_config']['max_input_len']
|
||||
max_output_len = config['builder_config']['max_output_len']
|
||||
config['builder_config']['max_beam_width']
|
||||
remove_input_padding = config['builder_config']['remove_input_padding']
|
||||
use_gpt_attention_plugin = config['plugin_config']['gpt_attention_plugin']
|
||||
world_size = config['builder_config']['tensor_parallel']
|
||||
assert world_size == tensorrt_llm.mpi_world_size(
|
||||
), f'Engine world size ({world_size}) != Runtime world size ({tensorrt_llm.mpi_world_size()})'
|
||||
|
||||
runtime_rank = tensorrt_llm.mpi_rank()
|
||||
runtime_mapping = tensorrt_llm.Mapping(world_size,
|
||||
runtime_rank,
|
||||
tp_size=world_size)
|
||||
runtime_mapping = tensorrt_llm.Mapping(
|
||||
world_size,
|
||||
runtime_rank,
|
||||
tp_size=world_size,
|
||||
)
|
||||
torch.cuda.set_device(runtime_rank % runtime_mapping.gpus_per_node)
|
||||
|
||||
serialize_path = find_engines(
|
||||
@ -88,15 +85,51 @@ def generate(model_name, batch_size, beam_width):
|
||||
|
||||
tokenizer = transformers.AutoTokenizer.from_pretrained(
|
||||
args.tokenizer_dir, trust_remote_code=True)
|
||||
end_id = tokenizer.eos_token_id
|
||||
pad_id = tokenizer.pad_token_id
|
||||
if args.model_name in ["glm_10b"]:
|
||||
sop_id = tokenizer.sop_token_id
|
||||
eop_id = tokenizer.eop_token_id
|
||||
input_text = args.input_text
|
||||
tokenized = tokenizer(input_text,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
return_length=True)
|
||||
input_ids = tokenized['input_ids'].int().contiguous().cuda()
|
||||
input_lengths = tokenized['length'].int().contiguous().cuda()
|
||||
input_ids = tokenized['input_ids'].int()
|
||||
input_lengths = tokenized['length'].int()
|
||||
max_input_len_real = torch.max(input_lengths)
|
||||
if max_input_len_real > max_input_len:
|
||||
print("Truncate input_length as %d" % max_input_len)
|
||||
input_ids = input_ids[:, :max_input_len]
|
||||
input_lengths = torch.where(input_lengths > max_input_len,
|
||||
max_input_len, input_lengths)
|
||||
else:
|
||||
max_input_len = max_input_len_real
|
||||
if args.model_name in ["glm_10b"]:
|
||||
input_ids = torch.cat(
|
||||
(input_ids, input_ids.new_full((batch_size, 1), sop_id)),
|
||||
dim=-1,
|
||||
)
|
||||
input_lengths += 1
|
||||
max_input_len_real += 1
|
||||
|
||||
if use_gpt_attention_plugin:
|
||||
if remove_input_padding:
|
||||
input_ids_no_padding = torch.zeros(1,
|
||||
torch.sum(input_lengths),
|
||||
dtype=torch.int32)
|
||||
lengths_acc = torch.cumsum(
|
||||
torch.cat([torch.IntTensor([0]), input_lengths]),
|
||||
dim=0,
|
||||
)
|
||||
for i in range(len(input_ids)):
|
||||
input_ids_no_padding[
|
||||
0, lengths_acc[i]:lengths_acc[i + 1]] = torch.IntTensor(
|
||||
input_ids[i,
|
||||
max_input_len - input_lengths[i]:max_input_len])
|
||||
|
||||
input_ids = input_ids_no_padding
|
||||
|
||||
elif use_gpt_attention_plugin:
|
||||
# when using gpt attention plugin, inputs needs to align at the head
|
||||
input_ids_padding_right = torch.zeros_like(input_ids) + end_id
|
||||
for i, sample in enumerate(input_ids):
|
||||
@ -125,7 +158,7 @@ def generate(model_name, batch_size, beam_width):
|
||||
)
|
||||
|
||||
sampling_config = SamplingConfig(
|
||||
end_id=end_id,
|
||||
end_id=eop_id if args.model_name in ["glm_10b"] else end_id,
|
||||
pad_id=pad_id,
|
||||
num_beams=args.beam_width,
|
||||
temperature=args.temperature,
|
||||
@ -136,23 +169,35 @@ def generate(model_name, batch_size, beam_width):
|
||||
|
||||
with open(serialize_path, 'rb') as f:
|
||||
engine_buffer = f.read()
|
||||
if model_name == 'chatglm-6b':
|
||||
decoder = ChatGLMGenerationSession(
|
||||
model_config,
|
||||
engine_buffer,
|
||||
runtime_mapping,
|
||||
)
|
||||
|
||||
if args.model_name in ["chatglm_6b", "glm_10b"]:
|
||||
session = ChatGLMGenerationSession
|
||||
else:
|
||||
decoder = GenerationSession(
|
||||
model_config,
|
||||
engine_buffer,
|
||||
runtime_mapping,
|
||||
)
|
||||
decoder.setup(input_ids.size(0), input_ids.size(1), args.max_output_len,
|
||||
args.beam_width)
|
||||
output_ids = decoder.decode(input_ids, input_lengths, sampling_config)
|
||||
session = GenerationSession
|
||||
decoder = session(
|
||||
model_config,
|
||||
engine_buffer,
|
||||
runtime_mapping,
|
||||
)
|
||||
|
||||
decoder.setup(
|
||||
len(input_text),
|
||||
max_input_len,
|
||||
max_output_len,
|
||||
beam_width,
|
||||
)
|
||||
output = decoder.decode(
|
||||
input_ids.contiguous().cuda(),
|
||||
input_lengths.contiguous().cuda(),
|
||||
sampling_config,
|
||||
output_sequence_lengths=True,
|
||||
return_dict=True,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
output_ids = output["output_ids"]
|
||||
output["sequence_lengths"]
|
||||
|
||||
data_path = Path(__file__).parent.parent / "data" / model_name
|
||||
data_path.mkdir(parents=True, exist_ok=True)
|
||||
nBS, nBM = input_ids.size(0), args.beam_width
|
||||
@ -174,12 +219,13 @@ def generate(model_name, batch_size, beam_width):
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
generate("chatglm-6b", batch_size=1, beam_width=1)
|
||||
generate("chatglm-6b", batch_size=2, beam_width=1)
|
||||
generate("chatglm2-6b", batch_size=1, beam_width=1)
|
||||
generate("chatglm2-6b", batch_size=2, beam_width=1)
|
||||
generate("chatglm2-6b", batch_size=1, beam_width=2)
|
||||
generate("chatglm3-6b", batch_size=1, beam_width=1)
|
||||
generate("chatglm3-6b", batch_size=2, beam_width=1)
|
||||
generate("chatglm3-6b", batch_size=1, beam_width=2)
|
||||
generate("chatglm_6b", batch_size=1, beam_width=1)
|
||||
generate("chatglm2_6b", batch_size=1, beam_width=1)
|
||||
generate("chatglm2_6b", batch_size=2, beam_width=1)
|
||||
generate("chatglm2_6b", batch_size=1, beam_width=2)
|
||||
generate("chatglm3_6b", batch_size=1, beam_width=1)
|
||||
generate("chatglm3_6b", batch_size=2, beam_width=1)
|
||||
generate("chatglm3_6b", batch_size=1, beam_width=2)
|
||||
#generate("glm_10b", batch_size=1, beam_width=1)
|
||||
#generate("glm_10b", batch_size=2, beam_width=1)
|
||||
print("Done.")
|
||||
|
||||
@ -86,12 +86,14 @@ def run_tests(cuda_architectures: _tp.Optional[str] = None,
|
||||
build_dir: _tp.Optional[str] = None,
|
||||
dist_dir: _tp.Optional[str] = None,
|
||||
model_cache: _tp.Optional[str] = None,
|
||||
skip_gpt=False,
|
||||
skip_gptj=False,
|
||||
skip_llama=False,
|
||||
skip_chatglm=False,
|
||||
only_fp8=False,
|
||||
only_multi_gpu=False,
|
||||
trt_root: _tp.Optional[str] = None) -> None:
|
||||
trt_root: _tp.Optional[str] = None,
|
||||
build_only=False) -> None:
|
||||
root_dir = find_root_dir()
|
||||
_log.info("Using root directory: %s", str(root_dir))
|
||||
|
||||
@ -114,27 +116,39 @@ def run_tests(cuda_architectures: _tp.Optional[str] = None,
|
||||
root_dir=root_dir,
|
||||
resources_dir=resources_dir,
|
||||
model_cache=model_cache,
|
||||
skip_gpt=skip_gpt,
|
||||
skip_gptj=skip_gptj,
|
||||
skip_llama=skip_llama,
|
||||
skip_chatglm=skip_chatglm,
|
||||
only_fp8=only_fp8)
|
||||
|
||||
if build_only:
|
||||
return
|
||||
|
||||
run_google_tests(build_dir=build_dir,
|
||||
skip_gpt=skip_gpt,
|
||||
skip_gptj=skip_gptj,
|
||||
skip_llama=skip_llama,
|
||||
skip_chatglm=skip_chatglm,
|
||||
only_fp8=only_fp8)
|
||||
|
||||
run_benchmarks(python_exe=python_exe,
|
||||
root_dir=root_dir,
|
||||
build_dir=build_dir,
|
||||
resources_dir=resources_dir)
|
||||
if not skip_gpt:
|
||||
run_benchmarks(python_exe=python_exe,
|
||||
root_dir=root_dir,
|
||||
build_dir=build_dir,
|
||||
resources_dir=resources_dir)
|
||||
else:
|
||||
_log.info("Skipping benchmarks")
|
||||
|
||||
else:
|
||||
prepare_multi_gpu_model_tests(python_exe=python_exe,
|
||||
root_dir=root_dir,
|
||||
resources_dir=resources_dir,
|
||||
model_cache=model_cache)
|
||||
|
||||
if build_only:
|
||||
return
|
||||
|
||||
run_multi_gpu_tests(build_dir=build_dir)
|
||||
|
||||
|
||||
@ -142,6 +156,7 @@ def prepare_all_model_tests(python_exe: str,
|
||||
root_dir: _pl.Path,
|
||||
resources_dir: _pl.Path,
|
||||
model_cache: _tp.Optional[str] = None,
|
||||
skip_gpt=False,
|
||||
skip_gptj=False,
|
||||
skip_llama=False,
|
||||
skip_chatglm=False,
|
||||
@ -149,11 +164,14 @@ def prepare_all_model_tests(python_exe: str,
|
||||
model_cache_arg = ["--model_cache", model_cache] if model_cache else []
|
||||
only_fp8_arg = ["--only_fp8"] if only_fp8 else []
|
||||
|
||||
prepare_model_tests(model_name="gpt",
|
||||
python_exe=python_exe,
|
||||
root_dir=root_dir,
|
||||
resources_dir=resources_dir,
|
||||
model_cache_arg=model_cache_arg)
|
||||
if not skip_gpt:
|
||||
prepare_model_tests(model_name="gpt",
|
||||
python_exe=python_exe,
|
||||
root_dir=root_dir,
|
||||
resources_dir=resources_dir,
|
||||
model_cache_arg=model_cache_arg)
|
||||
else:
|
||||
_log.info("Skipping GPT tests")
|
||||
|
||||
if not skip_gptj:
|
||||
prepare_model_tests(model_name="gptj",
|
||||
@ -228,8 +246,8 @@ def prepare_model_tests(model_name: str,
|
||||
run_command(generate_expected_output, cwd=root_dir, env=model_env)
|
||||
|
||||
|
||||
def run_google_tests(build_dir: _pl.Path, skip_gptj, skip_llama, skip_chatglm,
|
||||
only_fp8):
|
||||
def run_google_tests(build_dir: _pl.Path, skip_gpt, skip_gptj, skip_llama,
|
||||
skip_chatglm, only_fp8):
|
||||
make_google_tests = [
|
||||
"cmake", "--build", ".", "--config", "Release", "-j", "--target",
|
||||
"google-tests"
|
||||
@ -239,6 +257,10 @@ def run_google_tests(build_dir: _pl.Path, skip_gptj, skip_llama, skip_chatglm,
|
||||
cpp_env = {**_os.environ}
|
||||
ctest = ["ctest", "--output-on-failure", "--output-junit", "results.xml"]
|
||||
excluded_tests = []
|
||||
if skip_gpt:
|
||||
excluded_tests.append(
|
||||
".*GptTest.*|.*GptSessionTest.*|.*GptManagerTest.*|.*TrtGptModelTest.*"
|
||||
)
|
||||
if skip_gptj:
|
||||
excluded_tests.append(".*Gptj.*")
|
||||
if skip_llama:
|
||||
@ -343,6 +365,21 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--model_cache",
|
||||
type=str,
|
||||
help="Directory where models are stored")
|
||||
parser.add_argument("--only_gpt",
|
||||
action="store_true",
|
||||
help="Run only the tests for GPT")
|
||||
parser.add_argument("--only_gptj",
|
||||
action="store_true",
|
||||
help="Run only the tests for GPT-J")
|
||||
parser.add_argument("--only_llama",
|
||||
action="store_true",
|
||||
help="Run only the tests for Llama")
|
||||
parser.add_argument("--only_chatglm",
|
||||
action="store_true",
|
||||
help="Run only the tests for ChatGLM")
|
||||
parser.add_argument("--skip_gpt",
|
||||
action="store_true",
|
||||
help="Skip the tests for GPT")
|
||||
parser.add_argument("--skip_gptj",
|
||||
action="store_true",
|
||||
help="Skip the tests for GPT-J")
|
||||
@ -360,5 +397,39 @@ if __name__ == "__main__":
|
||||
"--only_multi_gpu",
|
||||
action="store_true",
|
||||
help="Run only mulit-GPU tests. Implemented for 4 GPUs.")
|
||||
parser.add_argument("--build_only",
|
||||
action="store_true",
|
||||
help="Build only, do not run tests.")
|
||||
|
||||
run_tests(**vars(parser.parse_args()))
|
||||
args = parser.parse_args()
|
||||
|
||||
if (args.only_gpt + args.only_gptj + args.only_llama + args.only_chatglm >
|
||||
1):
|
||||
parser.error('Cannot combine multiple only_* arguments.')
|
||||
|
||||
if args.only_gpt:
|
||||
args.skip_gptj = True
|
||||
args.skip_llama = True
|
||||
args.skip_chatglm = True
|
||||
|
||||
if args.only_gptj:
|
||||
args.skip_gpt = True
|
||||
args.skip_llama = True
|
||||
args.skip_chatglm = True
|
||||
|
||||
if args.only_llama:
|
||||
args.skip_gpt = True
|
||||
args.skip_gptj = True
|
||||
args.skip_chatglm = True
|
||||
|
||||
if args.only_chatglm:
|
||||
args.skip_gpt = True
|
||||
args.skip_gptj = True
|
||||
args.skip_llama = True
|
||||
|
||||
del args.only_gpt
|
||||
del args.only_gptj
|
||||
del args.only_llama
|
||||
del args.only_chatglm
|
||||
|
||||
run_tests(**vars(args))
|
||||
|
||||
@ -115,7 +115,7 @@ TEST_F(BufferManagerTest, Pointers)
|
||||
static_assert(static_cast<nvinfer1::DataType>(trtPointerType) == BufferDataType::kTrtPointerType);
|
||||
static_assert(trtPointerType == BufferDataType::kTrtPointerType); // uses implicit type conversion
|
||||
// The C++ type corresponding to the TensorRT type for storing pointers (int64_t)
|
||||
using cppStorageType = CppDataType<trtPointerType>::type;
|
||||
using cppStorageType = DataTypeTraits<trtPointerType>::type;
|
||||
static_assert(sizeof(cppStorageType) == sizeof(cppPointerType));
|
||||
|
||||
BufferManager manager(mStream);
|
||||
@ -152,4 +152,21 @@ TEST_F(BufferManagerTest, MemPoolAttributes)
|
||||
std::uint64_t threshold{0};
|
||||
TLLM_CUDA_CHECK(cudaMemPoolGetAttribute(memPool, cudaMemPoolAttrReleaseThreshold, &threshold));
|
||||
EXPECT_EQ(threshold, std::numeric_limits<std::uint64_t>::max());
|
||||
|
||||
manager.memoryPoolTrimTo(0);
|
||||
auto const reserved = manager.memoryPoolReserved();
|
||||
auto const used = manager.memoryPoolUsed();
|
||||
auto const free = manager.memoryPoolFree();
|
||||
EXPECT_EQ(free, reserved - used);
|
||||
auto constexpr kBytesToReserve = 1 << 20;
|
||||
{
|
||||
auto const mem = manager.allocate(MemoryType::kGPU, kBytesToReserve);
|
||||
EXPECT_EQ(mem->getSize(), kBytesToReserve);
|
||||
EXPECT_GE(manager.memoryPoolReserved(), reserved + kBytesToReserve);
|
||||
EXPECT_GE(manager.memoryPoolUsed(), used + kBytesToReserve);
|
||||
}
|
||||
EXPECT_GE(manager.memoryPoolFree(), free + kBytesToReserve);
|
||||
manager.memoryPoolTrimTo(0);
|
||||
EXPECT_LE(manager.memoryPoolReserved(), reserved);
|
||||
EXPECT_LE(manager.memoryPoolFree(), free);
|
||||
}
|
||||
|
||||
@ -90,7 +90,8 @@ void verifyResults(BufferManager& manager, GptDecoderBatch const& decoder,
|
||||
}
|
||||
}
|
||||
|
||||
void testDecoder(nvinfer1::DataType const dtype, std::vector<SamplingConfig> const& samplingConfigs, int maxBeamWidth)
|
||||
void testDecoder(nvinfer1::DataType const dtype, std::vector<SamplingConfig> const& samplingConfigs, int maxBeamWidth,
|
||||
bool computeLogProbs)
|
||||
{
|
||||
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
|
||||
SizeType constexpr tensorParallelism{1};
|
||||
@ -172,7 +173,11 @@ void testDecoder(nvinfer1::DataType const dtype, std::vector<SamplingConfig> con
|
||||
auto input = std::shared_ptr(manager.gpu(shape, TRTDataType<SizeType>::value));
|
||||
kernels::invokeFill(*input, tokenId, *streamPtr);
|
||||
inputIds.emplace_back(input);
|
||||
decoder.newRequest(b, decoder_batch::Request{inputIds[b], maxNewTokens, endId, padId}, samplingConfigs[b]);
|
||||
|
||||
auto decoderRequest = decoder_batch::Request{inputIds[b], maxNewTokens, endId};
|
||||
decoderRequest.computeCumLogProbs = computeLogProbs;
|
||||
decoderRequest.computeLogProbs = computeLogProbs;
|
||||
decoder.newRequest(b, decoderRequest, samplingConfigs[b]);
|
||||
}
|
||||
cudaDeviceSynchronize();
|
||||
|
||||
@ -206,13 +211,16 @@ void testDecoder(nvinfer1::DataType const dtype, std::vector<SamplingConfig> con
|
||||
EXPECT_NO_THROW(decoder.forward(outputs, inputs));
|
||||
EXPECT_THAT(decoder.getNbSteps(), ::testing::Each(maxNewTokens));
|
||||
|
||||
decoder.newRequest(0, decoder_batch::Request{inputIds[0], maxNewTokens}, samplingConfigs[0]);
|
||||
auto decoderRequest = decoder_batch::Request{inputIds[0], maxNewTokens};
|
||||
decoderRequest.computeCumLogProbs = computeLogProbs;
|
||||
decoderRequest.computeLogProbs = computeLogProbs;
|
||||
decoder.newRequest(0, decoderRequest, samplingConfigs[0]);
|
||||
EXPECT_FALSE(decoder.getFinished()[0]);
|
||||
EXPECT_EQ(decoder.getNbSteps()[0], 0);
|
||||
}
|
||||
|
||||
void testDecoderWavefront(
|
||||
nvinfer1::DataType const dtype, std::vector<SamplingConfig> const& samplingConfigs, int maxBeamWidth)
|
||||
void testDecoderWavefront(nvinfer1::DataType const dtype, std::vector<SamplingConfig> const& samplingConfigs,
|
||||
int maxBeamWidth, bool computeLogProbs)
|
||||
{
|
||||
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
|
||||
SizeType constexpr tensorParallelism{1};
|
||||
@ -302,7 +310,11 @@ void testDecoderWavefront(
|
||||
auto input = std::shared_ptr(manager.gpu(shape, TRTDataType<SizeType>::value));
|
||||
kernels::invokeFill(*input, tokenId, *streamPtr);
|
||||
inputIds.emplace_back(input);
|
||||
decoder.newRequest(b, decoder_batch::Request{inputIds[b], maxNewTokens, endId, padId}, samplingConfigs[b]);
|
||||
|
||||
auto decoderRequest = decoder_batch::Request{inputIds[b], maxNewTokens, endId};
|
||||
decoderRequest.computeCumLogProbs = computeLogProbs;
|
||||
decoderRequest.computeLogProbs = computeLogProbs;
|
||||
decoder.newRequest(b, decoderRequest, samplingConfigs[b]);
|
||||
|
||||
decoder.forward(outputs, inputs);
|
||||
|
||||
@ -335,7 +347,7 @@ struct BeamConfig
|
||||
std::vector<SizeType> beamWidths;
|
||||
};
|
||||
|
||||
class ParamTest : public ::testing::TestWithParam<std::tuple<nvinfer1::DataType, BeamConfig>>
|
||||
class ParamTest : public ::testing::TestWithParam<std::tuple<nvinfer1::DataType, BeamConfig, bool>>
|
||||
{
|
||||
};
|
||||
|
||||
@ -343,32 +355,39 @@ TEST_P(ParamTest, Test)
|
||||
{
|
||||
nvinfer1::DataType const dtype{std::get<0>(GetParam())};
|
||||
BeamConfig const beamConfig{std::get<1>(GetParam())};
|
||||
bool const computeLogProbs{std::get<2>(GetParam())};
|
||||
std::vector<SamplingConfig> samplingConfigs;
|
||||
for (auto const beamWidth : beamConfig.beamWidths)
|
||||
{
|
||||
samplingConfigs.emplace_back(beamWidth);
|
||||
}
|
||||
|
||||
testDecoder(dtype, samplingConfigs, beamConfig.maxBeamWidth);
|
||||
testDecoder(dtype, samplingConfigs, beamConfig.maxBeamWidth, computeLogProbs);
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(GptDecoderTest, ParamTest,
|
||||
INSTANTIATE_TEST_SUITE_P(GptDecoderBatchTest, ParamTest,
|
||||
testing::Combine(testing::Values(nvinfer1::DataType::kFLOAT, nvinfer1::DataType::kHALF),
|
||||
testing::Values(BeamConfig{1, {1, 1, 1}}, BeamConfig{3, {3, 3, 3, 3}}, BeamConfig{4, {1, 1}},
|
||||
BeamConfig{4, {3, 3, 3}}, BeamConfig{4, {1, 2, 3, 4}})),
|
||||
BeamConfig{4, {3, 3, 3}}, BeamConfig{4, {1, 2, 3, 4}}),
|
||||
testing::Values(false, true)),
|
||||
[](const testing::TestParamInfo<ParamTest::ParamType>& info)
|
||||
{
|
||||
std::string name{std::get<0>(info.param) == nvinfer1::DataType::kFLOAT ? "Float" : "Half"};
|
||||
BeamConfig const beamConfig = std::get<1>(info.param);
|
||||
bool const computeLogProbs = std::get<2>(info.param);
|
||||
name.append("MaxBeamWidth" + std::to_string(beamConfig.maxBeamWidth));
|
||||
for (auto const beamWdith : beamConfig.beamWidths)
|
||||
for (auto const beamWidth : beamConfig.beamWidths)
|
||||
{
|
||||
name.append("Bw" + std::to_string(beamWdith));
|
||||
name.append("Bw" + std::to_string(beamWidth));
|
||||
}
|
||||
if (computeLogProbs)
|
||||
{
|
||||
name.append("LogProbs");
|
||||
}
|
||||
return name;
|
||||
});
|
||||
|
||||
class ParamWavefrontTest : public ::testing::TestWithParam<std::tuple<nvinfer1::DataType, BeamConfig>>
|
||||
class ParamWavefrontTest : public ::testing::TestWithParam<std::tuple<nvinfer1::DataType, BeamConfig, bool>>
|
||||
{
|
||||
};
|
||||
|
||||
@ -376,27 +395,34 @@ TEST_P(ParamWavefrontTest, Test)
|
||||
{
|
||||
nvinfer1::DataType const dtype{std::get<0>(GetParam())};
|
||||
BeamConfig const beamConfig{std::get<1>(GetParam())};
|
||||
bool const computeLogProbs{std::get<2>(GetParam())};
|
||||
std::vector<SamplingConfig> samplingConfigs;
|
||||
for (auto const beamWidth : beamConfig.beamWidths)
|
||||
{
|
||||
samplingConfigs.emplace_back(beamWidth);
|
||||
}
|
||||
|
||||
testDecoderWavefront(dtype, samplingConfigs, beamConfig.maxBeamWidth);
|
||||
testDecoderWavefront(dtype, samplingConfigs, beamConfig.maxBeamWidth, computeLogProbs);
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(GptDecoderTest, ParamWavefrontTest,
|
||||
INSTANTIATE_TEST_SUITE_P(GptDecoderBatchTest, ParamWavefrontTest,
|
||||
testing::Combine(testing::Values(nvinfer1::DataType::kFLOAT, nvinfer1::DataType::kHALF),
|
||||
testing::Values(BeamConfig{1, {1, 1, 1}}, BeamConfig{3, {3, 3, 3, 3}}, BeamConfig{4, {1, 1}},
|
||||
BeamConfig{4, {3, 3, 3}}, BeamConfig{4, {1, 2, 3, 4}})),
|
||||
BeamConfig{4, {3, 3, 3}}, BeamConfig{4, {1, 2, 3, 4}}),
|
||||
testing::Values(false, true)),
|
||||
[](const testing::TestParamInfo<ParamTest::ParamType>& info)
|
||||
{
|
||||
std::string name{std::get<0>(info.param) == nvinfer1::DataType::kFLOAT ? "Float" : "Half"};
|
||||
BeamConfig const beamConfig = std::get<1>(info.param);
|
||||
bool const computeLogProbs = std::get<2>(info.param);
|
||||
name.append("MaxBeamWidth" + std::to_string(beamConfig.maxBeamWidth));
|
||||
for (auto const beamWdith : beamConfig.beamWidths)
|
||||
{
|
||||
name.append("Bw" + std::to_string(beamWdith));
|
||||
}
|
||||
if (computeLogProbs)
|
||||
{
|
||||
name.append("LogProbs");
|
||||
}
|
||||
return name;
|
||||
});
|
||||
|
||||
@ -55,18 +55,17 @@ void testDecoder(nvinfer1::DataType const dtype, SamplingConfig const& samplingC
|
||||
auto const beamWidth = samplingConfig.beamWidth;
|
||||
SizeType constexpr batchSize{4};
|
||||
|
||||
decoder->setup(samplingConfig, batchSize);
|
||||
|
||||
int constexpr endId{50257};
|
||||
SizeType constexpr maxInputLength{8};
|
||||
SizeType constexpr maxNewTokens{2};
|
||||
auto constexpr maxSeqLength = maxInputLength + maxNewTokens;
|
||||
decoder->setup(samplingConfig, batchSize, maxSeqLength);
|
||||
|
||||
// set up inputs
|
||||
auto logits = std::shared_ptr(
|
||||
manager.gpu(ITensor::makeShape({batchSize, beamWidth, vocabSizePadded}), modelConfig.getDataType()));
|
||||
manager.setZero(*logits);
|
||||
|
||||
int constexpr endId{50257};
|
||||
std::vector<int> const endIdsVec(batchSize * beamWidth, endId);
|
||||
auto endIds
|
||||
= std::shared_ptr(manager.copyFrom(endIdsVec, ITensor::makeShape({batchSize, beamWidth}), MemoryType::kGPU));
|
||||
|
||||
@ -40,7 +40,7 @@ namespace fs = std::filesystem;
|
||||
namespace
|
||||
{
|
||||
auto const TEST_RESOURCE_PATH = fs::path{TOP_LEVEL_DIR} / "cpp/tests/resources";
|
||||
auto const ENGINGE_PATH = TEST_RESOURCE_PATH / "models/rt_engine";
|
||||
auto const ENGINE_PATH = TEST_RESOURCE_PATH / "models/rt_engine";
|
||||
auto const DATA_PATH = TEST_RESOURCE_PATH / "data";
|
||||
|
||||
auto const GPT_MODEL_DIR = "gpt2";
|
||||
@ -500,7 +500,7 @@ TEST_P(ParamTest, Test)
|
||||
|
||||
std::ostringstream gpuSizePath;
|
||||
gpuSizePath << "tp" << modelSpec.mTPSize << "-pp" << modelSpec.mPPSize << "-gpu";
|
||||
auto const modelPath{ENGINGE_PATH / modelDir / modelSpec.mModelPath / gpuSizePath.str()};
|
||||
auto const modelPath{ENGINE_PATH / modelDir / modelSpec.mModelPath / gpuSizePath.str()};
|
||||
auto const resultsPath
|
||||
= DATA_PATH / modelDir / ((beamWidth == 1) ? "sampling" : "beam_search_" + std::to_string(beamWidth));
|
||||
fs::path const resultsFile{resultsPath / modelSpec.mResultsFile};
|
||||
@ -642,7 +642,7 @@ TEST_F(LlamaSessionOnDemandTest, SamplingFP16WithAttentionPlugin)
|
||||
GTEST_SKIP() << "Run only on demand";
|
||||
auto const modelDir = "llama_7bf";
|
||||
auto const engineDir = "llama_7bf_outputs_tp1";
|
||||
auto const modelPath{ENGINGE_PATH / modelDir / engineDir};
|
||||
auto const modelPath{ENGINE_PATH / modelDir / engineDir};
|
||||
SizeType constexpr beamWidth{1};
|
||||
fs::path resultsFile{DATA_PATH / modelDir / FP16_RESULT_FILE};
|
||||
auto const batchSizes = {8};
|
||||
@ -659,7 +659,7 @@ TEST_F(LlamaSessionOnDemandTest, SamplingFP16AttentionPluginDecoderBatch)
|
||||
{
|
||||
GTEST_SKIP() << "Run only on demand";
|
||||
auto const modelDir = "llamav2";
|
||||
auto const modelPath{ENGINGE_PATH / modelDir};
|
||||
auto const modelPath{ENGINE_PATH / modelDir};
|
||||
SizeType constexpr beamWidth{1};
|
||||
fs::path resultsFile{DATA_PATH / modelDir / FP16_RESULT_FILE};
|
||||
auto const batchSizes = {8};
|
||||
@ -676,11 +676,11 @@ class ChatGlmSessionTest : public SessionTest // for ChatGLM-6B
|
||||
{
|
||||
};
|
||||
|
||||
class ChatGlm2SessionTest : public SessionTest // for ChatGLM2-6B and ChatGLM2-6B-32k
|
||||
class ChatGlm2SessionTest : public SessionTest // for ChatGLM2-6B
|
||||
{
|
||||
};
|
||||
|
||||
class ChatGlm3SessionTest : public SessionTest // for ChatGLM3-6B and ChatGLM3-6B-32k
|
||||
class ChatGlm3SessionTest : public SessionTest // for ChatGLM3-6B
|
||||
{
|
||||
};
|
||||
|
||||
@ -691,7 +691,7 @@ namespace
|
||||
{
|
||||
|
||||
// TODO: consolidate this function with testGptSession
|
||||
// Notice: all ChatGLM models (ChatGLM-6B, ChatGLM2-6B, ChatGLM3-6B, ChatGLM2-6B-32k and ChatGLM3-6B-32k) use this
|
||||
// Notice: all ChatGLM / GLM models use this
|
||||
// function The differences are GptModelConfig::ModelVariant
|
||||
void testChatGlmSession(fs::path const& modelPath, std::string const& modelName, ModelSpec const& modelSpec,
|
||||
ModelIds const modelIds, SizeType beamWidth, std::initializer_list<int> const& batchSizes,
|
||||
@ -704,7 +704,7 @@ void testChatGlmSession(fs::path const& modelPath, std::string const& modelName,
|
||||
std::string fileNameSuffix
|
||||
= std::string("-BS") + std::to_string(batchSize) + "-BM" + std::to_string(beamWidth) + std::string(".npy");
|
||||
fs::path givenInputPath = DATA_PATH / modelName / (std::string("inputId") + fileNameSuffix);
|
||||
auto const& givenInput = utils::loadNpy(manager, givenInputPath, MemoryType::kCPU);
|
||||
auto const& givenInput = utils::loadNpy(manager, givenInputPath.string(), MemoryType::kCPU);
|
||||
auto const& inputShape = givenInput->getShape();
|
||||
ASSERT_EQ(inputShape.nbDims, 2);
|
||||
ASSERT_GT(inputShape.d[0], 0);
|
||||
@ -730,7 +730,7 @@ void testChatGlmSession(fs::path const& modelPath, std::string const& modelName,
|
||||
ASSERT_TRUE(fs::exists(enginePath));
|
||||
|
||||
auto const maxInputLength = static_cast<SizeType>(inputShape.d[1]);
|
||||
auto const maxNewTokens = 1024;
|
||||
auto const maxNewTokens = 512;
|
||||
auto const maxSeqLengthGroundTruth = static_cast<SizeType>(outputShape.d[2]);
|
||||
auto const maxSeqLength = maxInputLength + maxNewTokens;
|
||||
SamplingConfig samplingConfig{beamWidth};
|
||||
@ -866,8 +866,8 @@ void testChatGlmSession(fs::path const& modelPath, std::string const& modelName,
|
||||
|
||||
TEST_F(ChatGlmSessionTest, SamplingFP16WithGptAttentionPluginBS1BM1)
|
||||
{
|
||||
auto const modelName{"chatglm-6b"};
|
||||
auto const modelPath{ENGINGE_PATH / "chatglm"};
|
||||
auto const modelName{"chatglm_6b"};
|
||||
auto const modelPath{ENGINE_PATH / "chatglm"};
|
||||
auto const batchSizes = {1};
|
||||
auto constexpr dtype = nvinfer1::DataType::kHALF;
|
||||
auto const modelSpec = ModelSpec{"", "", dtype}.useGptAttentionPlugin();
|
||||
@ -876,22 +876,10 @@ TEST_F(ChatGlmSessionTest, SamplingFP16WithGptAttentionPluginBS1BM1)
|
||||
testChatGlmSession(modelPath, modelName, modelSpec, modeIds, 1, batchSizes, mLogger, false, MicroBatchSizes());
|
||||
}
|
||||
|
||||
TEST_F(ChatGlmSessionTest, SamplingFP16WithGptAttentionPluginBS2BM1)
|
||||
{
|
||||
auto const modelName{"chatglm-6b"};
|
||||
auto const modelPath{ENGINGE_PATH / "chatglm"};
|
||||
auto const batchSizes = {2};
|
||||
auto constexpr dtype = nvinfer1::DataType::kHALF;
|
||||
auto const modelSpec = ModelSpec{"", "", dtype}.useGptAttentionPlugin();
|
||||
auto const modeIds = ModelIds{130005, 130005};
|
||||
|
||||
testChatGlmSession(modelPath, modelName, modelSpec, modeIds, 1, batchSizes, mLogger, false, MicroBatchSizes());
|
||||
}
|
||||
|
||||
TEST_F(ChatGlm2SessionTest, SamplingFP16WithGptAttentionPluginBS1BM1)
|
||||
{
|
||||
auto const modelName{"chatglm2-6b"};
|
||||
auto const modelPath{ENGINGE_PATH / "chatglm"};
|
||||
auto const modelName{"chatglm2_6b"};
|
||||
auto const modelPath{ENGINE_PATH / "chatglm"};
|
||||
auto const batchSizes = {1};
|
||||
auto constexpr dtype = nvinfer1::DataType::kHALF;
|
||||
auto const modelSpec = ModelSpec{"", "", dtype}.useGptAttentionPlugin();
|
||||
@ -902,8 +890,8 @@ TEST_F(ChatGlm2SessionTest, SamplingFP16WithGptAttentionPluginBS1BM1)
|
||||
|
||||
TEST_F(ChatGlm2SessionTest, SamplingFP16WithGptAttentionPluginBS2BM1)
|
||||
{
|
||||
auto const modelName{"chatglm2-6b"};
|
||||
auto const modelPath{ENGINGE_PATH / "chatglm"};
|
||||
auto const modelName{"chatglm2_6b"};
|
||||
auto const modelPath{ENGINE_PATH / "chatglm"};
|
||||
auto const batchSizes = {2};
|
||||
auto constexpr dtype = nvinfer1::DataType::kHALF;
|
||||
auto const modelSpec = ModelSpec{"", "", dtype}.useGptAttentionPlugin();
|
||||
@ -914,8 +902,8 @@ TEST_F(ChatGlm2SessionTest, SamplingFP16WithGptAttentionPluginBS2BM1)
|
||||
|
||||
TEST_F(ChatGlm2SessionTest, SamplingFP16WithGptAttentionPluginBS1BM2)
|
||||
{
|
||||
auto const modelName{"chatglm2-6b"};
|
||||
auto const modelPath{ENGINGE_PATH / "chatglm"};
|
||||
auto const modelName{"chatglm2_6b"};
|
||||
auto const modelPath{ENGINE_PATH / "chatglm"};
|
||||
auto const batchSizes = {1};
|
||||
auto constexpr dtype = nvinfer1::DataType::kHALF;
|
||||
auto const modelSpec = ModelSpec{"", "", dtype}.useGptAttentionPlugin();
|
||||
@ -926,8 +914,8 @@ TEST_F(ChatGlm2SessionTest, SamplingFP16WithGptAttentionPluginBS1BM2)
|
||||
|
||||
TEST_F(ChatGlm3SessionTest, SamplingFP16WithGptAttentionPluginBS1BM1)
|
||||
{
|
||||
auto const modelName{"chatglm3-6b"};
|
||||
auto const modelPath{ENGINGE_PATH / "chatglm"};
|
||||
auto const modelName{"chatglm3_6b"};
|
||||
auto const modelPath{ENGINE_PATH / "chatglm"};
|
||||
auto const batchSizes = {1};
|
||||
auto constexpr dtype = nvinfer1::DataType::kHALF;
|
||||
auto const modelSpec = ModelSpec{"", "", dtype}.useGptAttentionPlugin();
|
||||
|
||||
@ -294,7 +294,7 @@ void testBufferType()
|
||||
using limits = std::numeric_limits<T>;
|
||||
static_assert(dataType.isPointer() || dataType.isUnsigned() != limits::is_signed);
|
||||
static_assert(std::is_same_v<T,
|
||||
typename CppDataType<dataType.getDataType(), dataType.isUnsigned(), dataType.isPointer()>::type>);
|
||||
typename DataTypeTraits<dataType.getDataType(), dataType.isUnsigned(), dataType.isPointer()>::type>);
|
||||
IBuffer::SharedPtr buffer{std::make_shared<HostBuffer>(size, dataType, allocator)};
|
||||
auto bufferPtr = bufferCast<T>(*buffer);
|
||||
auto constexpr max = limits::max();
|
||||
|
||||
@ -57,7 +57,7 @@ void checkFilled(IBuffer& buffer, int fillValue)
|
||||
{
|
||||
if (DType == buffer.getDataType())
|
||||
{
|
||||
EXPECT_THAT(BufferRange<typename CppDataType<DType>::type>(buffer), ::testing::Each(fillValue));
|
||||
EXPECT_THAT(BufferRange<typename DataTypeTraits<DType>::type>(buffer), ::testing::Each(fillValue));
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
# Multi-stage Dockerfile
|
||||
ARG BASE_IMAGE=nvcr.io/nvidia/pytorch
|
||||
ARG BASE_TAG=23.08-py3
|
||||
ARG BASE_TAG=23.10-py3
|
||||
|
||||
FROM ${BASE_IMAGE}:${BASE_TAG} as base
|
||||
|
||||
@ -19,10 +19,16 @@ COPY docker/common/install_cmake.sh install_cmake.sh
|
||||
RUN bash ./install_cmake.sh && rm install_cmake.sh
|
||||
|
||||
# Download & install internal TRT release
|
||||
ARG RELEASE_URL_TRT
|
||||
ARG TARGETARCH
|
||||
ENV RELEASE_URL_TRT=$RELEASE_URL_TRT
|
||||
ENV TRT_TARGETARCH=$TARGETARCH
|
||||
ARG TRT_VER="9.1.0.4"
|
||||
ENV TRT_VER=$TRT_VER
|
||||
ARG CUDA_VER="12.2"
|
||||
ENV CUDA_VER=$CUDA_VER
|
||||
ARG CUDNN_VER="8.9.4.25-1+cuda12.2"
|
||||
ENV CUDNN_VER=$CUDNN_VER
|
||||
ARG NCCL_VER="2.18.3-1+cuda12.2"
|
||||
ENV NCCL_VER=$NCCL_VER
|
||||
ARG CUBLAS_VER="12.2.5.6-1"
|
||||
ENV CUBLAS_VER=$CUBLAS_VER
|
||||
COPY docker/common/install_tensorrt.sh install_tensorrt.sh
|
||||
RUN bash ./install_tensorrt.sh && rm install_tensorrt.sh
|
||||
|
||||
|
||||
@ -27,6 +27,11 @@ DOCKER_PROGRESS ?= auto
|
||||
CUDA_ARCHS ?=
|
||||
BUILD_WHEEL_ARGS ?= $(shell grep 'ARG BUILD_WHEEL_ARGS=' Dockerfile.multi | grep -o '=.*' | tr -d '="')$(if $(CUDA_ARCHS), --cuda_architectures $(CUDA_ARCHS))
|
||||
TORCH_INSTALL_TYPE ?= skip
|
||||
CUDA_VERSION ?=
|
||||
CUDNN_VERSION ?=
|
||||
NCCL_VERSION ?=
|
||||
CUBLAS_VERSION ?=
|
||||
TRT_VERSION ?=
|
||||
|
||||
define add_local_user
|
||||
docker build \
|
||||
@ -50,6 +55,11 @@ endef
|
||||
$(if $(BASE_TAG), --build-arg BASE_TAG=$(BASE_TAG)) \
|
||||
$(if $(BUILD_WHEEL_ARGS), --build-arg BUILD_WHEEL_ARGS="$(BUILD_WHEEL_ARGS)") \
|
||||
$(if $(TORCH_INSTALL_TYPE), --build-arg TORCH_INSTALL_TYPE="$(TORCH_INSTALL_TYPE)") \
|
||||
$(if $(CUDA_VERSION), --build-arg CUDA_VER="$(CUDA_VERSION)") \
|
||||
$(if $(CUDNN_VERSION), --build-arg CUDNN_VER="$(CUDNN_VERSION)") \
|
||||
$(if $(NCCL_VERSION), --build-arg NCCL_VER="$(NCCL_VERSION)") \
|
||||
$(if $(CUBLAS_VERSION), --build-arg CUBLAS_VER="$(CUBLAS_VERSION)") \
|
||||
$(if $(TRT_VERSION), --build-arg TRT_VER="$(TRT_VERSION)") \
|
||||
$(if $(STAGE), --target $(STAGE)) \
|
||||
--file Dockerfile.multi \
|
||||
--tag $(IMAGE_WITH_TAG) \
|
||||
@ -92,23 +102,29 @@ wheel_%: STAGE = wheel
|
||||
|
||||
release_%: STAGE = release
|
||||
|
||||
# For x86_64 and aarch64
|
||||
jenkins_%: IMAGE_WITH_TAG = $(shell grep 'LLM_DOCKER_IMAGE = ' ../jenkins/L0_MergeRequest.groovy | grep -o '".*"' | tr -d '"')
|
||||
jenkins_%: STAGE = devel
|
||||
|
||||
# For x86_64
|
||||
centos7_%: IMAGE_WITH_TAG = $(shell grep 'LLM_CENTOS7_DOCKER_IMAGE = ' ../jenkins/L0_MergeRequest.groovy | grep -o '".*"' | tr -d '"')
|
||||
centos7_%: STAGE = devel
|
||||
centos7_%: TORCH_INSTALL_TYPE = src_cxx11_abi
|
||||
centos7_%: BASE_IMAGE = nvidia/cuda
|
||||
centos7_%: BASE_TAG = 12.2.0-devel-centos7
|
||||
centos7_%: BASE_TAG = 12.2.2-devel-centos7
|
||||
|
||||
# For x86_64 and aarch64
|
||||
ubuntu22_%: STAGE = devel
|
||||
ubuntu22_%: TORCH_INSTALL_TYPE = src_cxx11_abi
|
||||
ubuntu22_%: BASE_IMAGE = nvidia/cuda
|
||||
ubuntu22_%: BASE_TAG = 12.2.0-devel-ubuntu22.04
|
||||
ubuntu22_%: BASE_TAG = 12.2.2-devel-ubuntu22.04
|
||||
|
||||
# For x86_64 and aarch64
|
||||
old-cuda_%: IMAGE_WITH_TAG = $(shell grep 'LLM_OLD_CUDA_DOCKER_IMAGE = ' ../jenkins/L0_MergeRequest.groovy | grep -o '".*"' | tr -d '"')
|
||||
old-cuda_%: BASE_TAG = 23.07-py3
|
||||
old-cuda_%: STAGE = devel
|
||||
old-cuda_%: CUDA_VERSION = 12.1
|
||||
old-cuda_%: CUDNN_VERSION = 8.9.3.28-1+cuda12.1
|
||||
old-cuda_%: NCCL_VERSION = 2.18.3-1+cuda12.1
|
||||
old-cuda_%: CUBLAS_VERSION = 12.1.3.1-1
|
||||
|
||||
build: devel_build ;
|
||||
|
||||
|
||||
@ -30,6 +30,7 @@ init_ubuntu() {
|
||||
fi
|
||||
apt-get clean
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
echo "export LD_LIBRARY_PATH=/usr/local/cuda/lib64:\$LD_LIBRARY_PATH" >> "${ENV}"
|
||||
# Remove previous TRT installation
|
||||
if [[ $(apt list --installed | grep libnvinfer) ]]; then
|
||||
apt-get remove --purge -y libnvinfer*
|
||||
@ -63,10 +64,11 @@ init_centos() {
|
||||
yum -y update
|
||||
yum -y install centos-release-scl-rh epel-release
|
||||
# https://gitlab.com/nvidia/container-images/cuda
|
||||
echo "export LD_LIBRARY_PATH=/usr/local/cuda/lib64:\$LD_LIBRARY_PATH" >> "${ENV}"
|
||||
CUDA_VERSION=$(nvcc --version | sed -n 's/^.*release \([0-9]\+\.[0-9]\+\).*$/\1/p')
|
||||
YUM_CUDA=${CUDA_VERSION/./-}
|
||||
# Consistent with manylinux2014 centos-7 based version
|
||||
yum -y install wget rh-python${PY_VERSION} rh-python${PY_VERSION}-python-devel rh-git227 devtoolset-10 libffi-devel
|
||||
yum -y install wget git-lfs rh-python${PY_VERSION} rh-python${PY_VERSION}-python-devel rh-git227 devtoolset-10 libffi-devel
|
||||
yum -y install openmpi3 openmpi3-devel
|
||||
echo "source scl_source enable rh-git227 rh-python38" >> "${ENV}"
|
||||
echo "source scl_source enable devtoolset-10" >> "${DEVTOOLSET_ENV_FILE}"
|
||||
|
||||
@ -2,28 +2,53 @@
|
||||
|
||||
set -ex
|
||||
|
||||
NVCC_VERSION_OUTPUT=$(nvcc --version)
|
||||
if [[ $(echo $NVCC_VERSION_OUTPUT | grep -oP "\d+\.\d+" | head -n 1) != ${CUDA_VER} ]]; then
|
||||
echo "The version of pre-installed CUDA is not equal to ${CUDA_VER}."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
install_ubuntu_requirements() {
|
||||
CUDNN_VERSION="8"
|
||||
apt-get update && apt-get install -y --no-install-recommends gnupg2 curl ca-certificates
|
||||
ARCH=$(uname -m)
|
||||
if [ "$ARCH" = "amd64" ];then ARCH="x86_64";fi
|
||||
if [ "$ARCH" = "aarch64" ];then ARCH="sbsa";fi
|
||||
curl -fsSLO https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/${ARCH}/cuda-keyring_1.0-1_all.deb
|
||||
dpkg -i cuda-keyring_1.0-1_all.deb
|
||||
|
||||
apt-get update
|
||||
apt-get install -y --no-install-recommends libcudnn${CUDNN_VERSION} libcudnn${CUDNN_VERSION}-dev libnccl-dev
|
||||
if [[ $(apt list --installed | grep libcudnn8) ]]; then
|
||||
apt-get remove --purge -y libcudnn8*
|
||||
fi
|
||||
if [[ $(apt list --installed | grep libnccl) ]]; then
|
||||
apt-get remove --purge -y --allow-change-held-packages libnccl*
|
||||
fi
|
||||
if [[ $(apt list --installed | grep libcublas) ]]; then
|
||||
apt-get remove --purge -y --allow-change-held-packages libcublas*
|
||||
fi
|
||||
CUBLAS_CUDA_VERSION=$(echo $CUDA_VER | sed 's/\./-/g')
|
||||
apt-get install -y --no-install-recommends libcudnn8=${CUDNN_VER} libcudnn8-dev=${CUDNN_VER}
|
||||
apt-get install -y --no-install-recommends libnccl2=${NCCL_VER} libnccl-dev=${NCCL_VER}
|
||||
apt-get install -y --no-install-recommends libcublas-${CUBLAS_CUDA_VERSION}=${CUBLAS_VER} libcublas-dev-${CUBLAS_CUDA_VERSION}=${CUBLAS_VER}
|
||||
apt-get clean
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
}
|
||||
|
||||
install_centos_requirements() {
|
||||
CUDNN_VERSION="8"
|
||||
CUDNN_VER=$(echo $CUDNN_VER | sed 's/+/./g')
|
||||
CUBLAS_CUDA_VERSION=$(echo $CUDA_VER | sed 's/\./-/g')
|
||||
yum -y update
|
||||
yum -y install epel-release
|
||||
yum -y install libcudnn${CUDNN_VERSION} libcudnn${CUDNN_VERSION}-devel libnccl-devel
|
||||
yum remove -y libcudnn* && yum -y install libcudnn8-${CUDNN_VER} libcudnn8-devel-${CUDNN_VER}
|
||||
yum remove -y libnccl* && yum -y install libnccl-${NCCL_VER} libnccl-devel-${NCCL_VER}
|
||||
yum remove -y libcublas* && yum -y install libcublas-${CUBLAS_CUDA_VERSION}-${CUBLAS_VER} libcublas-devel-${CUBLAS_CUDA_VERSION}-${CUBLAS_VER}
|
||||
yum clean all
|
||||
}
|
||||
|
||||
install_tensorrt() {
|
||||
TENSOR_RT_VERSION="9.1.0.4"
|
||||
CUDA_VERSION="12.2"
|
||||
|
||||
PY_VERSION=$(python -c 'import sys; print(".".join(map(str, sys.version_info[0:2])))')
|
||||
PARSED_PY_VERSION=$(echo "${PY_VERSION//./}")
|
||||
TRT_CUDA_VERSION="12.2"
|
||||
|
||||
if [ -z "$RELEASE_URL_TRT" ];then
|
||||
ARCH=${TRT_TARGETARCH}
|
||||
@ -32,11 +57,11 @@ install_tensorrt() {
|
||||
if [ "$ARCH" = "amd64" ];then ARCH="x86_64";fi
|
||||
if [ "$ARCH" = "x86_64" ];then DIR_NAME="x64-agnostic"; else DIR_NAME=${ARCH};fi
|
||||
if [ "$ARCH" = "aarch64" ];then OS="ubuntu-22.04"; else OS="linux";fi
|
||||
RELEASE_URL_TRT=https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/secure/9.1.0/tars/tensorrt-${TENSOR_RT_VERSION}.${OS}.${ARCH}-gnu.cuda-${CUDA_VERSION}.tar.gz;
|
||||
RELEASE_URL_TRT=https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/secure/9.1.0/tars/tensorrt-${TRT_VER}.${OS}.${ARCH}-gnu.cuda-${TRT_CUDA_VERSION}.tar.gz;
|
||||
fi
|
||||
wget --no-verbose ${RELEASE_URL_TRT} -O /tmp/TensorRT.tar
|
||||
tar -xf /tmp/TensorRT.tar -C /usr/local/
|
||||
mv /usr/local/TensorRT-${TENSOR_RT_VERSION} /usr/local/tensorrt
|
||||
mv /usr/local/TensorRT-${TRT_VER} /usr/local/tensorrt
|
||||
pip install /usr/local/tensorrt/python/tensorrt-*-cp${PARSED_PY_VERSION}-*.whl
|
||||
rm -rf /tmp/TensorRT.tar
|
||||
echo 'export LD_LIBRARY_PATH=/usr/local/tensorrt/lib:$LD_LIBRARY_PATH' >> "${ENV}"
|
||||
|
||||
@ -33,7 +33,7 @@ For practical examples of H200's performance:
|
||||
**Max Throughput TP8:**
|
||||
an online chat agent scenario (ISL/OSL=80/200) with GPT3-175B on a full HGX (TP8) H200 is 1.6x more performant than H100.
|
||||
|
||||
<img src="media/H200launch_tps.png" alt="max throughput llama TP1" width="500" height="auto">
|
||||
<img src="media/H200launch_tps.png" alt="H200 TPS" width="500" height="auto">
|
||||
|
||||
<sub>Preliminary measured performance, subject to change.
|
||||
TensorRT-LLM v0.5.0, TensorRT v9.1.0.4. | Llama-70B: H100 FP8 BS 8, H200 FP8 BS 32 | GPT3-175B: H100 FP8 BS 64, H200 FP8 BS 128 </sub>
|
||||
|
||||
@ -322,6 +322,11 @@ batchSize, beamWidth]`_.
|
||||
that enabling that computation may have an impact on performance (the final
|
||||
LM head has to perform a matrix multiplication on all the context tokens
|
||||
instead of a just the last one),
|
||||
* `generationLogits`, is a tensor of values on the GPU (same datatype as the
|
||||
computation type) to store the logits for the generation. Its shape is
|
||||
`[batchSize, beamWidth, maxOutputLen-1, vocabSizePadded]`. This buffer will only be
|
||||
filled in if the TensorRT engine was built with the
|
||||
`gather_all_token_logits` parameter enabled.
|
||||
* `onTokenGenerated`, is a callback function invoked in the generation loop to
|
||||
pass newly generated tokens to the caller while the loop continues to
|
||||
execute. An implementation of that callback must accept the output `ids`
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user