mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
Update TensorRT-LLM (#787)
* Update TensorRT-LLM --------- Co-authored-by: Shixiaowei02 <39303645+Shixiaowei02@users.noreply.github.com>
This commit is contained in:
parent
d37b507f41
commit
deaae40bd7
@ -45,4 +45,5 @@ repos:
|
||||
args:
|
||||
- --skip=".git,3rdparty"
|
||||
- --exclude-file=examples/whisper/tokenizer.py
|
||||
- --ignore-words-list=rouge,inout,atleast,strat
|
||||
- --ignore-words-list=rouge,inout,atleast,strat,nd
|
||||
exclude: 'tests/llm-test-defs/turtle/test_input_files'
|
||||
|
||||
25
README.md
25
README.md
@ -108,9 +108,7 @@ concepts used in TensorRT-LLM, we recommend you to read the following
|
||||
|
||||
## Installation
|
||||
|
||||
*For Windows installation, see [`Windows`](windows/README.md).*
|
||||
|
||||
TensorRT-LLM must be built from source, instructions can be found
|
||||
The documentation for installing TensorRT-LLM can be found
|
||||
[here](./docs/source/installation.md). An image of a Docker container with
|
||||
TensorRT-LLM and its Triton Inference Server Backend will be made available
|
||||
soon.
|
||||
@ -118,6 +116,8 @@ soon.
|
||||
The remaining commands in that document must be executed from the TensorRT-LLM
|
||||
container.
|
||||
|
||||
*For Windows installation, see [`Windows`](windows/README.md).*
|
||||
|
||||
## Quick Start
|
||||
|
||||
To create a TensorRT engine for an existing model, there are 3 steps:
|
||||
@ -206,13 +206,17 @@ Lovelace architectures. Certain limitations may, however, apply.
|
||||
Various numerical precisions are supported in TensorRT-LLM. The support for
|
||||
some of those numerical features require specific architectures:
|
||||
|
||||
| | FP32 | FP16 | BF16 | FP8 | INT8 | INT4 |
|
||||
| :------------------ | :--- | :--- | :--- | :--- | :--- | :--- |
|
||||
| Volta (SM70) | Y | Y | N | N | Y | Y |
|
||||
| Turing (SM75) | Y | Y | N | N | Y | Y |
|
||||
| Ampere (SM80, SM86) | Y | Y | Y | N | Y | Y |
|
||||
| Ada-Lovelace (SM89) | Y | Y | Y | Y | Y | Y |
|
||||
| Hopper (SM90) | Y | Y | Y | Y | Y | Y |
|
||||
| | FP32 | FP16 | BF16 | FP8 | INT8 | INT4 |
|
||||
| :------------------ | :--- | :--- | :--- | :--- | :---- | :---- |
|
||||
| Volta (SM70) | Y | Y | N | N | Y (1) | Y (2) |
|
||||
| Turing (SM75) | Y | Y | N | N | Y (1) | Y (2) |
|
||||
| Ampere (SM80, SM86) | Y | Y | Y | N | Y | Y (3) |
|
||||
| Ada-Lovelace (SM89) | Y | Y | Y | Y | Y | Y |
|
||||
| Hopper (SM90) | Y | Y | Y | Y | Y | Y |
|
||||
|
||||
(1) INT8 SmoothQuant is not supported on SM70 and SM75.<br>
|
||||
(2) INT4 AWQ and GPTQ are not supported on SM < 80.<br>
|
||||
(3) INT4 AWQ and GPTQ with FP8 activations require SM >= 89.
|
||||
|
||||
In this release of TensorRT-LLM, the support for FP8 and quantized data types
|
||||
(INT8 or INT4) is not implemented for all the models. See the
|
||||
@ -267,6 +271,7 @@ The list of supported models is:
|
||||
* [MPT](examples/mpt)
|
||||
* [mT5](examples/enc_dec)
|
||||
* [OPT](examples/opt)
|
||||
* [Phi-1.5/Phi-2](examples/phi)
|
||||
* [Qwen](examples/qwen)
|
||||
* [Replit Code](examples/mpt)
|
||||
* [SantaCoder](examples/gpt)
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION &
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION &
|
||||
# AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may not
|
||||
|
||||
@ -9,7 +9,7 @@ multiple GPUs or multiple nodes with multiple GPUs.
|
||||
|
||||
Please follow the [`installation document`](../../docs/source/installation.md) to build TensorRT-LLM.
|
||||
|
||||
Note that the benchmarking source code for C++ runtime is not built by default, you can use the argument `--benchmarks` in [`build_wheel.py`](../../scripts/build_wheel.py) to build that.
|
||||
Note that the benchmarking source code for C++ runtime is not built by default, you can use the argument `--benchmarks` in [`build_wheel.py`](source:scripts/build_wheel.py) to build the corresponding executable.
|
||||
|
||||
Windows users: Follow the
|
||||
[`Windows installation document`](../../windows/README.md)
|
||||
@ -22,7 +22,7 @@ instead, and be sure to set DLL paths as specified in
|
||||
|
||||
Before you launch C++ benchmarking, please make sure that you have already built engine(s) using TensorRT-LLM API, C++ benchmarking code cannot generate engine(s) for you.
|
||||
|
||||
You can use the [`build.py`](../python/build.py) script to build the engine(s). Alternatively, if you have already benchmarked Python Runtime, you can reuse the engine(s) built by benchmarking code, please see that [`document`](../python/README.md).
|
||||
You can use the [`build.py`](source:benchmarks/python/build.py) script to build the engine(s). Alternatively, if you have already benchmarked Python Runtime, you can reuse the engine(s) built previously, please see that [`document`](../python/README.md).
|
||||
|
||||
#### Launch benchmarking
|
||||
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@ -267,13 +267,19 @@ public:
|
||||
batch_scheduler::SchedulerPolicy schedulerPolicy, TrtGptModelOptionalParams const& optionalParams,
|
||||
std::shared_ptr<Recorder> recorder, std::optional<uint64_t> terminateReqId)
|
||||
{
|
||||
ReturnBatchManagerStatsCallback iterationDataCallback{nullptr};
|
||||
if (optionalParams.logIterationData)
|
||||
{
|
||||
iterationDataCallback = [this](const std::string& s) { return TLLM_LOG_INFO(s); };
|
||||
}
|
||||
|
||||
mBatchManager = std::make_shared<GptManager>(
|
||||
trtEnginePath, modelType, maxBeamWidth, schedulerPolicy,
|
||||
[this](int max_num_requests) { return getInferenceRequests(max_num_requests); },
|
||||
[this](uint64_t requestId, std::list<NamedTensor> response_tensors, bool final_response,
|
||||
const std::string& errMsg)
|
||||
{ return sendResponse(requestId, response_tensors, final_response, errMsg); },
|
||||
nullptr, nullptr, optionalParams, terminateReqId);
|
||||
nullptr, iterationDataCallback, optionalParams, terminateReqId);
|
||||
mRecorder = recorder;
|
||||
mTerminateReqId = terminateReqId;
|
||||
}
|
||||
@ -459,8 +465,14 @@ std::shared_ptr<InferenceRequest> makeRequest(std::uint64_t reqId,
|
||||
request->setMaxNewTokens(
|
||||
bufferManager.copyFrom(&request_output_len, ITensor::makeShape({1, 1}), MemoryType::kPINNED));
|
||||
request->setBeamWidth(beamWidthTensor);
|
||||
request->setEndId(eosId);
|
||||
request->setPadId(padId);
|
||||
if (eosId != nullptr)
|
||||
{
|
||||
request->setEndId(eosId);
|
||||
}
|
||||
if (padId != nullptr)
|
||||
{
|
||||
request->setPadId(padId);
|
||||
}
|
||||
return request;
|
||||
}
|
||||
|
||||
@ -571,6 +583,8 @@ int main(int argc, char* argv[])
|
||||
|
||||
options.add_options()("log_level", "Choose log level between verbose/info/warning/error/internal_error.",
|
||||
cxxopts::value<std::string>()->default_value("error"));
|
||||
options.add_options()(
|
||||
"log_iteration_data", "On each decoder iteration, print batch state metadata.", cxxopts::value<bool>());
|
||||
|
||||
auto result = options.parse(argc, argv);
|
||||
|
||||
@ -618,6 +632,11 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
optionalParams.enableTrtOverlap = result["enable_trt_overlap"].as<bool>();
|
||||
}
|
||||
// Argument: Enable batch stats output
|
||||
if (result.count("log_iteration_data"))
|
||||
{
|
||||
optionalParams.logIterationData = result["log_iteration_data"].as<bool>();
|
||||
}
|
||||
|
||||
std::optional<int32_t> padId;
|
||||
// Argument: Padding token id
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@ -253,6 +253,7 @@ int main(int argc, char* argv[])
|
||||
options.add_options()("gen_micro_batch_size", "Batch size for generation phase.", cxxopts::value<int>());
|
||||
options.add_options()("max_attention_window", "Max kv cache length per sequence.", cxxopts::value<int>());
|
||||
options.add_options()("max_tokens_in_paged_kvcache", "Max tokens in paged K-V Cache.", cxxopts::value<int>());
|
||||
options.add_options()("sink_token_len", "Sink token length in kv cache per sequence.", cxxopts::value<int>());
|
||||
options.add_options()(
|
||||
"kv_cache_free_gpu_mem_fraction", "K-V Cache Free Gpu Mem Fraction.", cxxopts::value<float>());
|
||||
|
||||
@ -353,6 +354,11 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
sessionConfig.kvCacheConfig.maxAttentionWindow = result["max_attention_window"].as<int>();
|
||||
}
|
||||
// Argument: Sink token length
|
||||
if (result.count("sink_token_len"))
|
||||
{
|
||||
sessionConfig.kvCacheConfig.sinkTokenLength = result["sink_token_len"].as<int>();
|
||||
}
|
||||
// Argument: K-V Cache Free Gpu Mem Fraction
|
||||
if (result.count("kv_cache_free_gpu_mem_fraction"))
|
||||
{
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@ -231,7 +231,6 @@ _allowed_configs = {
|
||||
builder_opt=None,
|
||||
pre_norm=False,
|
||||
do_layer_norm_before=False,
|
||||
use_custom_all_reduce=False,
|
||||
)),
|
||||
"opt_2.7b":
|
||||
ModelConfig(name="opt_2.7b",
|
||||
@ -250,7 +249,6 @@ _allowed_configs = {
|
||||
builder_opt=None,
|
||||
pre_norm=False,
|
||||
do_layer_norm_before=True,
|
||||
use_custom_all_reduce=False,
|
||||
)),
|
||||
"opt_6.7b":
|
||||
ModelConfig(name="opt_6.7b",
|
||||
@ -269,7 +267,6 @@ _allowed_configs = {
|
||||
builder_opt=None,
|
||||
pre_norm=False,
|
||||
do_layer_norm_before=True,
|
||||
use_custom_all_reduce=False,
|
||||
)),
|
||||
"opt_66b":
|
||||
ModelConfig(name="opt_66b",
|
||||
@ -288,7 +285,6 @@ _allowed_configs = {
|
||||
builder_opt=None,
|
||||
pre_norm=True,
|
||||
do_layer_norm_before=True,
|
||||
use_custom_all_reduce=False,
|
||||
)),
|
||||
"llama_7b":
|
||||
ModelConfig(name="llama_7b",
|
||||
@ -515,7 +511,6 @@ _allowed_configs = {
|
||||
max_output_len=200,
|
||||
builder_opt=None,
|
||||
remove_input_padding=False,
|
||||
use_custom_all_reduce=False,
|
||||
)),
|
||||
"bloom_560m":
|
||||
ModelConfig(name="bloom_560m",
|
||||
@ -532,7 +527,6 @@ _allowed_configs = {
|
||||
max_input_len=1024,
|
||||
max_output_len=1024,
|
||||
builder_opt=None,
|
||||
use_custom_all_reduce=False,
|
||||
)),
|
||||
"bloom_176b":
|
||||
ModelConfig(name="bloom_176b",
|
||||
@ -549,7 +543,6 @@ _allowed_configs = {
|
||||
max_input_len=1024,
|
||||
max_output_len=1024,
|
||||
builder_opt=None,
|
||||
use_custom_all_reduce=False,
|
||||
)),
|
||||
"bert_base":
|
||||
ModelConfig(name="bert_base",
|
||||
@ -596,7 +589,7 @@ _allowed_configs = {
|
||||
num_heads=32,
|
||||
hidden_size=2048,
|
||||
vocab_size=50304,
|
||||
hidden_act=None,
|
||||
hidden_act='gelu',
|
||||
n_positions=2048,
|
||||
max_batch_size=256,
|
||||
max_input_len=1024,
|
||||
@ -617,7 +610,7 @@ _allowed_configs = {
|
||||
num_kv_heads=1,
|
||||
hidden_size=4544,
|
||||
vocab_size=65024,
|
||||
hidden_act=None,
|
||||
hidden_act='gelu',
|
||||
n_positions=2048,
|
||||
max_batch_size=128,
|
||||
max_input_len=512,
|
||||
@ -638,7 +631,7 @@ _allowed_configs = {
|
||||
num_kv_heads=8,
|
||||
hidden_size=8192,
|
||||
vocab_size=65024,
|
||||
hidden_act=None,
|
||||
hidden_act='gelu',
|
||||
n_positions=2048,
|
||||
max_batch_size=64,
|
||||
max_input_len=512,
|
||||
@ -659,7 +652,7 @@ _allowed_configs = {
|
||||
num_kv_heads=8,
|
||||
hidden_size=14848,
|
||||
vocab_size=65024,
|
||||
hidden_act=None,
|
||||
hidden_act='gelu',
|
||||
n_positions=2048,
|
||||
max_batch_size=8,
|
||||
max_input_len=1024,
|
||||
@ -921,6 +914,112 @@ _allowed_configs = {
|
||||
max_output_len=200,
|
||||
builder_opt=None,
|
||||
)),
|
||||
"baichuan_7b":
|
||||
ModelConfig(name="baichuan_7b",
|
||||
family="baichuan_7b",
|
||||
benchmark_type="gpt",
|
||||
build_config=BuildConfig(
|
||||
num_layers=32,
|
||||
num_heads=32,
|
||||
hidden_size=4096,
|
||||
vocab_size=64000,
|
||||
hidden_act='silu',
|
||||
n_positions=4096,
|
||||
inter_size=11008,
|
||||
max_batch_size=128,
|
||||
max_input_len=512,
|
||||
max_output_len=200,
|
||||
builder_opt=None,
|
||||
)),
|
||||
"baichuan2_7b_chat":
|
||||
ModelConfig(name="baichuan2_7b_chat",
|
||||
family="baichuan_7b",
|
||||
benchmark_type="gpt",
|
||||
build_config=BuildConfig(
|
||||
num_layers=32,
|
||||
num_heads=32,
|
||||
hidden_size=4096,
|
||||
vocab_size=125696,
|
||||
hidden_act='silu',
|
||||
n_positions=4096,
|
||||
inter_size=11008,
|
||||
max_batch_size=128,
|
||||
max_input_len=512,
|
||||
max_output_len=200,
|
||||
builder_opt=None,
|
||||
)),
|
||||
"baichuan_13b_chat":
|
||||
ModelConfig(name="baichuan_13b_chat",
|
||||
family="baichuan_13b",
|
||||
benchmark_type="gpt",
|
||||
build_config=BuildConfig(
|
||||
num_layers=40,
|
||||
num_heads=40,
|
||||
hidden_size=5120,
|
||||
vocab_size=64000,
|
||||
hidden_act='silu',
|
||||
n_positions=4096,
|
||||
inter_size=13696,
|
||||
max_batch_size=64,
|
||||
max_input_len=512,
|
||||
max_output_len=200,
|
||||
builder_opt=None,
|
||||
)),
|
||||
"baichuan2_13b_chat":
|
||||
ModelConfig(name="baichuan2_13b_chat",
|
||||
family="baichuan_13b",
|
||||
benchmark_type="gpt",
|
||||
build_config=BuildConfig(
|
||||
num_layers=40,
|
||||
num_heads=40,
|
||||
hidden_size=5120,
|
||||
vocab_size=125696,
|
||||
hidden_act='silu',
|
||||
n_positions=4096,
|
||||
inter_size=13696,
|
||||
max_batch_size=64,
|
||||
max_input_len=512,
|
||||
max_output_len=200,
|
||||
builder_opt=None,
|
||||
)),
|
||||
"internlm_chat_7b":
|
||||
ModelConfig(name="internlm_chat_7b",
|
||||
family="internlm",
|
||||
benchmark_type="gpt",
|
||||
build_config=BuildConfig(
|
||||
num_layers=32,
|
||||
num_heads=32,
|
||||
num_kv_heads=32,
|
||||
hidden_size=4096,
|
||||
vocab_size=103168,
|
||||
hidden_act='silu',
|
||||
n_positions=2048,
|
||||
inter_size=11008,
|
||||
max_batch_size=128,
|
||||
max_input_len=512,
|
||||
max_output_len=200,
|
||||
builder_opt=None,
|
||||
bias=True,
|
||||
)),
|
||||
"internlm_chat_20b":
|
||||
ModelConfig(name="internlm_chat_20b",
|
||||
family="internlm",
|
||||
benchmark_type="gpt",
|
||||
build_config=BuildConfig(
|
||||
num_layers=60,
|
||||
num_heads=40,
|
||||
num_kv_heads=40,
|
||||
hidden_size=5120,
|
||||
vocab_size=103168,
|
||||
hidden_act='silu',
|
||||
n_positions=4096,
|
||||
inter_size=13824,
|
||||
max_batch_size=64,
|
||||
max_input_len=512,
|
||||
max_output_len=200,
|
||||
builder_opt=None,
|
||||
bias=False,
|
||||
)),
|
||||
}
|
||||
|
||||
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@ -89,7 +89,14 @@ def parse_arguments():
|
||||
type=float,
|
||||
default="0",
|
||||
help=('Specify Top-P value of decoding.'))
|
||||
|
||||
parser.add_argument(
|
||||
'--profiling_verbosity',
|
||||
type=str,
|
||||
default='layer_names_only',
|
||||
choices=['layer_names_only', 'detailed', 'none'],
|
||||
help=
|
||||
'The profiling verbosity for the generated TRT engine. Set to detailed can inspect tactic choices and kernel parameters.'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--log_level',
|
||||
type=str,
|
||||
@ -180,6 +187,10 @@ def parse_arguments():
|
||||
help=
|
||||
'Quick sanity check with num_layer=1; will be silently ignored if --engine_dir is specified.'
|
||||
)
|
||||
parser.add_argument('--strongly_typed',
|
||||
default=False,
|
||||
action='store_true',
|
||||
help='This option will reduce the building time.')
|
||||
|
||||
parser.add_argument('--csv',
|
||||
default=False,
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@ -122,6 +122,10 @@ def parse_arguments():
|
||||
default=False,
|
||||
action='store_true',
|
||||
help="Build engines serially")
|
||||
parser.add_argument('--strongly_typed',
|
||||
default=False,
|
||||
action='store_true',
|
||||
help='This option will reduce the building time.')
|
||||
|
||||
parser.add_argument(
|
||||
'--rank',
|
||||
@ -143,23 +147,19 @@ def parse_arguments():
|
||||
|
||||
def get_quant_mode(quantization):
|
||||
quant_mode = QuantMode(0)
|
||||
strongly_typed = False
|
||||
use_smooth_quant = False
|
||||
per_token = False
|
||||
per_channel = False
|
||||
weight_only_precision = 'int8'
|
||||
|
||||
if quantization == "fp8":
|
||||
strongly_typed = True
|
||||
quant_mode = quant_mode.set_fp8_qdq()
|
||||
quant_mode = quant_mode.set_fp8_kv_cache()
|
||||
|
||||
elif quantization == "fp8_gemm":
|
||||
strongly_typed = True
|
||||
quant_mode = quant_mode.set_fp8_qdq()
|
||||
|
||||
elif quantization == "fp8_kv_cache":
|
||||
strongly_typed = True
|
||||
quant_mode = quant_mode.set_fp8_kv_cache()
|
||||
|
||||
elif quantization == "int8_sq_per_tensor":
|
||||
@ -205,7 +205,7 @@ def get_quant_mode(quantization):
|
||||
else:
|
||||
raise Exception(f'Unexpected quantization: {quantization}')
|
||||
|
||||
return quant_mode, strongly_typed, use_smooth_quant, weight_only_precision
|
||||
return quant_mode, use_smooth_quant, weight_only_precision
|
||||
|
||||
|
||||
def build_gpt(args):
|
||||
@ -223,6 +223,9 @@ def build_gpt(args):
|
||||
if not args.serial_build:
|
||||
torch.cuda.set_device(runtime_rank)
|
||||
|
||||
strongly_typed = args.strongly_typed
|
||||
if args.quantization is not None and "fp8" in args.quantization:
|
||||
strongly_typed = True
|
||||
num_kv_heads = build_config['num_heads'] \
|
||||
if build_config['num_kv_heads'] is None else build_config['num_kv_heads']
|
||||
apply_query_key_layer_scaling = False
|
||||
@ -234,7 +237,7 @@ def build_gpt(args):
|
||||
if args.max_output_len is None else args.max_output_len
|
||||
max_beam_width = build_config['max_beam_width'] \
|
||||
if args.max_beam_width is None else args.max_beam_width
|
||||
quant_mode, strongly_typed, use_smooth_quant, weight_only_precision = get_quant_mode(
|
||||
quant_mode, use_smooth_quant, weight_only_precision = get_quant_mode(
|
||||
args.quantization)
|
||||
use_weight_only = quant_mode.is_weight_only()
|
||||
|
||||
@ -243,6 +246,7 @@ def build_gpt(args):
|
||||
name=args.model,
|
||||
precision=args.dtype,
|
||||
timing_cache=None,
|
||||
profiling_verbosity=args.profiling_verbosity,
|
||||
tensor_parallel=world_size, # TP only
|
||||
parallel_build=True,
|
||||
num_layers=build_config['num_layers'],
|
||||
@ -468,21 +472,117 @@ def build_gpt(args):
|
||||
config = PretrainedConfig.from_dict(config)
|
||||
tensorrt_llm_model = tensorrt_llm.models.BloomForCausalLM(config)
|
||||
elif family == "falcon":
|
||||
tensorrt_llm_model = tensorrt_llm.models.FalconForCausalLM(
|
||||
config = {
|
||||
'architecture':
|
||||
'FalconForCausalLM',
|
||||
'dtype':
|
||||
args.dtype,
|
||||
'num_hidden_layers':
|
||||
build_config['num_layers'],
|
||||
'num_attention_heads':
|
||||
build_config['num_heads'],
|
||||
'num_key_value_heads':
|
||||
build_config['num_heads'] if build_config['num_kv_heads'] is None
|
||||
else build_config['num_kv_heads'],
|
||||
'hidden_size':
|
||||
build_config['hidden_size'],
|
||||
'vocab_size':
|
||||
build_config['vocab_size'],
|
||||
'position_embedding_type':
|
||||
'alibi_with_scale'
|
||||
if build_config['use_alibi'] else 'rope_gpt_neox',
|
||||
'max_position_embeddings':
|
||||
build_config['n_positions'],
|
||||
'hidden_act':
|
||||
build_config['hidden_act'],
|
||||
'quantization': {
|
||||
'use_smooth_quant':
|
||||
quant_mode.has_act_and_weight_quant(),
|
||||
'per_channel':
|
||||
quant_mode.has_per_channel_scaling(),
|
||||
'per_token':
|
||||
quant_mode.has_per_token_dynamic_scaling(),
|
||||
'per_group':
|
||||
quant_mode.has_per_group_scaling(),
|
||||
'group_size':
|
||||
128,
|
||||
'int8_kv_cache':
|
||||
quant_mode.has_int8_kv_cache(),
|
||||
'enable_fp8':
|
||||
quant_mode.has_fp8_qdq(),
|
||||
'fp8_kv_cache':
|
||||
quant_mode.has_fp8_kv_cache(),
|
||||
'use_weight_only':
|
||||
quant_mode.is_weight_only(),
|
||||
'weight_only_precision':
|
||||
'int8' if quant_mode.is_int8_weight_only() else 'int4',
|
||||
},
|
||||
'mapping': {
|
||||
'world_size': world_size,
|
||||
'tp_size': world_size
|
||||
},
|
||||
'bias':
|
||||
build_config['bias'],
|
||||
'parallel_attention':
|
||||
build_config['parallel_attention'],
|
||||
'new_decoder_architecture':
|
||||
build_config['new_decoder_architecture'],
|
||||
}
|
||||
if quant_mode.is_weight_only() and quant_mode.has_per_group_scaling():
|
||||
config['quantization'].update({
|
||||
'zero': False,
|
||||
'pre_quant_scale': True,
|
||||
'exclude_modules': [],
|
||||
})
|
||||
config = PretrainedConfig.from_dict(config)
|
||||
tensorrt_llm_model = tensorrt_llm.models.FalconForCausalLM(config)
|
||||
elif family == "baichuan_7b":
|
||||
tensorrt_llm_model = tensorrt_llm.models.BaichuanForCausalLM(
|
||||
num_layers=build_config['num_layers'],
|
||||
num_heads=build_config['num_heads'],
|
||||
num_kv_heads=None,
|
||||
hidden_size=build_config['hidden_size'],
|
||||
vocab_size=build_config['vocab_size'],
|
||||
hidden_act=build_config['hidden_act'],
|
||||
max_position_embeddings=build_config['n_positions'],
|
||||
position_embedding_type=PositionEmbeddingType.rope_gpt_neox,
|
||||
dtype=kv_dtype,
|
||||
mlp_hidden_size=build_config['inter_size'],
|
||||
mapping=tensorrt_llm.Mapping(world_size=world_size,
|
||||
tp_size=world_size),
|
||||
quant_mode=quant_mode)
|
||||
elif family == "baichuan_13b":
|
||||
tensorrt_llm_model = tensorrt_llm.models.BaichuanForCausalLM(
|
||||
num_layers=build_config['num_layers'],
|
||||
num_heads=build_config['num_heads'],
|
||||
num_kv_heads=None,
|
||||
hidden_size=build_config['hidden_size'],
|
||||
vocab_size=build_config['vocab_size'],
|
||||
hidden_act=build_config['hidden_act'],
|
||||
max_position_embeddings=build_config['n_positions'],
|
||||
position_embedding_type=PositionEmbeddingType.alibi,
|
||||
dtype=kv_dtype,
|
||||
mlp_hidden_size=build_config['inter_size'],
|
||||
mapping=tensorrt_llm.Mapping(world_size=world_size,
|
||||
tp_size=world_size),
|
||||
quant_mode=quant_mode)
|
||||
elif family == "internlm":
|
||||
tensorrt_llm_model = tensorrt_llm.models.LLaMAForCausalLM(
|
||||
num_layers=build_config['num_layers'],
|
||||
num_heads=build_config['num_heads'],
|
||||
num_kv_heads=num_kv_heads,
|
||||
hidden_size=build_config['hidden_size'],
|
||||
vocab_size=build_config['vocab_size'],
|
||||
hidden_act=build_config['hidden_act'],
|
||||
max_position_embeddings=build_config['n_positions'],
|
||||
dtype=kv_dtype,
|
||||
bias=build_config['bias'],
|
||||
quant_mode=quant_mode,
|
||||
use_alibi=build_config['use_alibi'],
|
||||
new_decoder_architecture=build_config['new_decoder_architecture'],
|
||||
parallel_attention=build_config['parallel_attention'],
|
||||
mlp_hidden_size=build_config['inter_size'],
|
||||
mapping=tensorrt_llm.Mapping(world_size=world_size,
|
||||
tp_size=world_size))
|
||||
tp_size=world_size), # TP only
|
||||
quant_mode=quant_mode,
|
||||
embedding_sharding_dim=1,
|
||||
use_fused_mlp=False,
|
||||
attn_bias=build_config['bias'])
|
||||
else:
|
||||
raise Exception(f'Unexpected model: {args.model}')
|
||||
|
||||
@ -501,7 +601,7 @@ def build_gpt(args):
|
||||
"zero": True,
|
||||
"pre_quant_scale": False,
|
||||
}
|
||||
if family not in ['opt', 'bloom']:
|
||||
if family not in ['opt', 'bloom', 'falcon']:
|
||||
tensorrt_llm_model = quantize_model(tensorrt_llm_model, quant_mode,
|
||||
**quant_kwargs)
|
||||
|
||||
@ -550,7 +650,7 @@ def build_gpt(args):
|
||||
max_input_len,
|
||||
max_output_len, True,
|
||||
max_beam_width)
|
||||
if family in ['opt', 'bloom']:
|
||||
if family in ['opt', 'bloom', 'falcon']:
|
||||
tensorrt_llm_model(**inputs)
|
||||
else:
|
||||
tensorrt_llm_model(*inputs)
|
||||
@ -604,6 +704,7 @@ def build_bert(args):
|
||||
name=args.model,
|
||||
precision=args.dtype,
|
||||
timing_cache=None,
|
||||
profiling_verbosity=args.profiling_verbosity,
|
||||
tensor_parallel=world_size, # TP only
|
||||
parallel_build=True,
|
||||
num_layers=build_config['num_layers'],
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@ -56,7 +56,7 @@ class GPTBenchmark(BaseBenchmark):
|
||||
if args.max_output_len is not None:
|
||||
self.max_output_len = args.max_output_len
|
||||
|
||||
self.quant_mode, _, _, _ = get_quant_mode(args.quantization)
|
||||
self.quant_mode, _, _ = get_quant_mode(args.quantization)
|
||||
self.enable_fp8 = self.quant_mode.has_fp8_qdq()
|
||||
self.fp8_kv_cache = self.quant_mode.has_fp8_kv_cache()
|
||||
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*
|
||||
* Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*
|
||||
* Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*
|
||||
* Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*
|
||||
* Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*
|
||||
* Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
@ -170,6 +170,7 @@ public:
|
||||
\
|
||||
void set##funcName(TensorPtr const& tensor) \
|
||||
{ \
|
||||
TLLM_CHECK_WITH_INFO(tensor, "Cannot set nullptr when calling %s", __FUNCTION__); \
|
||||
mInputTensors[tensorName] = tensor; \
|
||||
}
|
||||
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@ -31,9 +31,11 @@ public:
|
||||
|
||||
explicit KvCacheConfig(std::optional<SizeType> maxTokens = std::nullopt,
|
||||
std::optional<SizeType> maxAttentionWindow = std::nullopt,
|
||||
std::optional<SizeType> sinkTokenLength = std::nullopt,
|
||||
std::optional<float> freeGpuMemoryFraction = std::nullopt, bool enableBlockReuse = false)
|
||||
: maxTokens{maxTokens}
|
||||
, maxAttentionWindow{maxAttentionWindow}
|
||||
, sinkTokenLength{sinkTokenLength}
|
||||
, freeGpuMemoryFraction{freeGpuMemoryFraction}
|
||||
, enableBlockReuse(enableBlockReuse)
|
||||
{
|
||||
@ -41,10 +43,9 @@ public:
|
||||
|
||||
std::optional<SizeType> maxTokens;
|
||||
std::optional<SizeType> maxAttentionWindow;
|
||||
std::optional<SizeType> sinkTokenLength;
|
||||
std::optional<float> freeGpuMemoryFraction;
|
||||
bool enableBlockReuse;
|
||||
|
||||
static constexpr auto kDefaultGpuMemFraction = 0.85f;
|
||||
};
|
||||
|
||||
} // namespace tensorrt_llm::batch_manager::kv_cache_manager
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*
|
||||
* Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
@ -129,6 +129,8 @@ public:
|
||||
|
||||
[[nodiscard]] bool isFull() const;
|
||||
|
||||
[[nodiscard]] bool isShared() const;
|
||||
|
||||
private:
|
||||
// Linear index of block in pool
|
||||
SizeType mBlockIdx;
|
||||
@ -199,6 +201,11 @@ public:
|
||||
mCacheBlockIds.at(beamIdx).push_back(blockIdx);
|
||||
}
|
||||
|
||||
void changeCacheBlock(SizeType beamIdx, SizeType pagedBlockIdx, SizeType blockIdx)
|
||||
{
|
||||
mCacheBlockIds.at(beamIdx).at(pagedBlockIdx) = blockIdx;
|
||||
}
|
||||
|
||||
void clearCacheBlocks()
|
||||
{
|
||||
for (auto& beamBlockIds : mCacheBlockIds)
|
||||
@ -259,12 +266,14 @@ public:
|
||||
void addSequence(GenerationRequest& sequence, SizeType inputLength, std::shared_ptr<LlmRequest> const& llmRequest);
|
||||
|
||||
//! \brief Assign blocks for new sequence. Does not try to reuse blocks.
|
||||
void addSequence(GenerationRequest& sequence, SizeType inputLength, bool enableCyclicKvCache);
|
||||
void addSequence(GenerationRequest& sequence, SizeType numBlocks, SizeType unsharedBlockIdx);
|
||||
|
||||
//! \brief Allocate new block for each beam of the sequence.
|
||||
//! \details Might free cached blocks if no free blocks are available.
|
||||
void allocateBlock(GenerationRequest& sequence, bool shareAmongBeams = false);
|
||||
|
||||
void replaceSharedBlock(GenerationRequest& sequence, SizeType blockIdx);
|
||||
|
||||
//! \brief Release blocks of the sequence. Store blocks for reuse if llmReqeust is provided.
|
||||
void releaseBlocks(GenerationRequest& sequence, std::shared_ptr<LlmRequest> const& llmRequest = nullptr);
|
||||
|
||||
@ -354,8 +363,8 @@ public:
|
||||
|
||||
KVCacheManager(SizeType numLayers, SizeType numHeads, SizeType numKvHeads, SizeType hiddenSize,
|
||||
SizeType tokensPerBlock, SizeType maxNumBlocks, SizeType maxNumSequences, SizeType maxBeamWidth,
|
||||
SizeType maxBlocksPerSeq, SizeType maxAttentionWindow, nvinfer1::DataType dtype, CudaStreamPtr stream,
|
||||
bool enableBlockReuse = false);
|
||||
SizeType maxBlocksPerSeq, SizeType maxAttentionWindow, SizeType sinkTokenLength, bool useOneMoreBlock,
|
||||
nvinfer1::DataType dtype, CudaStreamPtr stream, bool enableBlockReuse = false);
|
||||
|
||||
void startScheduling();
|
||||
|
||||
@ -466,6 +475,7 @@ private:
|
||||
void resetBlockPointers(SizeType seqSlotIdx, SizeType beamWidth);
|
||||
void cacheBlockPointers(GenerationRequest const& seq, SizeType seqSlotIdx);
|
||||
void cacheNewBlockPointers(GenerationRequest const& seq, SizeType seqSlotIdx);
|
||||
void updateNewBlockPointer(const GenerationRequest& seq, SizeType seqSlotIdx, SizeType blockIdx);
|
||||
|
||||
private:
|
||||
// Number of elements per one blocks
|
||||
@ -479,6 +489,14 @@ private:
|
||||
// Maximum kv cache length per sequence
|
||||
// Enable cyclic kv cache when it exceeds
|
||||
SizeType mMaxAttentionWindow;
|
||||
// Sink token length in the kv cache per sequence
|
||||
SizeType mSinkTokenLength;
|
||||
// Bubble token length
|
||||
SizeType mBubbleLength;
|
||||
// Maximum token length (including bubble)
|
||||
SizeType mMaxTokenNum;
|
||||
// Number of tokens in the sink blocks
|
||||
SizeType mSinkBlockTokenLength;
|
||||
// Pools
|
||||
std::vector<runtime::ITensor::SharedPtr> mPools;
|
||||
// Block manager
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*
|
||||
* Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*
|
||||
* Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@ -34,12 +34,14 @@ public:
|
||||
|
||||
explicit TrtGptModelOptionalParams(KvCacheConfig const& kvCacheConfig = KvCacheConfig{},
|
||||
std::optional<SizeType> maxNumSequences = std::nullopt, bool enableTrtOverlap = true,
|
||||
std::optional<std::vector<SizeType>> const& deviceIds = std::nullopt, bool normalizeLogProbs = true)
|
||||
std::optional<std::vector<SizeType>> const& deviceIds = std::nullopt, bool normalizeLogProbs = true,
|
||||
bool logIterationData = false)
|
||||
: kvCacheConfig{kvCacheConfig}
|
||||
, maxNumSequences{maxNumSequences}
|
||||
, enableTrtOverlap{enableTrtOverlap}
|
||||
, deviceIds(deviceIds)
|
||||
, normalizeLogProbs{normalizeLogProbs}
|
||||
, logIterationData{logIterationData}
|
||||
{
|
||||
}
|
||||
|
||||
@ -48,6 +50,7 @@ public:
|
||||
bool enableTrtOverlap;
|
||||
std::optional<std::vector<SizeType>> deviceIds;
|
||||
bool normalizeLogProbs;
|
||||
bool logIterationData;
|
||||
};
|
||||
|
||||
} // namespace tensorrt_llm::batch_manager
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*
|
||||
* Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*
|
||||
* Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*
|
||||
* Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*
|
||||
* Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*
|
||||
* Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
@ -29,11 +29,12 @@ class DecodingInput
|
||||
public:
|
||||
using TensorPtr = std::shared_ptr<ITensor const>;
|
||||
|
||||
DecodingInput(
|
||||
SizeType maxLength, SizeType maxAttentionWindow, SizeType batchSize, TensorPtr logits, TensorPtr endIds)
|
||||
DecodingInput(SizeType maxLength, SizeType maxAttentionWindow, SizeType sinkTokenLength, SizeType batchSize,
|
||||
TensorPtr logits, TensorPtr endIds)
|
||||
: step{maxLength}
|
||||
, maxLength{maxLength}
|
||||
, maxAttentionWindow{maxAttentionWindow}
|
||||
, sinkTokenLength{sinkTokenLength}
|
||||
, batchSize{batchSize}
|
||||
, logits{std::move(logits)}
|
||||
, endIds{std::move(endIds)}
|
||||
@ -46,6 +47,7 @@ public:
|
||||
SizeType step;
|
||||
SizeType maxLength;
|
||||
SizeType maxAttentionWindow;
|
||||
SizeType sinkTokenLength;
|
||||
SizeType batchSize;
|
||||
TensorPtr logits; // [batchSize, beamWidth, vocabSizePadded], on gpu
|
||||
TensorPtr endIds; // [batchSize * beamWidth], on gpu
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*
|
||||
* Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*
|
||||
* Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*
|
||||
* Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*
|
||||
* Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*
|
||||
* Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
@ -45,8 +45,8 @@ public:
|
||||
GptDecoderBatch(std::size_t vocabSize, std::size_t vocabSizePadded, CudaStreamPtr stream);
|
||||
|
||||
//! Setup the decoder before calling `forward()`
|
||||
void setup(SizeType maxBatchSize, SizeType maxBeamWidth, SizeType maxAttentionWindow, SizeType maxSequenceLength,
|
||||
SizeType maxTokensPerStep, nvinfer1::DataType dtype) override;
|
||||
void setup(SizeType maxBatchSize, SizeType maxBeamWidth, SizeType maxAttentionWindow, SizeType sinkTokenLength,
|
||||
SizeType maxSequenceLength, SizeType maxTokensPerStep, nvinfer1::DataType dtype) override;
|
||||
|
||||
//! @brief Initialize the decoder at `batchIdx` with a new `request`.
|
||||
void newRequest(
|
||||
@ -201,6 +201,7 @@ private:
|
||||
// decoding accept by logits kernel, on gpu
|
||||
SizeType mMaxSequenceLength{};
|
||||
SizeType mMaxAttentionWindow{};
|
||||
SizeType mSinkTokenLength{};
|
||||
SizeType mActualBatchSize{};
|
||||
SizeType mMaxTokensPerStep{};
|
||||
};
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*
|
||||
* Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*
|
||||
* Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*
|
||||
* Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
@ -145,10 +145,10 @@ private:
|
||||
|
||||
void createContexts();
|
||||
void createBuffers(SizeType numMicroBatches);
|
||||
void createDecoders(SizeType batchSize, SizeType beamWidth, SizeType maxAttentionWindow, SizeType maxSequenceLength,
|
||||
nvinfer1::DataType logitsType, bool decoderPerRequest, SizeType numMicroBatches);
|
||||
void createDecoders(SizeType batchSize, SizeType beamWidth, SizeType maxAttentionWindow, SizeType sinkTokenLength,
|
||||
SizeType maxSequenceLength, nvinfer1::DataType logitsType, bool decoderPerRequest, SizeType numMicroBatches);
|
||||
void createKvCacheManager(SizeType batchSize, SizeType beamWidth, SizeType maxAttentionWindow,
|
||||
SizeType maxSequenceLength, KvCacheConfig const& config);
|
||||
SizeType sinkTokenLength, SizeType maxSequenceLength, KvCacheConfig const& config);
|
||||
void createCustomAllReduceWorkspace(SizeType batchSize, SizeType beamWidth, SizeType maxSequenceLength);
|
||||
|
||||
void executeContextStep(std::vector<GenerationInput> const& microBatchesInputs,
|
||||
@ -257,6 +257,7 @@ private:
|
||||
|
||||
SizeType mDecoderMaxSequenceLength{};
|
||||
SizeType mDecoderMaxAttentionWindow{};
|
||||
SizeType mDecoderSinkTokenLength{};
|
||||
|
||||
LoggerPtr mLogger;
|
||||
std::shared_ptr<TllmRuntime> mRuntime;
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*
|
||||
* Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*
|
||||
* Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*
|
||||
* Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
@ -75,7 +75,7 @@ public:
|
||||
|
||||
//! Setup the decoder before calling `forward()`, also calls reshapeBuffers
|
||||
virtual void setup(SizeType maxBatchSize, SizeType maxBeamWidth, SizeType maxAttentionWindow,
|
||||
SizeType maxSequenceLength, SizeType maxTokensPerStep, nvinfer1::DataType dtype)
|
||||
SizeType sinkTokenLength, SizeType maxSequenceLength, SizeType maxTokensPerStep, nvinfer1::DataType dtype)
|
||||
= 0;
|
||||
|
||||
//! @brief Initialize the decoder with new batch of inputs.
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*
|
||||
* Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
|
||||
/*
|
||||
* Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*
|
||||
* Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*
|
||||
* Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*
|
||||
* Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*
|
||||
* Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*
|
||||
* Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION &
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION &
|
||||
# AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may not
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:b5007e359b3c93562b81ea6dbf414a6cb98a88de2f82aebb740a044a2deb3946
|
||||
size 1846872
|
||||
oid sha256:327edb4d1e50392467f194cb8ccacad39d2d872d1f89aef79cafa203171a4734
|
||||
size 1858074
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:256cd1e9da6a7e77ed2981abbeca7fb18660b9469643791333d9a909c01bc601
|
||||
size 1860514
|
||||
oid sha256:e55ee683c569bde1fd18442152b201cf4ebb41bbe21c6c1c6abfc5bac6256e5f
|
||||
size 1873024
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
097b6a4d83e0c954f5b6510c0e871e87 libtensorrt_llm_batch_manager_static.a
|
||||
9c73d43e5b1a94dbd16c94237905719e libtensorrt_llm_batch_manager_static.pre_cxx11.a
|
||||
0067d2b225abceb20316ddb3a5cddf426ef25160 commit
|
||||
a4ea2e88effa61e397d984943fd47802 libtensorrt_llm_batch_manager_static.a
|
||||
94a233ead0aca9a51776a0bab70c59b4 libtensorrt_llm_batch_manager_static.pre_cxx11.a
|
||||
9251c03ebe80b196becd8ac3abbfa02c2a3273ad commit
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:c95be2b543ce79a591b12569fc5d245b76c72b9e0485b17bfb5f16bc46fa7029
|
||||
size 1775504
|
||||
oid sha256:0a8dc8411449452686afc7b4005cdb77905914edc9d5257d5d283b0dfc4eb9aa
|
||||
size 1790812
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:fe2ddf4d130704e82f0ab299b203408684432e9c111bf481db2b34a9a1023b83
|
||||
size 1763222
|
||||
oid sha256:fda2fdc9c2b3672e94f15927b6bbeb5321c436761c8d1b1de96f23f6807a351a
|
||||
size 1776536
|
||||
|
||||
@ -1,2 +1,2 @@
|
||||
cd974e5d12c72241dec0c6b439f2f7a0 libtensorrt_llm_batch_manager_static.a
|
||||
4e88c3c8609582273a1b026fadac4abd libtensorrt_llm_batch_manager_static.pre_cxx11.a
|
||||
1c62064ee5f68d76194bad877504b0d3 libtensorrt_llm_batch_manager_static.a
|
||||
4d7b4ff0c4c14865a18535cef6117ee9 libtensorrt_llm_batch_manager_static.pre_cxx11.a
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*
|
||||
* Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*
|
||||
* Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*
|
||||
* Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*
|
||||
* Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*
|
||||
* Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*
|
||||
* Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*
|
||||
* Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*
|
||||
* Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*
|
||||
* Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*
|
||||
* Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*
|
||||
* Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*
|
||||
* Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*
|
||||
* Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*
|
||||
* Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*
|
||||
* Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
||||
@ -535,9 +535,13 @@ public:
|
||||
const KVBlockArray pagedKVCache, const void* cuQSeqlenPtr, const void* cuKVSeqlenPtr, void* outputPtr,
|
||||
cudaStream_t stream)
|
||||
{
|
||||
KVBlockArrayForContextFMHA pagedKVCacheForContextMHA;
|
||||
pagedKVCacheForContextMHA = KVBlockArrayForContextFMHA(pagedKVCache.mMaxSeqs, pagedKVCache.mMaxBlocksPerSeq,
|
||||
pagedKVCache.mTokensPerBlock, mPagedKVParams.h_kv * mPagedKVParams.d * sizeof(half));
|
||||
pagedKVCacheForContextMHA.data = pagedKVCache.data;
|
||||
mPagedKVParams.q_ptr = qPtr;
|
||||
mPagedKVParams.tma_desc_paged_kv = reinterpret_cast<cudaTmaDesc*>(pagedKVTmaDesc);
|
||||
mPagedKVParams.paged_kv_cache = pagedKVCache;
|
||||
mPagedKVParams.paged_kv_cache = pagedKVCacheForContextMHA;
|
||||
mPagedKVParams.o_ptr = outputPtr;
|
||||
mPagedKVParams.cu_q_seqlens = reinterpret_cast<const int*>(cuQSeqlenPtr);
|
||||
mPagedKVParams.cu_seqlens = reinterpret_cast<const int*>(cuKVSeqlenPtr);
|
||||
|
||||
@ -179,7 +179,7 @@ struct Fused_multihead_attention_paged_kv_params_v2
|
||||
// The Q matrices.
|
||||
const void* q_ptr;
|
||||
// Paged KV Cache buffer.
|
||||
KVBlockArray paged_kv_cache;
|
||||
KVBlockArrayForContextFMHA paged_kv_cache;
|
||||
// The O matrix (output).
|
||||
void* o_ptr;
|
||||
// The packed mask for random mask.
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*
|
||||
* Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*
|
||||
* Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
||||
@ -29,7 +29,8 @@ namespace mmha
|
||||
|
||||
// Forward declaration of the kernel launcher to avoid including decoderMaskedMultiheadAttentionLaunch.h
|
||||
template <typename T, typename KVCacheBuffer, typename T_PARAMS, int Dh>
|
||||
void mmha_launch_kernel(const T_PARAMS& params, const KVCacheBuffer& kv_cache_buffer, const cudaStream_t& stream);
|
||||
void mmha_launch_kernel(const T_PARAMS& params, const KVCacheBuffer& kv_cache_buffer,
|
||||
const KVLinearBuffer& shift_k_cache, const cudaStream_t& stream);
|
||||
|
||||
} // namespace mmha
|
||||
|
||||
@ -37,44 +38,67 @@ namespace
|
||||
{
|
||||
|
||||
#define MMHA_LAUNCH_KERNEL(Dh) \
|
||||
mmha::mmha_launch_kernel<T, KVCacheBuffer, KERNEL_PARAMS_TYPE, Dh>(params, kv_cache_buffer, stream); \
|
||||
mmha::mmha_launch_kernel<T, KVCacheBuffer, KERNEL_PARAMS_TYPE, Dh>( \
|
||||
params, kv_cache_buffer, shift_k_cache, stream); \
|
||||
break;
|
||||
|
||||
template <typename T, typename KVCacheBuffer, typename KERNEL_PARAMS_TYPE>
|
||||
void multihead_attention_(
|
||||
const KERNEL_PARAMS_TYPE& params, const KVCacheBuffer& kv_cache_buffer, const cudaStream_t& stream)
|
||||
void multihead_attention_(const KERNEL_PARAMS_TYPE& params, const KVCacheBuffer& kv_cache_buffer,
|
||||
const KVLinearBuffer& shift_k_cache, const cudaStream_t& stream)
|
||||
{
|
||||
switch (params.hidden_size_per_head)
|
||||
{
|
||||
case 32: mmha::mmha_launch_kernel<T, KVCacheBuffer, KERNEL_PARAMS_TYPE, 32>(params, kv_cache_buffer, stream); break;
|
||||
case 64: mmha::mmha_launch_kernel<T, KVCacheBuffer, KERNEL_PARAMS_TYPE, 64>(params, kv_cache_buffer, stream); break;
|
||||
case 32:
|
||||
mmha::mmha_launch_kernel<T, KVCacheBuffer, KERNEL_PARAMS_TYPE, 32>(
|
||||
params, kv_cache_buffer, shift_k_cache, stream);
|
||||
break;
|
||||
case 64:
|
||||
mmha::mmha_launch_kernel<T, KVCacheBuffer, KERNEL_PARAMS_TYPE, 64>(
|
||||
params, kv_cache_buffer, shift_k_cache, stream);
|
||||
break;
|
||||
case 128:
|
||||
mmha::mmha_launch_kernel<T, KVCacheBuffer, KERNEL_PARAMS_TYPE, 128>(params, kv_cache_buffer, stream);
|
||||
mmha::mmha_launch_kernel<T, KVCacheBuffer, KERNEL_PARAMS_TYPE, 128>(
|
||||
params, kv_cache_buffer, shift_k_cache, stream);
|
||||
break;
|
||||
case 256:
|
||||
mmha::mmha_launch_kernel<T, KVCacheBuffer, KERNEL_PARAMS_TYPE, 256>(params, kv_cache_buffer, stream);
|
||||
mmha::mmha_launch_kernel<T, KVCacheBuffer, KERNEL_PARAMS_TYPE, 256>(
|
||||
params, kv_cache_buffer, shift_k_cache, stream);
|
||||
break;
|
||||
#ifndef FAST_BUILD // skip mmha 48, 80, 96, 112, 144, 160, 192 and 224 for fast build
|
||||
case 48: mmha::mmha_launch_kernel<T, KVCacheBuffer, KERNEL_PARAMS_TYPE, 48>(params, kv_cache_buffer, stream); break;
|
||||
case 80: mmha::mmha_launch_kernel<T, KVCacheBuffer, KERNEL_PARAMS_TYPE, 80>(params, kv_cache_buffer, stream); break;
|
||||
case 96: mmha::mmha_launch_kernel<T, KVCacheBuffer, KERNEL_PARAMS_TYPE, 96>(params, kv_cache_buffer, stream); break;
|
||||
case 48:
|
||||
mmha::mmha_launch_kernel<T, KVCacheBuffer, KERNEL_PARAMS_TYPE, 48>(
|
||||
params, kv_cache_buffer, shift_k_cache, stream);
|
||||
break;
|
||||
case 80:
|
||||
mmha::mmha_launch_kernel<T, KVCacheBuffer, KERNEL_PARAMS_TYPE, 80>(
|
||||
params, kv_cache_buffer, shift_k_cache, stream);
|
||||
break;
|
||||
case 96:
|
||||
mmha::mmha_launch_kernel<T, KVCacheBuffer, KERNEL_PARAMS_TYPE, 96>(
|
||||
params, kv_cache_buffer, shift_k_cache, stream);
|
||||
break;
|
||||
case 112:
|
||||
mmha::mmha_launch_kernel<T, KVCacheBuffer, KERNEL_PARAMS_TYPE, 112>(params, kv_cache_buffer, stream);
|
||||
mmha::mmha_launch_kernel<T, KVCacheBuffer, KERNEL_PARAMS_TYPE, 112>(
|
||||
params, kv_cache_buffer, shift_k_cache, stream);
|
||||
break;
|
||||
case 144:
|
||||
mmha::mmha_launch_kernel<T, KVCacheBuffer, KERNEL_PARAMS_TYPE, 144>(params, kv_cache_buffer, stream);
|
||||
mmha::mmha_launch_kernel<T, KVCacheBuffer, KERNEL_PARAMS_TYPE, 144>(
|
||||
params, kv_cache_buffer, shift_k_cache, stream);
|
||||
break;
|
||||
case 160:
|
||||
mmha::mmha_launch_kernel<T, KVCacheBuffer, KERNEL_PARAMS_TYPE, 160>(params, kv_cache_buffer, stream);
|
||||
mmha::mmha_launch_kernel<T, KVCacheBuffer, KERNEL_PARAMS_TYPE, 160>(
|
||||
params, kv_cache_buffer, shift_k_cache, stream);
|
||||
break;
|
||||
case 192:
|
||||
mmha::mmha_launch_kernel<T, KVCacheBuffer, KERNEL_PARAMS_TYPE, 192>(params, kv_cache_buffer, stream);
|
||||
mmha::mmha_launch_kernel<T, KVCacheBuffer, KERNEL_PARAMS_TYPE, 192>(
|
||||
params, kv_cache_buffer, shift_k_cache, stream);
|
||||
break;
|
||||
case 224:
|
||||
mmha::mmha_launch_kernel<T, KVCacheBuffer, KERNEL_PARAMS_TYPE, 224>(params, kv_cache_buffer, stream);
|
||||
mmha::mmha_launch_kernel<T, KVCacheBuffer, KERNEL_PARAMS_TYPE, 224>(
|
||||
params, kv_cache_buffer, shift_k_cache, stream);
|
||||
break;
|
||||
#endif // FAST_BUILD
|
||||
default: TLLM_THROW("unsupported head_size");
|
||||
default: TLLM_CHECK_WITH_INFO(false, "unsupported head_size %d", params.hidden_size_per_head);
|
||||
}
|
||||
}
|
||||
|
||||
@ -86,16 +110,16 @@ void multihead_attention_(
|
||||
|
||||
#define INSTANTIATE_MMHA_NORMAL_AND_PAGED(T, CROSS_ATTENTION) \
|
||||
void masked_multihead_attention(const Multihead_attention_params<T, CROSS_ATTENTION>& params, \
|
||||
const KVBlockArray& kv_cache_buffer, const cudaStream_t& stream) \
|
||||
const KVBlockArray& kv_cache_buffer, const KVLinearBuffer& shift_k_cache, const cudaStream_t& stream) \
|
||||
{ \
|
||||
multihead_attention_<T, KVBlockArray, Multihead_attention_params<T, CROSS_ATTENTION>>( \
|
||||
params, kv_cache_buffer, stream); \
|
||||
params, kv_cache_buffer, shift_k_cache, stream); \
|
||||
} \
|
||||
void masked_multihead_attention(const Multihead_attention_params<T, CROSS_ATTENTION>& params, \
|
||||
const KVLinearBuffer& kv_cache_buffer, const cudaStream_t& stream) \
|
||||
const KVLinearBuffer& kv_cache_buffer, const KVLinearBuffer& shift_k_cache, const cudaStream_t& stream) \
|
||||
{ \
|
||||
multihead_attention_<T, KVLinearBuffer, Multihead_attention_params<T, CROSS_ATTENTION>>( \
|
||||
params, kv_cache_buffer, stream); \
|
||||
params, kv_cache_buffer, shift_k_cache, stream); \
|
||||
}
|
||||
INSTANTIATE_MMHA_NORMAL_AND_PAGED(float, true)
|
||||
INSTANTIATE_MMHA_NORMAL_AND_PAGED(float, false)
|
||||
|
||||
@ -108,6 +108,8 @@ struct Multihead_attention_params_base
|
||||
int max_attention_window_size = 0;
|
||||
// Cyclic kv cache capacity (used to get the cyclic kv cache position for new tokens)
|
||||
int cyclic_attention_window_size = 0;
|
||||
// Length of the sink token in KV cache
|
||||
int sink_token_length = 0;
|
||||
// The number of heads (H).
|
||||
int num_heads = 0;
|
||||
// Controls MHA/MQA/GQA
|
||||
@ -122,6 +124,8 @@ struct Multihead_attention_params_base
|
||||
RotaryScalingType rotary_embedding_scale_type = RotaryScalingType::kNONE;
|
||||
float rotary_embedding_scale = 0.0f;
|
||||
int rotary_embedding_max_positions = 0;
|
||||
// Position shift for streamingllm
|
||||
bool position_shift_enabled = false;
|
||||
// The current timestep. TODO Check that do we only this param in cross attention?
|
||||
int timestep = 0;
|
||||
// The current timestep of each sentences (support different timestep for different sentences)
|
||||
@ -222,13 +226,13 @@ using Cross_multihead_attention_params = Multihead_attention_params<T, true>;
|
||||
|
||||
#define DECLARE_MMHA_NORMAL_AND_PAGED(T) \
|
||||
void masked_multihead_attention(const Masked_multihead_attention_params<T>& params, \
|
||||
const KVBlockArray& block_array, const cudaStream_t& stream); \
|
||||
const KVBlockArray& block_array, const KVLinearBuffer& shift_k_cache, const cudaStream_t& stream); \
|
||||
void masked_multihead_attention(const Masked_multihead_attention_params<T>& params, \
|
||||
const KVLinearBuffer& kv_cache_buffer, const cudaStream_t& stream); \
|
||||
const KVLinearBuffer& kv_cache_buffer, const KVLinearBuffer& shift_k_cache, const cudaStream_t& stream); \
|
||||
void masked_multihead_attention(const Cross_multihead_attention_params<T>& params, \
|
||||
const KVBlockArray& block_array, const cudaStream_t& stream); \
|
||||
const KVBlockArray& block_array, const KVLinearBuffer& shift_k_cache, const cudaStream_t& stream); \
|
||||
void masked_multihead_attention(const Cross_multihead_attention_params<T>& params, \
|
||||
const KVLinearBuffer& kv_cache_buffer, const cudaStream_t& stream);
|
||||
const KVLinearBuffer& kv_cache_buffer, const KVLinearBuffer& shift_k_cache, const cudaStream_t& stream);
|
||||
DECLARE_MMHA_NORMAL_AND_PAGED(float);
|
||||
DECLARE_MMHA_NORMAL_AND_PAGED(uint16_t);
|
||||
#ifdef ENABLE_BF16
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
#!/usr/bin/env python3
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
|
||||
@ -149,15 +149,15 @@ inline void multi_block_grid_setup(dim3& grid, const Multihead_attention_params<
|
||||
if (dynamic_smem_sz >= 46 * 1024) \
|
||||
{ \
|
||||
cudaError_t res = cudaFuncSetAttribute( \
|
||||
mmha::masked_multihead_attention_kernel<T, T_cache, KVCacheBuffer, Dh, DYNAMIC_THDS_PER_BLOCK, \
|
||||
KernelParamsType::DO_CROSS_ATTENTION, HAS_BEAMS, DO_MULTI_BLOCK>, \
|
||||
mmha::masked_multihead_attention_kernel<T, T_cache, TKcache, KVCacheBuffer, KCacheBuffer, Dh, \
|
||||
DYNAMIC_THDS_PER_BLOCK, KernelParamsType::DO_CROSS_ATTENTION, HAS_BEAMS, DO_MULTI_BLOCK, POS_SHIFT>, \
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize, dynamic_smem_sz); \
|
||||
TLLM_CHECK_WITH_INFO( \
|
||||
res == cudaSuccess, "Sequence Length is too long for the MMHA kernel (not enough shared memory)."); \
|
||||
} \
|
||||
TLLM_CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&available_blocks, \
|
||||
mmha::masked_multihead_attention_kernel<T, T_cache, KVCacheBuffer, Dh, DYNAMIC_THDS_PER_BLOCK, \
|
||||
KernelParamsType::DO_CROSS_ATTENTION, HAS_BEAMS, DO_MULTI_BLOCK>, \
|
||||
mmha::masked_multihead_attention_kernel<T, T_cache, TKcache, KVCacheBuffer, KCacheBuffer, Dh, \
|
||||
DYNAMIC_THDS_PER_BLOCK, KernelParamsType::DO_CROSS_ATTENTION, HAS_BEAMS, DO_MULTI_BLOCK, POS_SHIFT>, \
|
||||
DYNAMIC_THDS_PER_BLOCK, dynamic_smem_sz));
|
||||
|
||||
#define MMHA_KERNEL(DYNAMIC_THDS_PER_BLOCK, ENABLE_MULTI_BLOCK) \
|
||||
@ -166,16 +166,17 @@ inline void multi_block_grid_setup(dim3& grid, const Multihead_attention_params<
|
||||
/* Set 46KB threshold here because we have to take static/driver shared memory into consideration. */ \
|
||||
if (dynamic_smem_sz >= 46 * 1024) \
|
||||
{ \
|
||||
cudaError_t res = cudaFuncSetAttribute( \
|
||||
mmha::masked_multihead_attention_kernel<T, T_cache, KVCacheBuffer, Dh, DYNAMIC_THDS_PER_BLOCK, \
|
||||
KernelParamsType::DO_CROSS_ATTENTION, HAS_BEAMS, ENABLE_MULTI_BLOCK>, \
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize, dynamic_smem_sz); \
|
||||
cudaError_t res \
|
||||
= cudaFuncSetAttribute(mmha::masked_multihead_attention_kernel<T, T_cache, TKcache, KVCacheBuffer, \
|
||||
KCacheBuffer, Dh, DYNAMIC_THDS_PER_BLOCK, KernelParamsType::DO_CROSS_ATTENTION, \
|
||||
HAS_BEAMS, ENABLE_MULTI_BLOCK, POS_SHIFT>, \
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize, dynamic_smem_sz); \
|
||||
TLLM_CHECK_WITH_INFO( \
|
||||
res == cudaSuccess, "Sequence Length is too long for the MMHA kernel (not enough shared memory)."); \
|
||||
} \
|
||||
mmha::masked_multihead_attention_kernel<T, T_cache, KVCacheBuffer, Dh, DYNAMIC_THDS_PER_BLOCK, \
|
||||
KernelParamsType::DO_CROSS_ATTENTION, HAS_BEAMS, ENABLE_MULTI_BLOCK> \
|
||||
<<<grid, DYNAMIC_THDS_PER_BLOCK, dynamic_smem_sz, stream>>>(params, kv_cache_buffer);
|
||||
mmha::masked_multihead_attention_kernel<T, T_cache, TKcache, KVCacheBuffer, KCacheBuffer, Dh, \
|
||||
DYNAMIC_THDS_PER_BLOCK, KernelParamsType::DO_CROSS_ATTENTION, HAS_BEAMS, ENABLE_MULTI_BLOCK, POS_SHIFT> \
|
||||
<<<grid, DYNAMIC_THDS_PER_BLOCK, dynamic_smem_sz, stream>>>(params, kv_cache_buffer, k_cache_buffer);
|
||||
|
||||
// if resources are not enough to launch 512 threads per block, we will fallback to 256.
|
||||
#define MMHA_512_BLOCKSIZE_CHECK() \
|
||||
@ -204,10 +205,10 @@ inline void multi_block_grid_setup(dim3& grid, const Multihead_attention_params<
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename T, typename T_cache, typename KVCacheBuffer, typename KernelParamsType, int Dh, int THDS_PER_BLOCK,
|
||||
bool HAS_BEAMS, bool DO_MULTI_BLOCK>
|
||||
void mmha_launch_kernel_ex(
|
||||
const KernelParamsType& params, const KVCacheBuffer& kv_cache_buffer, const cudaStream_t& stream, int tlength)
|
||||
template <typename T, typename T_cache, typename TKcache, typename KVCacheBuffer, typename KCacheBuffer,
|
||||
typename KernelParamsType, int Dh, int THDS_PER_BLOCK, bool HAS_BEAMS, bool DO_MULTI_BLOCK, bool POS_SHIFT>
|
||||
void mmha_launch_kernel_ex(const KernelParamsType& params, const KVCacheBuffer& kv_cache_buffer,
|
||||
const KCacheBuffer& k_cache_buffer, const cudaStream_t& stream, int tlength)
|
||||
{
|
||||
dim3 grid{static_cast<unsigned>(params.num_heads), static_cast<unsigned>(params.batch_size), 1};
|
||||
|
||||
@ -225,8 +226,8 @@ void mmha_launch_kernel_ex(
|
||||
// Set 0 dynamic shared memory size as we need the number of available blocks limited by registers.
|
||||
// Dynamic shared memory is fixed for different block size.
|
||||
TLLM_CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks_per_sm,
|
||||
mmha::masked_multihead_attention_kernel<T, T_cache, KVCacheBuffer, Dh, THDS_PER_BLOCK,
|
||||
KernelParamsType::DO_CROSS_ATTENTION, HAS_BEAMS, DO_MULTI_BLOCK>,
|
||||
mmha::masked_multihead_attention_kernel<T, T_cache, TKcache, KVCacheBuffer, KCacheBuffer, Dh, THDS_PER_BLOCK,
|
||||
KernelParamsType::DO_CROSS_ATTENTION, HAS_BEAMS, DO_MULTI_BLOCK, POS_SHIFT>,
|
||||
THDS_PER_BLOCK, 0));
|
||||
|
||||
int block_size_factor
|
||||
@ -289,61 +290,80 @@ void mmha_launch_kernel_ex(
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename T_cache, typename KVCacheBuffer, typename KernelParamsType, int Dh, int THDS_PER_BLOCK,
|
||||
bool HAS_BEAMS, bool DO_MULTI_BLOCK>
|
||||
void mmha_launch_kernel_dispatch_pos_shift(const KernelParamsType& params, const KVCacheBuffer& kv_cache_buffer,
|
||||
const KVLinearBuffer& shift_k_cache, const cudaStream_t& stream, int tlength)
|
||||
{
|
||||
if (params.position_shift_enabled && !KernelParamsType::DO_CROSS_ATTENTION)
|
||||
{
|
||||
mmha_launch_kernel_ex<T, T_cache, T, KVCacheBuffer, KVLinearBuffer, KernelParamsType, Dh, THDS_PER_BLOCK,
|
||||
HAS_BEAMS, DO_MULTI_BLOCK, true>(params, kv_cache_buffer, shift_k_cache, stream, tlength);
|
||||
}
|
||||
else
|
||||
{
|
||||
mmha_launch_kernel_ex<T, T_cache, T_cache, KVCacheBuffer, KVCacheBuffer, KernelParamsType, Dh, THDS_PER_BLOCK,
|
||||
HAS_BEAMS, DO_MULTI_BLOCK, false>(params, kv_cache_buffer, kv_cache_buffer, stream, tlength);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename KVCacheBuffer, typename KernelParamsType, int Dh, int THDS_PER_BLOCK, bool HAS_BEAMS,
|
||||
bool DO_MULTI_BLOCK>
|
||||
void mmha_launch_kernel_dispatch_8bits_kv_cache(
|
||||
const KernelParamsType& params, const KVCacheBuffer& kv_cache_buffer, const cudaStream_t& stream, int tlength)
|
||||
void mmha_launch_kernel_dispatch_8bits_kv_cache(const KernelParamsType& params, const KVCacheBuffer& kv_cache_buffer,
|
||||
const KVLinearBuffer& shift_k_cache, const cudaStream_t& stream, int tlength)
|
||||
{
|
||||
if (params.int8_kv_cache)
|
||||
{
|
||||
mmha_launch_kernel_ex<T, int8_t, KVCacheBuffer, KernelParamsType, Dh, THDS_PER_BLOCK, HAS_BEAMS,
|
||||
DO_MULTI_BLOCK>(params, kv_cache_buffer, stream, tlength);
|
||||
mmha_launch_kernel_dispatch_pos_shift<T, int8_t, KVCacheBuffer, KernelParamsType, Dh, THDS_PER_BLOCK, HAS_BEAMS,
|
||||
DO_MULTI_BLOCK>(params, kv_cache_buffer, shift_k_cache, stream, tlength);
|
||||
}
|
||||
#ifdef ENABLE_FP8
|
||||
else if (params.fp8_kv_cache)
|
||||
{
|
||||
mmha_launch_kernel_ex<T, __nv_fp8_e4m3, KVCacheBuffer, KernelParamsType, Dh, THDS_PER_BLOCK, HAS_BEAMS,
|
||||
DO_MULTI_BLOCK>(params, kv_cache_buffer, stream, tlength);
|
||||
mmha_launch_kernel_dispatch_pos_shift<T, __nv_fp8_e4m3, KVCacheBuffer, KernelParamsType, Dh, THDS_PER_BLOCK,
|
||||
HAS_BEAMS, DO_MULTI_BLOCK>(params, kv_cache_buffer, shift_k_cache, stream, tlength);
|
||||
}
|
||||
#endif // ENABLE_FP8
|
||||
else
|
||||
{
|
||||
mmha_launch_kernel_ex<T, T, KVCacheBuffer, KernelParamsType, Dh, THDS_PER_BLOCK, HAS_BEAMS, DO_MULTI_BLOCK>(
|
||||
params, kv_cache_buffer, stream, tlength);
|
||||
mmha_launch_kernel_dispatch_pos_shift<T, T, KVCacheBuffer, KernelParamsType, Dh, THDS_PER_BLOCK, HAS_BEAMS,
|
||||
DO_MULTI_BLOCK>(params, kv_cache_buffer, shift_k_cache, stream, tlength);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename KVCacheBuffer, typename KernelParamsType, int Dh, bool HAS_BEAMS>
|
||||
void mmha_launch_kernel_dispatch(
|
||||
const KernelParamsType& params, const KVCacheBuffer& kv_cache_buffer, const cudaStream_t& stream)
|
||||
void mmha_launch_kernel_dispatch(const KernelParamsType& params, const KVCacheBuffer& kv_cache_buffer,
|
||||
const KVLinearBuffer& shift_k_cache, const cudaStream_t& stream)
|
||||
{
|
||||
int const tlength = params.timestep;
|
||||
if (params.multi_block_mode)
|
||||
{
|
||||
mmha_launch_kernel_dispatch_8bits_kv_cache<T, KVCacheBuffer, KernelParamsType, Dh, 256, HAS_BEAMS, true>(
|
||||
params, kv_cache_buffer, stream, tlength);
|
||||
params, kv_cache_buffer, shift_k_cache, stream, tlength);
|
||||
}
|
||||
else
|
||||
{
|
||||
mmha_launch_kernel_dispatch_8bits_kv_cache<T, KVCacheBuffer, KernelParamsType, Dh, 256, HAS_BEAMS, false>(
|
||||
params, kv_cache_buffer, stream, tlength);
|
||||
params, kv_cache_buffer, shift_k_cache, stream, tlength);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename KVCacheBuffer, typename KernelParamsType, int Dh>
|
||||
void mmha_launch_kernel(
|
||||
const KernelParamsType& params, const KVCacheBuffer& kv_cache_buffer, const cudaStream_t& stream)
|
||||
void mmha_launch_kernel(const KernelParamsType& params, const KVCacheBuffer& kv_cache_buffer,
|
||||
const KVLinearBuffer& shift_k_cache, const cudaStream_t& stream)
|
||||
{
|
||||
assert((params.rotary_embedding_dim != 0)
|
||||
== (params.position_embedding_type == PositionEmbeddingType::kROPE_GPT_NEOX
|
||||
|| params.position_embedding_type == PositionEmbeddingType::kROPE_GPTJ));
|
||||
if (params.beam_width == 1)
|
||||
{
|
||||
mmha_launch_kernel_dispatch<T, KVCacheBuffer, KernelParamsType, Dh, false>(params, kv_cache_buffer, stream);
|
||||
mmha_launch_kernel_dispatch<T, KVCacheBuffer, KernelParamsType, Dh, false>(
|
||||
params, kv_cache_buffer, shift_k_cache, stream);
|
||||
}
|
||||
else
|
||||
{
|
||||
mmha_launch_kernel_dispatch<T, KVCacheBuffer, KernelParamsType, Dh, true>(params, kv_cache_buffer, stream);
|
||||
mmha_launch_kernel_dispatch<T, KVCacheBuffer, KernelParamsType, Dh, true>(
|
||||
params, kv_cache_buffer, shift_k_cache, stream);
|
||||
}
|
||||
}
|
||||
|
||||
@ -352,16 +372,16 @@ void mmha_launch_kernel(
|
||||
#define INSTANTIATE_MMHA_LAUNCHERS(T, Dh) \
|
||||
template void mmha_launch_kernel<T, KVLinearBuffer, Masked_multihead_attention_params<T>, Dh>( \
|
||||
const Masked_multihead_attention_params<T>& params, const KVLinearBuffer& kv_cache_buffer, \
|
||||
const cudaStream_t& stream); \
|
||||
const KVLinearBuffer& shift_k_cache, const cudaStream_t& stream); \
|
||||
template void mmha_launch_kernel<T, KVBlockArray, Masked_multihead_attention_params<T>, Dh>( \
|
||||
const Masked_multihead_attention_params<T>& params, const KVBlockArray& kv_cache_buffer, \
|
||||
const cudaStream_t& stream); \
|
||||
const KVLinearBuffer& shift_k_cache, const cudaStream_t& stream); \
|
||||
template void mmha_launch_kernel<T, KVLinearBuffer, Cross_multihead_attention_params<T>, Dh>( \
|
||||
const Cross_multihead_attention_params<T>& params, const KVLinearBuffer& kv_cache_buffer, \
|
||||
const cudaStream_t& stream); \
|
||||
const KVLinearBuffer& shift_k_cache, const cudaStream_t& stream); \
|
||||
template void mmha_launch_kernel<T, KVBlockArray, Cross_multihead_attention_params<T>, Dh>( \
|
||||
const Cross_multihead_attention_params<T>& params, const KVBlockArray& kv_cache_buffer, \
|
||||
const cudaStream_t& stream);
|
||||
const KVLinearBuffer& shift_k_cache, const cudaStream_t& stream);
|
||||
|
||||
} // namespace kernels
|
||||
} // namespace tensorrt_llm
|
||||
|
||||
@ -1227,8 +1227,12 @@ template <
|
||||
typename T,
|
||||
// The type of the cache.
|
||||
typename Tcache,
|
||||
// The type of the shift key cache.
|
||||
typename TKcache,
|
||||
// Type of struct containing KV cache
|
||||
typename KVCacheBuffer,
|
||||
// Type of struct containing K cache to read past keys
|
||||
typename KCacheBuffer,
|
||||
// The hidden dimension per head.
|
||||
unsigned Dh,
|
||||
// The number of threads in a threadblock.
|
||||
@ -1239,6 +1243,8 @@ template <
|
||||
bool HAS_BEAMS,
|
||||
// Whether enable multi-block mode for long-sequence-length.
|
||||
bool DO_MULTI_BLOCK = false,
|
||||
// Whether enable position shift for streamingllm
|
||||
bool POS_SHIFT = false,
|
||||
// The number of threads per key.
|
||||
unsigned THREADS_PER_KEY = threads_per_key<T, dh_max(Dh)>(),
|
||||
// The number of threads per value.
|
||||
@ -1249,13 +1255,15 @@ template <
|
||||
// The unroll factor for loading from V cache.
|
||||
unsigned V_LOOP_UNROLL = 8>
|
||||
__global__ void masked_multihead_attention_kernel(
|
||||
Multihead_attention_params<T, DO_CROSS_ATTENTION> params, KVCacheBuffer kvCacheBuffer)
|
||||
Multihead_attention_params<T, DO_CROSS_ATTENTION> params, KVCacheBuffer kvCacheBuffer, KCacheBuffer pastKCache)
|
||||
{
|
||||
|
||||
using Tk = typename kernel_type_t<T>::Type;
|
||||
// Use 8bit cache.
|
||||
static constexpr bool ENABLE_8BITS_CACHE = sizeof(Tcache) == 1;
|
||||
static constexpr bool ENABLE_8BITS_K_CACHE = sizeof(TKcache) == 1;
|
||||
static constexpr bool ENABLE_8BITS_KV_CACHE = sizeof(Tcache) == 1;
|
||||
// FP8 KV Cache.
|
||||
static constexpr bool FP8_K_CACHE = std::is_same<TKcache, __nv_fp8_e4m3>::value;
|
||||
static constexpr bool FP8_KV_CACHE = std::is_same<Tcache, __nv_fp8_e4m3>::value;
|
||||
// INT8 KV Cache.
|
||||
static constexpr bool INT8_KV_CACHE = std::is_same<Tcache, int8_t>::value;
|
||||
@ -1276,6 +1284,8 @@ __global__ void masked_multihead_attention_kernel(
|
||||
// Note max_attention_window_size is maximum of cyclic_attention_window_size among all layers.
|
||||
// By default, you can assume that they are the same.
|
||||
const auto cyclic_kv_cache_len = static_cast<unsigned>(params.cyclic_attention_window_size);
|
||||
// The number of sink tokens in kv cache to support streamingllm
|
||||
const auto sink_token_len = static_cast<unsigned>(params.sink_token_length);
|
||||
// The current timestep (including paddings).
|
||||
// It is only used to calculate the smem stride.
|
||||
const auto timestep = static_cast<unsigned>(DO_MULTI_BLOCK ? params.timesteps_per_block : params.timestep);
|
||||
@ -1335,7 +1345,7 @@ __global__ void masked_multihead_attention_kernel(
|
||||
// The type of queries and keys for the math in the Q*K^T product.
|
||||
using K_vec_k = typename K_vec_k_<T, K_VEC_SIZE>::Type;
|
||||
// Only used when key cache is quantized to 8 bits.
|
||||
using K_vec_m = typename packed_type<Tcache, num_elems<K_vec_k>::value>::type;
|
||||
using K_vec_m = typename packed_type<TKcache, num_elems<K_vec_k>::value>::type;
|
||||
#ifdef MMHA_USE_FP32_ACCUM_FOR_FMA
|
||||
using K_vec_accum = typename Qk_vec_accum_fp32_<K_vec_k>::Type;
|
||||
#else
|
||||
@ -1424,7 +1434,12 @@ __global__ void masked_multihead_attention_kernel(
|
||||
: (params.length_per_sample ? (params.length_per_sample[batch_beam_idx] - 1) : static_cast<int>(timestep));
|
||||
// We will use cyclic kv cache when it exceeds the limit.
|
||||
// The length position for storing new key and value.
|
||||
const int cyclic_tlength = tlength % cyclic_kv_cache_len;
|
||||
const int cyclic_tlength = kvCacheBuffer.getKVTokenIdx(tlength);
|
||||
// When enable cyclic kv cache and one more block mode, we need to shift the index to the actual index in the
|
||||
// sequence. Otherwise, if the token is not the sink token, we need to add the bubblen length to the index.
|
||||
const bool enable_use_seq_idx_kv = kvCacheBuffer.mEnableOneMoreBlock && tlength > cyclic_kv_cache_len;
|
||||
const int shift_for_cyclic_kv = (enable_use_seq_idx_kv) ? tlength - cyclic_kv_cache_len : kvCacheBuffer.mBubbleLen;
|
||||
const int shift_for_cyclic_k = (enable_use_seq_idx_kv) ? tlength - cyclic_kv_cache_len : pastKCache.mBubbleLen;
|
||||
// The actual kv cache length.
|
||||
// tlength is the past length actually.
|
||||
const int kv_loop_length = min(tlength, cyclic_kv_cache_len);
|
||||
@ -1433,6 +1448,8 @@ __global__ void masked_multihead_attention_kernel(
|
||||
// as context kv cache might be overwritten by the new kv cache
|
||||
const int beam0_context_length
|
||||
= HAS_BEAMS && tlength > cyclic_kv_cache_len ? 0 : params.input_lengths[batch_beam_idx];
|
||||
// The position of the current timestep, and it is used to apply the position embedding
|
||||
const int current_pos_idx = (!POS_SHIFT || DO_CROSS_ATTENTION) ? tlength : kv_loop_length;
|
||||
|
||||
// The offset in the Q and K buffer also accounts for the batch.
|
||||
const auto qk_vec_idx = tidx * QK_VEC_SIZE;
|
||||
@ -1443,25 +1460,29 @@ __global__ void masked_multihead_attention_kernel(
|
||||
|
||||
// Quant/Dequant scales for 8bits kv cache.
|
||||
using T_scale = typename kv_cache_scale_type_t<T, Tcache>::Type;
|
||||
T_scale kv_scale_orig_quant, kv_scale_quant_orig;
|
||||
const float kv_scale_quant_orig_f = (ENABLE_8BITS_CACHE ? params.kv_scale_quant_orig[0] : 1.0f);
|
||||
convert_from_float(&kv_scale_quant_orig, kv_scale_quant_orig_f);
|
||||
convert_from_float(&kv_scale_orig_quant, (ENABLE_8BITS_CACHE ? params.kv_scale_orig_quant[0] : 1.0f));
|
||||
T_scale kv_scale_orig_quant, k_scale_quant_orig;
|
||||
const float k_scale_quant_orig_f = (ENABLE_8BITS_K_CACHE ? params.kv_scale_quant_orig[0] : 1.0f);
|
||||
const float kv_scale_quant_orig_f = (ENABLE_8BITS_KV_CACHE ? params.kv_scale_quant_orig[0] : 1.0f);
|
||||
convert_from_float(&k_scale_quant_orig, k_scale_quant_orig_f);
|
||||
convert_from_float(&kv_scale_orig_quant, (ENABLE_8BITS_KV_CACHE ? params.kv_scale_orig_quant[0] : 1.0f));
|
||||
|
||||
// Up to QK_VECS_PER_Dh_MAX threads load Q and K + the bias values for the current timestep.
|
||||
// Trigger the loads from the Q and K buffers.
|
||||
Qk_vec_k q, k, q_bias, k_bias;
|
||||
// key without position embedding
|
||||
Qk_vec_k k_wo_pos;
|
||||
zero(q);
|
||||
zero(k);
|
||||
zero(q_bias);
|
||||
zero(k_bias);
|
||||
zero(k_wo_pos);
|
||||
float rotary_embedding_base = params.rotary_embedding_base;
|
||||
float rotary_embedding_scale = params.rotary_embedding_scale;
|
||||
if (is_valid_qk_vec)
|
||||
{
|
||||
mmha::update_rotary_base_n_scale(rotary_embedding_base, rotary_embedding_scale,
|
||||
params.rotary_embedding_scale_type, params.rotary_embedding_dim, params.rotary_embedding_max_positions,
|
||||
tlength);
|
||||
current_pos_idx);
|
||||
// Query
|
||||
// The stride between tokens. We may be able to always use params.stride.
|
||||
uint32_t q_stride = params.stride ? static_cast<uint32_t>(params.stride) : (num_heads * Dh);
|
||||
@ -1553,6 +1574,7 @@ __global__ void masked_multihead_attention_kernel(
|
||||
vec_conversion<Qk_vec_k, Qk_vec_m>(*reinterpret_cast<const Qk_vec_m*>(
|
||||
¶ms.ia3_key_weights[tensorrt_llm::common::flat_index2(ia3_ti_hi, qk_vec_idx, Dh)])));
|
||||
}
|
||||
k_wo_pos = k;
|
||||
|
||||
// Note we have no paddings in KV cache now.
|
||||
switch (params.position_embedding_type)
|
||||
@ -1569,12 +1591,12 @@ __global__ void masked_multihead_attention_kernel(
|
||||
if (HANDLE_KV)
|
||||
{
|
||||
apply_rotary_embedding(q, k, tidx, params.rotary_embedding_dim, params.rotary_embedding_base,
|
||||
params.rotary_embedding_scale, tlength);
|
||||
params.rotary_embedding_scale, current_pos_idx);
|
||||
}
|
||||
else
|
||||
{
|
||||
apply_rotary_embedding(q, tidx, params.rotary_embedding_dim, params.rotary_embedding_base,
|
||||
params.rotary_embedding_scale, tlength);
|
||||
params.rotary_embedding_scale, current_pos_idx);
|
||||
}
|
||||
break;
|
||||
}
|
||||
@ -1613,14 +1635,14 @@ __global__ void masked_multihead_attention_kernel(
|
||||
mmha::vec_from_smem_transpose(k, k_smem_, transpose_idx, smem_pitch);
|
||||
|
||||
mmha::apply_rotary_embedding(q, k, transpose_idx / tidx_factor, params.rotary_embedding_dim,
|
||||
rotary_embedding_base, rotary_embedding_scale, tlength);
|
||||
rotary_embedding_base, rotary_embedding_scale, current_pos_idx);
|
||||
|
||||
mmha::write_smem_transpose(k, k_smem_, transpose_idx, smem_pitch);
|
||||
}
|
||||
else
|
||||
{
|
||||
mmha::apply_rotary_embedding(q, transpose_idx / tidx_factor, params.rotary_embedding_dim,
|
||||
rotary_embedding_base, rotary_embedding_scale, tlength);
|
||||
rotary_embedding_base, rotary_embedding_scale, current_pos_idx);
|
||||
}
|
||||
mmha::write_smem_transpose(q, q_smem_, transpose_idx, smem_pitch);
|
||||
}
|
||||
@ -1648,7 +1670,7 @@ __global__ void masked_multihead_attention_kernel(
|
||||
|
||||
// Store the Q values to shared memory.
|
||||
#ifdef MMHA_FP8_SCALE_Q_INSTEAD_OF_K
|
||||
if constexpr (FP8_KV_CACHE)
|
||||
if constexpr (FP8_K_CACHE)
|
||||
{
|
||||
// There are many more elements from K than elements from Q so we pre-scale Q instead
|
||||
// of scaling all the elements from K. It helps reduce the number of ops.
|
||||
@ -1656,7 +1678,7 @@ __global__ void masked_multihead_attention_kernel(
|
||||
zero(scaled_q);
|
||||
if (is_valid_qk_vec)
|
||||
{
|
||||
scaled_q = mul<Qk_vec_k, Tk, Qk_vec_k>(kv_scale_quant_orig, q);
|
||||
scaled_q = mul<Qk_vec_k, Tk, Qk_vec_k>(k_scale_quant_orig, q);
|
||||
}
|
||||
reinterpret_cast<Qk_vec_k*>(&q_smem[qk_vec_idx])[0] = scaled_q;
|
||||
}
|
||||
@ -1672,7 +1694,14 @@ __global__ void masked_multihead_attention_kernel(
|
||||
// Store the K values to shared memory.
|
||||
// We store K values from shared memory to global memory
|
||||
// when the target position of K cache in global memory has been accessed (in the case of cyclic kv cache)
|
||||
reinterpret_cast<Qk_vec_k*>(&k_smem[qk_vec_idx])[0] = k;
|
||||
if (POS_SHIFT && !DO_CROSS_ATTENTION)
|
||||
{
|
||||
reinterpret_cast<Qk_vec_k*>(&k_smem[qk_vec_idx])[0] = k_wo_pos;
|
||||
}
|
||||
else
|
||||
{
|
||||
reinterpret_cast<Qk_vec_k*>(&k_smem[qk_vec_idx])[0] = k;
|
||||
}
|
||||
|
||||
// Compute \sum_i Q[i] * K^T[i] for the current timestep.
|
||||
qk = dot<Qk_vec_accum, Qk_vec_k>(q, k);
|
||||
@ -1766,9 +1795,6 @@ __global__ void masked_multihead_attention_kernel(
|
||||
// The number of unrolled keys per ieration.
|
||||
constexpr unsigned UNROLLED_K_PER_ITER = K_PER_ITER * K_LOOP_UNROLL;
|
||||
|
||||
// Base pointer for the row of pointers to k cache blocks
|
||||
void** k_cache_base_row_ptr = reinterpret_cast<void**>(kvCacheBuffer.getRowPtr(KVIdxType::K_IDX, batch_beam_idx));
|
||||
|
||||
const auto timesteps_per_block = static_cast<unsigned>(params.timesteps_per_block);
|
||||
|
||||
// Pick a number of keys to make sure all the threads of a warp enter (due to shfl_sync).
|
||||
@ -1840,13 +1866,24 @@ __global__ void masked_multihead_attention_kernel(
|
||||
// Dh OOB values will be handled by zero_q.
|
||||
// Seq OOB values will be masked out when storing back to smem.
|
||||
auto const jj = min(k_idx.y + k_vec_i * K_ELTS_PER_CHUNK, Dh - K_VEC_SIZE);
|
||||
const int valid_time_now = min(time_now + k_loop * K_PER_ITER, context_length - 1);
|
||||
int valid_time_now = min(time_now + k_loop * K_PER_ITER, context_length - 1);
|
||||
if (valid_time_now >= sink_token_len)
|
||||
{
|
||||
// If one more block mode is enabled, we use the index in sequence as tokenIdx.
|
||||
// Otherwise, we need to add the bubble length to the index
|
||||
valid_time_now += shift_for_cyclic_k;
|
||||
if (enable_use_seq_idx_kv)
|
||||
{
|
||||
// Convert the token index in sequence to token index in K cache.
|
||||
valid_time_now = pastKCache.getKVTokenIdx(valid_time_now);
|
||||
}
|
||||
}
|
||||
const int seqIdx = batch_idx * beam_width;
|
||||
|
||||
// Base pointer to k cache block for beam's batch
|
||||
Tcache* k_cache_batch = reinterpret_cast<Tcache*>(kvCacheBuffer.getKBlockPtr(seqIdx, valid_time_now));
|
||||
TKcache* k_cache_batch = reinterpret_cast<TKcache*>(pastKCache.getKBlockPtr(seqIdx, valid_time_now));
|
||||
|
||||
int inBlockIdx = kvCacheBuffer.getKVLocalIdx(valid_time_now, hi_kv, Dh, jj);
|
||||
int inBlockIdx = pastKCache.getKVLocalIdx(valid_time_now, hi_kv, Dh, jj);
|
||||
k_vec_cache[k_loop][k_vec_i] = *reinterpret_cast<const K_vec_m*>(&k_cache_batch[inBlockIdx]);
|
||||
}
|
||||
}
|
||||
@ -1904,16 +1941,16 @@ __global__ void masked_multihead_attention_kernel(
|
||||
// Note that dot will convert 8bit vec to the accumulation data type (float by default).
|
||||
float qk_ = 0.f;
|
||||
#ifdef MMHA_FP8_SCALE_Q_INSTEAD_OF_K
|
||||
if constexpr (FP8_KV_CACHE)
|
||||
if constexpr (FP8_K_CACHE)
|
||||
{
|
||||
qk_ = Qk_dot<T, THREADS_PER_KEY>::dot(q_vec, k_vec) * params.inv_sqrt_dh;
|
||||
}
|
||||
else
|
||||
#endif // MMHA_FP8_SCALE_Q_INSTEAD_OF_K
|
||||
{
|
||||
if constexpr (ENABLE_8BITS_CACHE)
|
||||
if constexpr (ENABLE_8BITS_K_CACHE)
|
||||
{
|
||||
qk_ = Qk_dot<T, THREADS_PER_KEY>::scale_dot(q_vec, k_vec, kv_scale_quant_orig_f)
|
||||
qk_ = Qk_dot<T, THREADS_PER_KEY>::scale_dot(q_vec, k_vec, k_scale_quant_orig_f)
|
||||
* params.inv_sqrt_dh;
|
||||
}
|
||||
else
|
||||
@ -1975,13 +2012,24 @@ __global__ void masked_multihead_attention_kernel(
|
||||
for (int k_vec_i = 0; k_vec_i < K_VECS_PER_THREAD; ++k_vec_i)
|
||||
{
|
||||
const int jj = min(k_idx.y + k_vec_i * K_ELTS_PER_CHUNK, Dh - K_VEC_SIZE);
|
||||
const int valid_time_now = min(time_now, kv_loop_length - 1);
|
||||
int valid_time_now = min(time_now, kv_loop_length - 1);
|
||||
int beam_offset = beam_indices[valid_time_now];
|
||||
if (valid_time_now >= sink_token_len)
|
||||
{
|
||||
// If one more block mode is enabled, we use the index in sequence as tokenIdx.
|
||||
// Otherwise, we need to add the bubble length to the index
|
||||
valid_time_now += shift_for_cyclic_k;
|
||||
if (enable_use_seq_idx_kv)
|
||||
{
|
||||
// Convert the token index in sequence to token index in K cache.
|
||||
valid_time_now = pastKCache.getKVTokenIdx(valid_time_now);
|
||||
}
|
||||
}
|
||||
const int seqIdx = batch_idx * beam_width + beam_offset;
|
||||
// Base pointer to k cache block for beam's batch, before offsetting with indirection buffer
|
||||
Tcache* k_cache_batch = reinterpret_cast<Tcache*>(kvCacheBuffer.getKBlockPtr(seqIdx, valid_time_now));
|
||||
TKcache* k_cache_batch = reinterpret_cast<TKcache*>(pastKCache.getKBlockPtr(seqIdx, valid_time_now));
|
||||
|
||||
int inBlockIdx = kvCacheBuffer.getKVLocalIdx(valid_time_now, hi_kv, Dh, jj);
|
||||
int inBlockIdx = pastKCache.getKVLocalIdx(valid_time_now, hi_kv, Dh, jj);
|
||||
k_vec[k_vec_i] = (*reinterpret_cast<const K_vec_m*>(&k_cache_batch[inBlockIdx]));
|
||||
}
|
||||
|
||||
@ -2024,16 +2072,16 @@ __global__ void masked_multihead_attention_kernel(
|
||||
// Note that dot will convert 8bit vec to the accumulation data type (float by default).
|
||||
float qk_ = 0.f;
|
||||
#ifdef MMHA_FP8_SCALE_Q_INSTEAD_OF_K
|
||||
if constexpr (FP8_KV_CACHE)
|
||||
if constexpr (FP8_K_CACHE)
|
||||
{
|
||||
qk_ = Qk_dot<T, THREADS_PER_KEY>::dot(q_vec, k_vec) * params.inv_sqrt_dh;
|
||||
}
|
||||
else
|
||||
#endif // MMHA_FP8_SCALE_Q_INSTEAD_OF_K
|
||||
{
|
||||
if constexpr (ENABLE_8BITS_CACHE)
|
||||
if constexpr (ENABLE_8BITS_K_CACHE)
|
||||
{
|
||||
qk_ = Qk_dot<T, THREADS_PER_KEY>::scale_dot(q_vec, k_vec, kv_scale_quant_orig_f)
|
||||
qk_ = Qk_dot<T, THREADS_PER_KEY>::scale_dot(q_vec, k_vec, k_scale_quant_orig_f)
|
||||
* params.inv_sqrt_dh;
|
||||
}
|
||||
else
|
||||
@ -2126,7 +2174,7 @@ __global__ void masked_multihead_attention_kernel(
|
||||
// The base pointer for the value in the cache buffer.
|
||||
Tcache* k_cache = reinterpret_cast<Tcache*>(kvCacheBuffer.getKBlockPtr(batch_beam_idx, cyclic_tlength));
|
||||
|
||||
if constexpr (ENABLE_8BITS_CACHE)
|
||||
if constexpr (ENABLE_8BITS_KV_CACHE)
|
||||
{
|
||||
store_8bits_kv_cache_vec(reinterpret_cast<Tcache*>(k_cache), k_vec, inBlockIdx, kv_scale_orig_quant);
|
||||
}
|
||||
@ -2224,12 +2272,6 @@ __global__ void masked_multihead_attention_kernel(
|
||||
const auto vo = v_idx.x;
|
||||
// The hidden dimensions computed by this particular thread.
|
||||
const auto vi = v_idx.y;
|
||||
// Base pointer for the row of pointers to v cache blocks
|
||||
void** v_cache_base_row_ptr = reinterpret_cast<void**>(kvCacheBuffer.getRowPtr(KVIdxType::V_IDX, batch_beam_idx));
|
||||
// Base pointer for the row of pointers to v cache blocks for beam's batch, before offsetting with indirection
|
||||
// buffer
|
||||
void** v_cache_batch_row_ptr
|
||||
= reinterpret_cast<void**>(kvCacheBuffer.getRowPtr(KVIdxType::V_IDX, batch_idx * beam_width));
|
||||
|
||||
// The number of values processed per iteration of the loop.
|
||||
constexpr unsigned V_PER_ITER{THREADS_PER_BLOCK / THREADS_PER_VALUE};
|
||||
@ -2293,6 +2335,17 @@ __global__ void masked_multihead_attention_kernel(
|
||||
// Fetch offset based on cache_indir when beam sampling
|
||||
int time_idx = ti + v_loop * V_PER_ITER + (MULTI_BLOCK_FLAG ? c_tile_times_timesteps_per_block : 0);
|
||||
time_idx = min(time_idx, kv_loop_length - 1);
|
||||
if (time_idx >= sink_token_len)
|
||||
{
|
||||
// If one more block mode is enabled, we use the index in sequence as tokenIdx.
|
||||
// Otherwise, we need to add the bubble length to the index
|
||||
time_idx += shift_for_cyclic_kv;
|
||||
if (enable_use_seq_idx_kv)
|
||||
{
|
||||
// Convert the token index in sequence to token index in V cache.
|
||||
time_idx = kvCacheBuffer.getKVTokenIdx(time_idx);
|
||||
}
|
||||
}
|
||||
int rowIdx = batch_idx * beam_width;
|
||||
|
||||
const int inBlockIdx = kvCacheBuffer.getKVLocalIdx(time_idx, hi_kv, Dh, vi);
|
||||
@ -2339,6 +2392,18 @@ __global__ void masked_multihead_attention_kernel(
|
||||
}
|
||||
int rowIdx = batch_idx * beam_width + beam_indices[time_idx];
|
||||
|
||||
if (time_idx >= sink_token_len)
|
||||
{
|
||||
// If one more block mode is enabled, we use the index in sequence as tokenIdx.
|
||||
// Otherwise, we need to add the bubble length to the index
|
||||
time_idx += shift_for_cyclic_kv;
|
||||
if (enable_use_seq_idx_kv)
|
||||
{
|
||||
// Convert the token index in sequence to token index in V cache.
|
||||
time_idx = kvCacheBuffer.getKVTokenIdx(time_idx);
|
||||
}
|
||||
}
|
||||
|
||||
const int inBlockIdx = kvCacheBuffer.getKVLocalIdx(time_idx, hi_kv, Dh, vi);
|
||||
// The base pointer for the value in the cache buffer.
|
||||
Tcache* v_cache_batch = reinterpret_cast<Tcache*>(kvCacheBuffer.getVBlockPtr(rowIdx, time_idx));
|
||||
@ -2362,10 +2427,9 @@ __global__ void masked_multihead_attention_kernel(
|
||||
// One group of threads computes the product(s) for the current timestep.
|
||||
if (vo == kv_loop_length % V_PER_ITER && is_valid_vi && (!MULTI_BLOCK_FLAG || (c_tile == ctile_idx)))
|
||||
{
|
||||
const int tokenIdx = cyclic_tlength;
|
||||
const int inBlockIdx = kvCacheBuffer.getKVLocalIdx(tokenIdx, hi_kv, Dh, vi);
|
||||
const int inBlockIdx = kvCacheBuffer.getKVLocalIdx(cyclic_tlength, hi_kv, Dh, vi);
|
||||
// The base pointer for the value in the cache buffer.
|
||||
Tcache* v_cache_base = reinterpret_cast<Tcache*>(kvCacheBuffer.getBlockPtr(v_cache_base_row_ptr, tokenIdx));
|
||||
Tcache* v_cache_base = reinterpret_cast<Tcache*>(kvCacheBuffer.getVBlockPtr(batch_beam_idx, cyclic_tlength));
|
||||
|
||||
V_vec_k v;
|
||||
if (DO_CROSS_ATTENTION)
|
||||
@ -2414,7 +2478,7 @@ __global__ void masked_multihead_attention_kernel(
|
||||
// For MQA/GQA mode, write only with the first Q head of each group per KV head.
|
||||
if (hi == (hi_kv * qhead_per_kv))
|
||||
{
|
||||
if (ENABLE_8BITS_CACHE)
|
||||
if (ENABLE_8BITS_KV_CACHE)
|
||||
{
|
||||
store_8bits_kv_cache_vec(v_cache_base, v, inBlockIdx, kv_scale_orig_quant);
|
||||
}
|
||||
|
||||
@ -287,8 +287,8 @@ public:
|
||||
}
|
||||
|
||||
template <typename T, bool HAS_BEAM>
|
||||
void run(const XQAParams& xqaParams, KVLinearBuffer& kv_linear_buffer, const cudaStream_t& stream,
|
||||
int multiprocessor_count, int max_multi_block_slots) const
|
||||
void run(const XQAParams& xqaParams, KVLinearBuffer& kv_linear_buffer, int2& rotary_kernel_launch_cache,
|
||||
const cudaStream_t& stream, int multiprocessor_count, int max_multi_block_slots) const
|
||||
{
|
||||
unsigned int head_size = xqaParams.head_size;
|
||||
int num_q_heads = xqaParams.num_q_heads;
|
||||
@ -303,11 +303,12 @@ public:
|
||||
|
||||
invokeApplyBiasRopeUpdateKVCache<T, KVLinearBuffer, true>(static_cast<T*>(const_cast<void*>(xqaParams.qkv)),
|
||||
nullptr, kv_linear_buffer, static_cast<const T*>(xqaParams.qkv_bias), xqaParams.sequence_lengths, nullptr,
|
||||
nullptr, xqaParams.batch_size, 1, xqaParams.cyclic_attention_window_size, xqaParams.batch_size * beam_width,
|
||||
xqaParams.num_q_heads, xqaParams.num_kv_heads, xqaParams.head_size, xqaParams.rotary_embedding_dim,
|
||||
xqaParams.rotary_embedding_base, xqaParams.rotary_embedding_scale_type, xqaParams.rotary_embedding_scale,
|
||||
xqaParams.rotary_embedding_max_positions, xqaParams.position_embedding_type, (float*) nullptr, 0,
|
||||
cache_type, xqaParams.kv_scale_orig_quant, false, stream, beam_width);
|
||||
nullptr, xqaParams.batch_size, 1, xqaParams.cyclic_attention_window_size, xqaParams.sink_token_length,
|
||||
xqaParams.batch_size * beam_width, xqaParams.num_q_heads, xqaParams.num_kv_heads, xqaParams.head_size,
|
||||
xqaParams.rotary_embedding_dim, xqaParams.rotary_embedding_base, xqaParams.rotary_embedding_scale_type,
|
||||
xqaParams.rotary_embedding_scale, xqaParams.rotary_embedding_max_positions,
|
||||
xqaParams.position_embedding_type, xqaParams.position_shift_enabled, (float*) nullptr, 0, cache_type,
|
||||
xqaParams.kv_scale_orig_quant, false, beam_width, rotary_kernel_launch_cache, stream);
|
||||
|
||||
XQAKernelRuntimeHashKey hash_key{xqaParams.kv_cache_data_type, head_size, num_q_heads_over_kv, beam_width};
|
||||
const auto findIter = mFunctions.find(hash_key);
|
||||
@ -459,8 +460,8 @@ public:
|
||||
return xqaKernel->supportConfig(xqaParams) && xqaKernel->mayHavePerfGain(xqaParams, mMultiProcessorCount);
|
||||
}
|
||||
|
||||
void run(const XQAParams& xqa_params, KVLinearBuffer& kv_linear_buffer, const cudaStream_t& stream,
|
||||
int max_multi_block_slots);
|
||||
void run(const XQAParams& xqa_params, KVLinearBuffer& kv_linear_buffer, int2& rotary_kernel_launch_cache,
|
||||
const cudaStream_t& stream, int max_multi_block_slots);
|
||||
|
||||
private:
|
||||
const XQAKernelList* xqaKernel;
|
||||
@ -470,17 +471,17 @@ private:
|
||||
};
|
||||
|
||||
void DecoderXQARunner::xqaImpl::run(const XQAParams& xqa_params, KVLinearBuffer& kv_linear_buffer,
|
||||
const cudaStream_t& stream, int max_multi_block_slots)
|
||||
int2& rotary_kernel_launch_cache, const cudaStream_t& stream, int max_multi_block_slots)
|
||||
{
|
||||
if (xqa_params.beam_width > 1)
|
||||
{
|
||||
xqaKernel->template run<__half, true>(
|
||||
xqa_params, kv_linear_buffer, stream, mMultiProcessorCount, max_multi_block_slots);
|
||||
xqaKernel->template run<__half, true>(xqa_params, kv_linear_buffer, rotary_kernel_launch_cache, stream,
|
||||
mMultiProcessorCount, max_multi_block_slots);
|
||||
}
|
||||
else
|
||||
{
|
||||
xqaKernel->template run<__half, false>(
|
||||
xqa_params, kv_linear_buffer, stream, mMultiProcessorCount, max_multi_block_slots);
|
||||
xqaKernel->template run<__half, false>(xqa_params, kv_linear_buffer, rotary_kernel_launch_cache, stream,
|
||||
mMultiProcessorCount, max_multi_block_slots);
|
||||
}
|
||||
}
|
||||
|
||||
@ -545,7 +546,7 @@ void DecoderXQARunner::run(const XQAParams& xqa_params, KVLinearBuffer& kv_linea
|
||||
{
|
||||
int max_multi_block_slots = kMaxBeamWidth * XQALaunchParam<true>::GetMaxBatchSizePerWave(mMultiProcessorCount)
|
||||
* kMaxNbCtaPerKVHeadFactor * mNumKVHeads;
|
||||
return pimpl->run(xqa_params, kv_linear_buffer, stream, max_multi_block_slots);
|
||||
return pimpl->run(xqa_params, kv_linear_buffer, mLaunchGridBlockCache, stream, max_multi_block_slots);
|
||||
}
|
||||
|
||||
} // namespace kernels
|
||||
|
||||
@ -63,6 +63,7 @@ struct XQAParams
|
||||
int32_t beam_width = 0;
|
||||
int32_t max_attention_window_size = 0;
|
||||
int32_t cyclic_attention_window_size = 0;
|
||||
int32_t sink_token_length = 0;
|
||||
int timestep = 0;
|
||||
const void* qkv_bias;
|
||||
const int32_t* sequence_lengths; //
|
||||
@ -82,6 +83,7 @@ struct XQAParams
|
||||
float rotary_embedding_scale;
|
||||
int rotary_embedding_max_positions;
|
||||
tensorrt_llm::kernels::PositionEmbeddingType position_embedding_type;
|
||||
bool position_shift_enabled = false;
|
||||
bool remove_padding = false;
|
||||
tensorrt_llm::kernels::AttentionMaskType mask_type;
|
||||
bool paged_kv_cache;
|
||||
@ -144,6 +146,8 @@ public:
|
||||
SUPPORT_RETURN_FALSE("beam_width");
|
||||
if (xqaParams.cyclic_attention_window_size != xqaParams.max_attention_window_size)
|
||||
SUPPORT_RETURN_FALSE("cyclic_attention_window_size != max_attention_window_size");
|
||||
if (xqaParams.position_shift_enabled || xqaParams.sink_token_length > 0)
|
||||
SUPPORT_RETURN_FALSE("streaming-llm");
|
||||
return shouldUseImpl(xqaParams);
|
||||
}
|
||||
|
||||
@ -182,6 +186,10 @@ private:
|
||||
|
||||
static constexpr int kMaxBeamWidth = 4;
|
||||
|
||||
// Cache the grid_size and block_size that gives the highest occupancy for
|
||||
// invokeApplyBiasRopeUpdateKVCache.
|
||||
int2 mLaunchGridBlockCache = make_int2(0, 0);
|
||||
|
||||
class xqaImpl;
|
||||
std::unique_ptr<xqaImpl> pimpl;
|
||||
|
||||
|
||||
@ -28,6 +28,7 @@ using tensorrt_llm::common::bf16hmul2;
|
||||
using tensorrt_llm::common::bf16hmul;
|
||||
using tensorrt_llm::common::bf16hadd2;
|
||||
using tensorrt_llm::common::float22bf162;
|
||||
using tensorrt_llm::common::hsub2;
|
||||
#endif
|
||||
|
||||
namespace tensorrt_llm
|
||||
@ -48,7 +49,7 @@ struct __align__(16) Float4_
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct __align__(32) Float8_
|
||||
struct __align__(16) Float8_
|
||||
{
|
||||
float2 x;
|
||||
float2 y;
|
||||
@ -336,6 +337,32 @@ struct packed_type<float, 8>
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
inline __device__ uint32_t sub(uint32_t a, uint32_t b)
|
||||
{
|
||||
uint32_t c;
|
||||
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b));
|
||||
return c;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
inline __device__ __nv_bfloat162 sub(__nv_bfloat162 a, __nv_bfloat162 b)
|
||||
{
|
||||
return hsub2(a, b);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
inline __device__ float2 sub(float2 a, float2 b)
|
||||
{
|
||||
float2 c;
|
||||
c.x = a.x - b.x;
|
||||
c.y = a.y - b.y;
|
||||
return c;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
inline __device__ float add(float a, float b)
|
||||
{
|
||||
return a + b;
|
||||
@ -2510,7 +2537,19 @@ inline __device__ float update_rotary_base(
|
||||
{
|
||||
const float b = (scale * kv_seq_len / max_positions) - (scale - 1);
|
||||
const float p = static_cast<float>(embed_dim) / (embed_dim - 2);
|
||||
return base * pow(b, p);
|
||||
return base * __powf(b, p);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
inline __device__ float2 update_dynamic_scaling_rotary(float base, float scale, const int kv_seq_len,
|
||||
const int max_positions, const int embed_dim, const bool dynamic_scaling)
|
||||
{
|
||||
const float b = kv_seq_len * __fdividef(scale, max_positions) - (scale - 1);
|
||||
const float p = __fdividef(embed_dim, embed_dim - 2);
|
||||
const float updated_base = dynamic_scaling ? base * __powf(b, p) : base;
|
||||
const float updated_scale = dynamic_scaling ? 1.0f : scale;
|
||||
return {updated_base, updated_scale};
|
||||
}
|
||||
|
||||
inline __device__ void update_rotary_base_n_scale(float& base, float& scale, RotaryScalingType const scale_type,
|
||||
@ -2534,8 +2573,8 @@ inline __device__ void update_rotary_base_n_scale(float& base, float& scale, Rot
|
||||
inline __device__ float2 rotary_embedding_coefficient(
|
||||
const int zid, const int rot_embed_dim, const float base, const float scale, const float t_step)
|
||||
{
|
||||
const float inv_freq = (t_step * scale) / pow(base, zid / (float) rot_embed_dim);
|
||||
return {cos(inv_freq), sin(inv_freq)};
|
||||
const float inv_freq = __fdividef(float(t_step * scale), __powf(base, zid / (float) rot_embed_dim));
|
||||
return {__cosf(inv_freq), __sinf(inv_freq)};
|
||||
}
|
||||
|
||||
inline __device__ float2 rotary_embedding_transform(const float2 v, const float2 coef)
|
||||
@ -2629,6 +2668,30 @@ inline __device__ void apply_rotary_embedding(
|
||||
k_.y = rotary_embedding_transform(k_.y, coef1);
|
||||
}
|
||||
|
||||
inline __device__ void apply_rotary_embedding(
|
||||
Float8_& q, Float8_& k, int tid, int rot_embed_dim, float base, float scale, int t_step)
|
||||
{
|
||||
if (8 * tid >= rot_embed_dim)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
Float8_& q_ = *reinterpret_cast<Float8_*>(&q);
|
||||
Float8_& k_ = *reinterpret_cast<Float8_*>(&k);
|
||||
const auto coef0 = rotary_embedding_coefficient(8 * tid, rot_embed_dim, base, scale, t_step);
|
||||
q_.x = rotary_embedding_transform(q_.x, coef0);
|
||||
k_.x = rotary_embedding_transform(k_.x, coef0);
|
||||
const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, rot_embed_dim, base, scale, t_step);
|
||||
q_.y = rotary_embedding_transform(q_.y, coef1);
|
||||
k_.y = rotary_embedding_transform(k_.y, coef1);
|
||||
const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, rot_embed_dim, base, scale, t_step);
|
||||
q_.z = rotary_embedding_transform(q_.z, coef2);
|
||||
k_.z = rotary_embedding_transform(k_.z, coef2);
|
||||
const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, rot_embed_dim, base, scale, t_step);
|
||||
q_.w = rotary_embedding_transform(q_.w, coef3);
|
||||
k_.w = rotary_embedding_transform(k_.w, coef3);
|
||||
}
|
||||
|
||||
inline __device__ void apply_rotary_embedding(
|
||||
uint32_t& q, int tid, int rot_embed_dim, float base, float scale, int t_step)
|
||||
{
|
||||
@ -2821,6 +2884,205 @@ inline __device__ void apply_rotary_embedding(
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
inline __device__ void apply_rotary_embedding(uint16_t& q, uint16_t q_pair, uint16_t& k, uint16_t k_pair, int tid0,
|
||||
int tid1, // not used
|
||||
int rot_embed_dim, float base, float scale, int t_step, int first_half)
|
||||
{
|
||||
const float2 coef = rotary_embedding_coefficient(tid0, rot_embed_dim, base, scale, t_step);
|
||||
uint32_t cos = float2_to_half2(make_float2(coef.x, coef.x));
|
||||
uint32_t sin = float2_to_half2(make_float2(coef.y, coef.y));
|
||||
uint32_t h2, h2_pair;
|
||||
reinterpret_cast<uint16_t*>(&h2)[0] = q;
|
||||
reinterpret_cast<uint16_t*>(&h2)[1] = k;
|
||||
reinterpret_cast<uint16_t*>(&h2_pair)[0] = q_pair;
|
||||
reinterpret_cast<uint16_t*>(&h2_pair)[1] = k_pair;
|
||||
if (first_half)
|
||||
{
|
||||
h2 = sub(mul<uint32_t>(cos, h2), mul<uint32_t>(sin, h2_pair));
|
||||
}
|
||||
else
|
||||
{
|
||||
h2 = add(mul<uint32_t>(cos, h2), mul<uint32_t>(sin, h2_pair));
|
||||
}
|
||||
q = reinterpret_cast<uint16_t*>(&h2)[0];
|
||||
k = reinterpret_cast<uint16_t*>(&h2)[1];
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
inline __device__ void apply_rotary_embedding(uint32_t& q, uint32_t q_pair, uint32_t& k, uint32_t k_pair, int tid0,
|
||||
int tid1, int rot_embed_dim, float base, float scale, int t_step, int first_half)
|
||||
{
|
||||
const float2 coef0 = rotary_embedding_coefficient(tid0, rot_embed_dim, base, scale, t_step);
|
||||
const float2 coef1 = rotary_embedding_coefficient(tid1, rot_embed_dim, base, scale, t_step);
|
||||
uint32_t cos0 = float2_to_half2(make_float2(coef0.x, coef1.x));
|
||||
uint32_t sin0 = float2_to_half2(make_float2(coef0.y, coef1.y));
|
||||
if (first_half)
|
||||
{
|
||||
q = sub(mul<uint32_t>(cos0, q), mul<uint32_t>(sin0, q_pair));
|
||||
k = sub(mul<uint32_t>(cos0, k), mul<uint32_t>(sin0, k_pair));
|
||||
}
|
||||
else
|
||||
{
|
||||
q = add(mul<uint32_t>(cos0, q), mul<uint32_t>(sin0, q_pair));
|
||||
k = add(mul<uint32_t>(cos0, k), mul<uint32_t>(sin0, k_pair));
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
inline __device__ void apply_rotary_embedding(__nv_bfloat16& q, __nv_bfloat16 q_pair, __nv_bfloat16& k,
|
||||
__nv_bfloat16 k_pair, int tid0,
|
||||
int tid1, // not used
|
||||
int rot_embed_dim, float base, float scale, int t_step, int first_half)
|
||||
{
|
||||
const float2 coef = rotary_embedding_coefficient(tid0, rot_embed_dim, base, scale, t_step);
|
||||
__nv_bfloat162 cos = float22bf162(make_float2(coef.x, coef.x));
|
||||
__nv_bfloat162 sin = float22bf162(make_float2(coef.y, coef.y));
|
||||
__nv_bfloat162 h2, h2_pair;
|
||||
reinterpret_cast<__nv_bfloat16*>(&h2)[0] = q;
|
||||
reinterpret_cast<__nv_bfloat16*>(&h2)[1] = k;
|
||||
reinterpret_cast<__nv_bfloat16*>(&h2_pair)[0] = q_pair;
|
||||
reinterpret_cast<__nv_bfloat16*>(&h2_pair)[1] = k_pair;
|
||||
if (first_half)
|
||||
{
|
||||
h2 = sub(mul<__nv_bfloat162>(cos, h2), mul<__nv_bfloat162>(sin, h2_pair));
|
||||
}
|
||||
else
|
||||
{
|
||||
h2 = add(mul<__nv_bfloat162>(cos, h2), mul<__nv_bfloat162>(sin, h2_pair));
|
||||
}
|
||||
q = reinterpret_cast<__nv_bfloat16*>(&h2)[0];
|
||||
k = reinterpret_cast<__nv_bfloat16*>(&h2)[1];
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
inline __device__ void apply_rotary_embedding(__nv_bfloat162& q, __nv_bfloat162 q_pair, __nv_bfloat162& k,
|
||||
__nv_bfloat162 k_pair, int tid0, int tid1, int rot_embed_dim, float base, float scale, int t_step, int first_half)
|
||||
{
|
||||
const float2 coef0 = rotary_embedding_coefficient(tid0, rot_embed_dim, base, scale, t_step);
|
||||
const float2 coef1 = rotary_embedding_coefficient(tid1, rot_embed_dim, base, scale, t_step);
|
||||
__nv_bfloat162 cos0 = float22bf162(make_float2(coef0.x, coef1.x));
|
||||
__nv_bfloat162 sin0 = float22bf162(make_float2(coef0.y, coef1.y));
|
||||
if (first_half)
|
||||
{
|
||||
q = sub(mul<__nv_bfloat162>(cos0, q), mul<__nv_bfloat162>(sin0, q_pair));
|
||||
k = sub(mul<__nv_bfloat162>(cos0, k), mul<__nv_bfloat162>(sin0, k_pair));
|
||||
}
|
||||
else
|
||||
{
|
||||
q = add(mul<__nv_bfloat162>(cos0, q), mul<__nv_bfloat162>(sin0, q_pair));
|
||||
k = add(mul<__nv_bfloat162>(cos0, k), mul<__nv_bfloat162>(sin0, k_pair));
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
inline __device__ void apply_rotary_embedding(float& q, float q_pair, float& k, float k_pair, int tid0,
|
||||
int tid1, // not used
|
||||
int rot_embed_dim, float base, float scale, int t_step, int first_half)
|
||||
{
|
||||
const float2 coef = rotary_embedding_coefficient(tid0, rot_embed_dim, base, scale, t_step);
|
||||
float cos = coef.x;
|
||||
float sin = coef.y;
|
||||
if (first_half)
|
||||
{
|
||||
q = sub(mul<float>(cos, q), mul<float>(sin, q_pair));
|
||||
k = sub(mul<float>(cos, k), mul<float>(sin, k_pair));
|
||||
}
|
||||
else
|
||||
{
|
||||
q = add(mul<float>(cos, q), mul<float>(sin, q_pair));
|
||||
k = add(mul<float>(cos, k), mul<float>(sin, k_pair));
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
inline __device__ void apply_rotary_embedding(float2& q, float2 q_pair, float2& k, float2 k_pair, int tid0, int tid1,
|
||||
int rot_embed_dim, float base, float scale, int t_step, int first_half)
|
||||
{
|
||||
const float2 coef0 = rotary_embedding_coefficient(tid0, rot_embed_dim, base, scale, t_step);
|
||||
const float2 coef1 = rotary_embedding_coefficient(tid1, rot_embed_dim, base, scale, t_step);
|
||||
float2 cos0 = make_float2(coef0.x, coef1.x);
|
||||
float2 sin0 = make_float2(coef0.y, coef1.y);
|
||||
if (first_half)
|
||||
{
|
||||
q = sub(mul<float2>(cos0, q), mul<float2>(sin0, q_pair));
|
||||
k = sub(mul<float2>(cos0, k), mul<float2>(sin0, k_pair));
|
||||
}
|
||||
else
|
||||
{
|
||||
q = add(mul<float2>(cos0, q), mul<float2>(sin0, q_pair));
|
||||
k = add(mul<float2>(cos0, k), mul<float2>(sin0, k_pair));
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename Vec_type, typename Packed_type, typename T>
|
||||
inline __device__ void apply_rotary_embedding_gptneox(Vec_type& q, Vec_type& k, int tidx, int rotary_embedding_dim,
|
||||
float rotary_embedding_base, float rotary_embedding_scale, int t_step, bool first_half)
|
||||
{
|
||||
// 32 threads: each hold VEC_SIZE elements (half)
|
||||
Vec_type q_pair, k_pair;
|
||||
constexpr int VEC_SIZE = sizeof(Vec_type) / sizeof(Packed_type);
|
||||
constexpr int PACKED_ELT_SIZE = sizeof(Packed_type) / sizeof(T);
|
||||
if constexpr (sizeof(Vec_type) == 2)
|
||||
{
|
||||
reinterpret_cast<uint16_t&>(q_pair) = __shfl_xor_sync(0xffffffff, reinterpret_cast<uint16_t&>(q), 16);
|
||||
reinterpret_cast<uint16_t&>(k_pair) = __shfl_xor_sync(0xffffffff, reinterpret_cast<uint16_t&>(k), 16);
|
||||
}
|
||||
else if constexpr (sizeof(Vec_type) == 4)
|
||||
{
|
||||
reinterpret_cast<unsigned int&>(q_pair) = __shfl_xor_sync(0xffffffff, reinterpret_cast<unsigned int&>(q), 16);
|
||||
reinterpret_cast<unsigned int&>(k_pair) = __shfl_xor_sync(0xffffffff, reinterpret_cast<unsigned int&>(k), 16);
|
||||
}
|
||||
else if constexpr (sizeof(Vec_type) >= 8)
|
||||
{
|
||||
#pragma unroll
|
||||
for (int vec_id = 0; vec_id < sizeof(Vec_type) / 8; vec_id++)
|
||||
{
|
||||
reinterpret_cast<unsigned long*>(&q_pair)[vec_id]
|
||||
= __shfl_xor_sync(0xffffffff, reinterpret_cast<unsigned long*>(&q)[vec_id], 16);
|
||||
reinterpret_cast<unsigned long*>(&k_pair)[vec_id]
|
||||
= __shfl_xor_sync(0xffffffff, reinterpret_cast<unsigned long*>(&k)[vec_id], 16);
|
||||
}
|
||||
}
|
||||
|
||||
const int half_rotary_dim = rotary_embedding_dim / 2;
|
||||
|
||||
#pragma unroll
|
||||
for (int elt_id = 0; elt_id < VEC_SIZE; elt_id++)
|
||||
{
|
||||
// Pack two elements for calculation (only one if each the thread only gets one element)
|
||||
// Assume the head size (or rotary embedding) is multiple of 8.
|
||||
const int rotary_emd_pos0_id
|
||||
= (tidx * VEC_SIZE * PACKED_ELT_SIZE + elt_id * PACKED_ELT_SIZE + 0 - int(!first_half) * half_rotary_dim)
|
||||
* 2;
|
||||
const int rotary_emd_pos1_id
|
||||
= (tidx * VEC_SIZE * PACKED_ELT_SIZE + elt_id * PACKED_ELT_SIZE + 1 - int(!first_half) * half_rotary_dim)
|
||||
* 2;
|
||||
|
||||
const bool valid_rotary_pos = rotary_emd_pos1_id < rotary_embedding_dim;
|
||||
|
||||
Packed_type q_ = reinterpret_cast<Packed_type*>(&q)[elt_id];
|
||||
Packed_type q_pair_ = reinterpret_cast<Packed_type*>(&q_pair)[elt_id];
|
||||
Packed_type k_ = reinterpret_cast<Packed_type*>(&k)[elt_id];
|
||||
Packed_type k_pair_ = reinterpret_cast<Packed_type*>(&k_pair)[elt_id];
|
||||
|
||||
apply_rotary_embedding(q_, q_pair_, k_, k_pair_, rotary_emd_pos0_id, rotary_emd_pos1_id, rotary_embedding_dim,
|
||||
rotary_embedding_base, rotary_embedding_scale, t_step, first_half);
|
||||
|
||||
if (valid_rotary_pos)
|
||||
{
|
||||
reinterpret_cast<Packed_type*>(&q)[elt_id] = q_;
|
||||
reinterpret_cast<Packed_type*>(&k)[elt_id] = k_;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
inline __device__ void convert_from_float(float* dst, float src)
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*
|
||||
* Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*
|
||||
* Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
@ -204,7 +204,7 @@ __global__ void computeAttentionMask(AttentionMaskDataType* attentionMask, const
|
||||
case AttentionMaskType::BIDIRECTIONALGLM:
|
||||
// clang-format off
|
||||
isValid = (colIdx < seqLength - 1) ||
|
||||
(rowIdx == maxSeqLength - 1 && colIdx == maxSeqLength - 1);
|
||||
(rowIdx == seqLength - 1 && colIdx == seqLength - 1);
|
||||
// clang-format on
|
||||
// seq_length==4, max_seq_len==5
|
||||
// 1 1 1 1 0
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*
|
||||
* Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
@ -83,6 +83,8 @@ struct BuildDecoderInfoParams
|
||||
// The kv cache capacity.
|
||||
// We will apply the limited_length_causal mask when there are not enough kv cache.
|
||||
int attentionWindowSize;
|
||||
// The number of sink tokens in the kv cache.
|
||||
int sinkTokenLength;
|
||||
// The number of tokens in total. It's \sum_{ii=0}^{batchSize} seqLengths[ii].
|
||||
int numTokens;
|
||||
// The type of attention.
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user