Update TensorRT-LLM (#787)

* Update TensorRT-LLM

---------

Co-authored-by: Shixiaowei02 <39303645+Shixiaowei02@users.noreply.github.com>
This commit is contained in:
Kaiyu Xie 2024-01-02 17:54:32 +08:00 committed by GitHub
parent d37b507f41
commit deaae40bd7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
526 changed files with 9173 additions and 4375 deletions

View File

@ -45,4 +45,5 @@ repos:
args:
- --skip=".git,3rdparty"
- --exclude-file=examples/whisper/tokenizer.py
- --ignore-words-list=rouge,inout,atleast,strat
- --ignore-words-list=rouge,inout,atleast,strat,nd
exclude: 'tests/llm-test-defs/turtle/test_input_files'

View File

@ -108,9 +108,7 @@ concepts used in TensorRT-LLM, we recommend you to read the following
## Installation
*For Windows installation, see [`Windows`](windows/README.md).*
TensorRT-LLM must be built from source, instructions can be found
The documentation for installing TensorRT-LLM can be found
[here](./docs/source/installation.md). An image of a Docker container with
TensorRT-LLM and its Triton Inference Server Backend will be made available
soon.
@ -118,6 +116,8 @@ soon.
The remaining commands in that document must be executed from the TensorRT-LLM
container.
*For Windows installation, see [`Windows`](windows/README.md).*
## Quick Start
To create a TensorRT engine for an existing model, there are 3 steps:
@ -206,13 +206,17 @@ Lovelace architectures. Certain limitations may, however, apply.
Various numerical precisions are supported in TensorRT-LLM. The support for
some of those numerical features require specific architectures:
| | FP32 | FP16 | BF16 | FP8 | INT8 | INT4 |
| :------------------ | :--- | :--- | :--- | :--- | :--- | :--- |
| Volta (SM70) | Y | Y | N | N | Y | Y |
| Turing (SM75) | Y | Y | N | N | Y | Y |
| Ampere (SM80, SM86) | Y | Y | Y | N | Y | Y |
| Ada-Lovelace (SM89) | Y | Y | Y | Y | Y | Y |
| Hopper (SM90) | Y | Y | Y | Y | Y | Y |
| | FP32 | FP16 | BF16 | FP8 | INT8 | INT4 |
| :------------------ | :--- | :--- | :--- | :--- | :---- | :---- |
| Volta (SM70) | Y | Y | N | N | Y (1) | Y (2) |
| Turing (SM75) | Y | Y | N | N | Y (1) | Y (2) |
| Ampere (SM80, SM86) | Y | Y | Y | N | Y | Y (3) |
| Ada-Lovelace (SM89) | Y | Y | Y | Y | Y | Y |
| Hopper (SM90) | Y | Y | Y | Y | Y | Y |
(1) INT8 SmoothQuant is not supported on SM70 and SM75.<br>
(2) INT4 AWQ and GPTQ are not supported on SM < 80.<br>
(3) INT4 AWQ and GPTQ with FP8 activations require SM >= 89.
In this release of TensorRT-LLM, the support for FP8 and quantized data types
(INT8 or INT4) is not implemented for all the models. See the
@ -267,6 +271,7 @@ The list of supported models is:
* [MPT](examples/mpt)
* [mT5](examples/enc_dec)
* [OPT](examples/opt)
* [Phi-1.5/Phi-2](examples/phi)
* [Qwen](examples/qwen)
* [Replit Code](examples/mpt)
* [SantaCoder](examples/gpt)

View File

@ -1,4 +1,4 @@
# SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION &
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION &
# AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not

View File

@ -9,7 +9,7 @@ multiple GPUs or multiple nodes with multiple GPUs.
Please follow the [`installation document`](../../docs/source/installation.md) to build TensorRT-LLM.
Note that the benchmarking source code for C++ runtime is not built by default, you can use the argument `--benchmarks` in [`build_wheel.py`](../../scripts/build_wheel.py) to build that.
Note that the benchmarking source code for C++ runtime is not built by default, you can use the argument `--benchmarks` in [`build_wheel.py`](source:scripts/build_wheel.py) to build the corresponding executable.
Windows users: Follow the
[`Windows installation document`](../../windows/README.md)
@ -22,7 +22,7 @@ instead, and be sure to set DLL paths as specified in
Before you launch C++ benchmarking, please make sure that you have already built engine(s) using TensorRT-LLM API, C++ benchmarking code cannot generate engine(s) for you.
You can use the [`build.py`](../python/build.py) script to build the engine(s). Alternatively, if you have already benchmarked Python Runtime, you can reuse the engine(s) built by benchmarking code, please see that [`document`](../python/README.md).
You can use the [`build.py`](source:benchmarks/python/build.py) script to build the engine(s). Alternatively, if you have already benchmarked Python Runtime, you can reuse the engine(s) built previously, please see that [`document`](../python/README.md).
#### Launch benchmarking

View File

@ -1,5 +1,5 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");

View File

@ -1,5 +1,5 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
@ -267,13 +267,19 @@ public:
batch_scheduler::SchedulerPolicy schedulerPolicy, TrtGptModelOptionalParams const& optionalParams,
std::shared_ptr<Recorder> recorder, std::optional<uint64_t> terminateReqId)
{
ReturnBatchManagerStatsCallback iterationDataCallback{nullptr};
if (optionalParams.logIterationData)
{
iterationDataCallback = [this](const std::string& s) { return TLLM_LOG_INFO(s); };
}
mBatchManager = std::make_shared<GptManager>(
trtEnginePath, modelType, maxBeamWidth, schedulerPolicy,
[this](int max_num_requests) { return getInferenceRequests(max_num_requests); },
[this](uint64_t requestId, std::list<NamedTensor> response_tensors, bool final_response,
const std::string& errMsg)
{ return sendResponse(requestId, response_tensors, final_response, errMsg); },
nullptr, nullptr, optionalParams, terminateReqId);
nullptr, iterationDataCallback, optionalParams, terminateReqId);
mRecorder = recorder;
mTerminateReqId = terminateReqId;
}
@ -459,8 +465,14 @@ std::shared_ptr<InferenceRequest> makeRequest(std::uint64_t reqId,
request->setMaxNewTokens(
bufferManager.copyFrom(&request_output_len, ITensor::makeShape({1, 1}), MemoryType::kPINNED));
request->setBeamWidth(beamWidthTensor);
request->setEndId(eosId);
request->setPadId(padId);
if (eosId != nullptr)
{
request->setEndId(eosId);
}
if (padId != nullptr)
{
request->setPadId(padId);
}
return request;
}
@ -571,6 +583,8 @@ int main(int argc, char* argv[])
options.add_options()("log_level", "Choose log level between verbose/info/warning/error/internal_error.",
cxxopts::value<std::string>()->default_value("error"));
options.add_options()(
"log_iteration_data", "On each decoder iteration, print batch state metadata.", cxxopts::value<bool>());
auto result = options.parse(argc, argv);
@ -618,6 +632,11 @@ int main(int argc, char* argv[])
{
optionalParams.enableTrtOverlap = result["enable_trt_overlap"].as<bool>();
}
// Argument: Enable batch stats output
if (result.count("log_iteration_data"))
{
optionalParams.logIterationData = result["log_iteration_data"].as<bool>();
}
std::optional<int32_t> padId;
// Argument: Padding token id

View File

@ -1,5 +1,5 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
@ -253,6 +253,7 @@ int main(int argc, char* argv[])
options.add_options()("gen_micro_batch_size", "Batch size for generation phase.", cxxopts::value<int>());
options.add_options()("max_attention_window", "Max kv cache length per sequence.", cxxopts::value<int>());
options.add_options()("max_tokens_in_paged_kvcache", "Max tokens in paged K-V Cache.", cxxopts::value<int>());
options.add_options()("sink_token_len", "Sink token length in kv cache per sequence.", cxxopts::value<int>());
options.add_options()(
"kv_cache_free_gpu_mem_fraction", "K-V Cache Free Gpu Mem Fraction.", cxxopts::value<float>());
@ -353,6 +354,11 @@ int main(int argc, char* argv[])
{
sessionConfig.kvCacheConfig.maxAttentionWindow = result["max_attention_window"].as<int>();
}
// Argument: Sink token length
if (result.count("sink_token_len"))
{
sessionConfig.kvCacheConfig.sinkTokenLength = result["sink_token_len"].as<int>();
}
// Argument: K-V Cache Free Gpu Mem Fraction
if (result.count("kv_cache_free_gpu_mem_fraction"))
{

View File

@ -1,4 +1,4 @@
# SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");

View File

@ -1,4 +1,4 @@
# SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
@ -231,7 +231,6 @@ _allowed_configs = {
builder_opt=None,
pre_norm=False,
do_layer_norm_before=False,
use_custom_all_reduce=False,
)),
"opt_2.7b":
ModelConfig(name="opt_2.7b",
@ -250,7 +249,6 @@ _allowed_configs = {
builder_opt=None,
pre_norm=False,
do_layer_norm_before=True,
use_custom_all_reduce=False,
)),
"opt_6.7b":
ModelConfig(name="opt_6.7b",
@ -269,7 +267,6 @@ _allowed_configs = {
builder_opt=None,
pre_norm=False,
do_layer_norm_before=True,
use_custom_all_reduce=False,
)),
"opt_66b":
ModelConfig(name="opt_66b",
@ -288,7 +285,6 @@ _allowed_configs = {
builder_opt=None,
pre_norm=True,
do_layer_norm_before=True,
use_custom_all_reduce=False,
)),
"llama_7b":
ModelConfig(name="llama_7b",
@ -515,7 +511,6 @@ _allowed_configs = {
max_output_len=200,
builder_opt=None,
remove_input_padding=False,
use_custom_all_reduce=False,
)),
"bloom_560m":
ModelConfig(name="bloom_560m",
@ -532,7 +527,6 @@ _allowed_configs = {
max_input_len=1024,
max_output_len=1024,
builder_opt=None,
use_custom_all_reduce=False,
)),
"bloom_176b":
ModelConfig(name="bloom_176b",
@ -549,7 +543,6 @@ _allowed_configs = {
max_input_len=1024,
max_output_len=1024,
builder_opt=None,
use_custom_all_reduce=False,
)),
"bert_base":
ModelConfig(name="bert_base",
@ -596,7 +589,7 @@ _allowed_configs = {
num_heads=32,
hidden_size=2048,
vocab_size=50304,
hidden_act=None,
hidden_act='gelu',
n_positions=2048,
max_batch_size=256,
max_input_len=1024,
@ -617,7 +610,7 @@ _allowed_configs = {
num_kv_heads=1,
hidden_size=4544,
vocab_size=65024,
hidden_act=None,
hidden_act='gelu',
n_positions=2048,
max_batch_size=128,
max_input_len=512,
@ -638,7 +631,7 @@ _allowed_configs = {
num_kv_heads=8,
hidden_size=8192,
vocab_size=65024,
hidden_act=None,
hidden_act='gelu',
n_positions=2048,
max_batch_size=64,
max_input_len=512,
@ -659,7 +652,7 @@ _allowed_configs = {
num_kv_heads=8,
hidden_size=14848,
vocab_size=65024,
hidden_act=None,
hidden_act='gelu',
n_positions=2048,
max_batch_size=8,
max_input_len=1024,
@ -921,6 +914,112 @@ _allowed_configs = {
max_output_len=200,
builder_opt=None,
)),
"baichuan_7b":
ModelConfig(name="baichuan_7b",
family="baichuan_7b",
benchmark_type="gpt",
build_config=BuildConfig(
num_layers=32,
num_heads=32,
hidden_size=4096,
vocab_size=64000,
hidden_act='silu',
n_positions=4096,
inter_size=11008,
max_batch_size=128,
max_input_len=512,
max_output_len=200,
builder_opt=None,
)),
"baichuan2_7b_chat":
ModelConfig(name="baichuan2_7b_chat",
family="baichuan_7b",
benchmark_type="gpt",
build_config=BuildConfig(
num_layers=32,
num_heads=32,
hidden_size=4096,
vocab_size=125696,
hidden_act='silu',
n_positions=4096,
inter_size=11008,
max_batch_size=128,
max_input_len=512,
max_output_len=200,
builder_opt=None,
)),
"baichuan_13b_chat":
ModelConfig(name="baichuan_13b_chat",
family="baichuan_13b",
benchmark_type="gpt",
build_config=BuildConfig(
num_layers=40,
num_heads=40,
hidden_size=5120,
vocab_size=64000,
hidden_act='silu',
n_positions=4096,
inter_size=13696,
max_batch_size=64,
max_input_len=512,
max_output_len=200,
builder_opt=None,
)),
"baichuan2_13b_chat":
ModelConfig(name="baichuan2_13b_chat",
family="baichuan_13b",
benchmark_type="gpt",
build_config=BuildConfig(
num_layers=40,
num_heads=40,
hidden_size=5120,
vocab_size=125696,
hidden_act='silu',
n_positions=4096,
inter_size=13696,
max_batch_size=64,
max_input_len=512,
max_output_len=200,
builder_opt=None,
)),
"internlm_chat_7b":
ModelConfig(name="internlm_chat_7b",
family="internlm",
benchmark_type="gpt",
build_config=BuildConfig(
num_layers=32,
num_heads=32,
num_kv_heads=32,
hidden_size=4096,
vocab_size=103168,
hidden_act='silu',
n_positions=2048,
inter_size=11008,
max_batch_size=128,
max_input_len=512,
max_output_len=200,
builder_opt=None,
bias=True,
)),
"internlm_chat_20b":
ModelConfig(name="internlm_chat_20b",
family="internlm",
benchmark_type="gpt",
build_config=BuildConfig(
num_layers=60,
num_heads=40,
num_kv_heads=40,
hidden_size=5120,
vocab_size=103168,
hidden_act='silu',
n_positions=4096,
inter_size=13824,
max_batch_size=64,
max_input_len=512,
max_output_len=200,
builder_opt=None,
bias=False,
)),
}

View File

@ -1,4 +1,4 @@
# SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");

View File

@ -1,4 +1,4 @@
# SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
@ -89,7 +89,14 @@ def parse_arguments():
type=float,
default="0",
help=('Specify Top-P value of decoding.'))
parser.add_argument(
'--profiling_verbosity',
type=str,
default='layer_names_only',
choices=['layer_names_only', 'detailed', 'none'],
help=
'The profiling verbosity for the generated TRT engine. Set to detailed can inspect tactic choices and kernel parameters.'
)
parser.add_argument(
'--log_level',
type=str,
@ -180,6 +187,10 @@ def parse_arguments():
help=
'Quick sanity check with num_layer=1; will be silently ignored if --engine_dir is specified.'
)
parser.add_argument('--strongly_typed',
default=False,
action='store_true',
help='This option will reduce the building time.')
parser.add_argument('--csv',
default=False,

View File

@ -1,4 +1,4 @@
# SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");

View File

@ -1,4 +1,4 @@
# SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");

View File

@ -1,4 +1,4 @@
# SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
@ -122,6 +122,10 @@ def parse_arguments():
default=False,
action='store_true',
help="Build engines serially")
parser.add_argument('--strongly_typed',
default=False,
action='store_true',
help='This option will reduce the building time.')
parser.add_argument(
'--rank',
@ -143,23 +147,19 @@ def parse_arguments():
def get_quant_mode(quantization):
quant_mode = QuantMode(0)
strongly_typed = False
use_smooth_quant = False
per_token = False
per_channel = False
weight_only_precision = 'int8'
if quantization == "fp8":
strongly_typed = True
quant_mode = quant_mode.set_fp8_qdq()
quant_mode = quant_mode.set_fp8_kv_cache()
elif quantization == "fp8_gemm":
strongly_typed = True
quant_mode = quant_mode.set_fp8_qdq()
elif quantization == "fp8_kv_cache":
strongly_typed = True
quant_mode = quant_mode.set_fp8_kv_cache()
elif quantization == "int8_sq_per_tensor":
@ -205,7 +205,7 @@ def get_quant_mode(quantization):
else:
raise Exception(f'Unexpected quantization: {quantization}')
return quant_mode, strongly_typed, use_smooth_quant, weight_only_precision
return quant_mode, use_smooth_quant, weight_only_precision
def build_gpt(args):
@ -223,6 +223,9 @@ def build_gpt(args):
if not args.serial_build:
torch.cuda.set_device(runtime_rank)
strongly_typed = args.strongly_typed
if args.quantization is not None and "fp8" in args.quantization:
strongly_typed = True
num_kv_heads = build_config['num_heads'] \
if build_config['num_kv_heads'] is None else build_config['num_kv_heads']
apply_query_key_layer_scaling = False
@ -234,7 +237,7 @@ def build_gpt(args):
if args.max_output_len is None else args.max_output_len
max_beam_width = build_config['max_beam_width'] \
if args.max_beam_width is None else args.max_beam_width
quant_mode, strongly_typed, use_smooth_quant, weight_only_precision = get_quant_mode(
quant_mode, use_smooth_quant, weight_only_precision = get_quant_mode(
args.quantization)
use_weight_only = quant_mode.is_weight_only()
@ -243,6 +246,7 @@ def build_gpt(args):
name=args.model,
precision=args.dtype,
timing_cache=None,
profiling_verbosity=args.profiling_verbosity,
tensor_parallel=world_size, # TP only
parallel_build=True,
num_layers=build_config['num_layers'],
@ -468,21 +472,117 @@ def build_gpt(args):
config = PretrainedConfig.from_dict(config)
tensorrt_llm_model = tensorrt_llm.models.BloomForCausalLM(config)
elif family == "falcon":
tensorrt_llm_model = tensorrt_llm.models.FalconForCausalLM(
config = {
'architecture':
'FalconForCausalLM',
'dtype':
args.dtype,
'num_hidden_layers':
build_config['num_layers'],
'num_attention_heads':
build_config['num_heads'],
'num_key_value_heads':
build_config['num_heads'] if build_config['num_kv_heads'] is None
else build_config['num_kv_heads'],
'hidden_size':
build_config['hidden_size'],
'vocab_size':
build_config['vocab_size'],
'position_embedding_type':
'alibi_with_scale'
if build_config['use_alibi'] else 'rope_gpt_neox',
'max_position_embeddings':
build_config['n_positions'],
'hidden_act':
build_config['hidden_act'],
'quantization': {
'use_smooth_quant':
quant_mode.has_act_and_weight_quant(),
'per_channel':
quant_mode.has_per_channel_scaling(),
'per_token':
quant_mode.has_per_token_dynamic_scaling(),
'per_group':
quant_mode.has_per_group_scaling(),
'group_size':
128,
'int8_kv_cache':
quant_mode.has_int8_kv_cache(),
'enable_fp8':
quant_mode.has_fp8_qdq(),
'fp8_kv_cache':
quant_mode.has_fp8_kv_cache(),
'use_weight_only':
quant_mode.is_weight_only(),
'weight_only_precision':
'int8' if quant_mode.is_int8_weight_only() else 'int4',
},
'mapping': {
'world_size': world_size,
'tp_size': world_size
},
'bias':
build_config['bias'],
'parallel_attention':
build_config['parallel_attention'],
'new_decoder_architecture':
build_config['new_decoder_architecture'],
}
if quant_mode.is_weight_only() and quant_mode.has_per_group_scaling():
config['quantization'].update({
'zero': False,
'pre_quant_scale': True,
'exclude_modules': [],
})
config = PretrainedConfig.from_dict(config)
tensorrt_llm_model = tensorrt_llm.models.FalconForCausalLM(config)
elif family == "baichuan_7b":
tensorrt_llm_model = tensorrt_llm.models.BaichuanForCausalLM(
num_layers=build_config['num_layers'],
num_heads=build_config['num_heads'],
num_kv_heads=None,
hidden_size=build_config['hidden_size'],
vocab_size=build_config['vocab_size'],
hidden_act=build_config['hidden_act'],
max_position_embeddings=build_config['n_positions'],
position_embedding_type=PositionEmbeddingType.rope_gpt_neox,
dtype=kv_dtype,
mlp_hidden_size=build_config['inter_size'],
mapping=tensorrt_llm.Mapping(world_size=world_size,
tp_size=world_size),
quant_mode=quant_mode)
elif family == "baichuan_13b":
tensorrt_llm_model = tensorrt_llm.models.BaichuanForCausalLM(
num_layers=build_config['num_layers'],
num_heads=build_config['num_heads'],
num_kv_heads=None,
hidden_size=build_config['hidden_size'],
vocab_size=build_config['vocab_size'],
hidden_act=build_config['hidden_act'],
max_position_embeddings=build_config['n_positions'],
position_embedding_type=PositionEmbeddingType.alibi,
dtype=kv_dtype,
mlp_hidden_size=build_config['inter_size'],
mapping=tensorrt_llm.Mapping(world_size=world_size,
tp_size=world_size),
quant_mode=quant_mode)
elif family == "internlm":
tensorrt_llm_model = tensorrt_llm.models.LLaMAForCausalLM(
num_layers=build_config['num_layers'],
num_heads=build_config['num_heads'],
num_kv_heads=num_kv_heads,
hidden_size=build_config['hidden_size'],
vocab_size=build_config['vocab_size'],
hidden_act=build_config['hidden_act'],
max_position_embeddings=build_config['n_positions'],
dtype=kv_dtype,
bias=build_config['bias'],
quant_mode=quant_mode,
use_alibi=build_config['use_alibi'],
new_decoder_architecture=build_config['new_decoder_architecture'],
parallel_attention=build_config['parallel_attention'],
mlp_hidden_size=build_config['inter_size'],
mapping=tensorrt_llm.Mapping(world_size=world_size,
tp_size=world_size))
tp_size=world_size), # TP only
quant_mode=quant_mode,
embedding_sharding_dim=1,
use_fused_mlp=False,
attn_bias=build_config['bias'])
else:
raise Exception(f'Unexpected model: {args.model}')
@ -501,7 +601,7 @@ def build_gpt(args):
"zero": True,
"pre_quant_scale": False,
}
if family not in ['opt', 'bloom']:
if family not in ['opt', 'bloom', 'falcon']:
tensorrt_llm_model = quantize_model(tensorrt_llm_model, quant_mode,
**quant_kwargs)
@ -550,7 +650,7 @@ def build_gpt(args):
max_input_len,
max_output_len, True,
max_beam_width)
if family in ['opt', 'bloom']:
if family in ['opt', 'bloom', 'falcon']:
tensorrt_llm_model(**inputs)
else:
tensorrt_llm_model(*inputs)
@ -604,6 +704,7 @@ def build_bert(args):
name=args.model,
precision=args.dtype,
timing_cache=None,
profiling_verbosity=args.profiling_verbosity,
tensor_parallel=world_size, # TP only
parallel_build=True,
num_layers=build_config['num_layers'],

View File

@ -1,4 +1,4 @@
# SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");

View File

@ -1,4 +1,4 @@
# SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
@ -56,7 +56,7 @@ class GPTBenchmark(BaseBenchmark):
if args.max_output_len is not None:
self.max_output_len = args.max_output_len
self.quant_mode, _, _, _ = get_quant_mode(args.quantization)
self.quant_mode, _, _ = get_quant_mode(args.quantization)
self.enable_fp8 = self.quant_mode.has_fp8_qdq()
self.fp8_kv_cache = self.quant_mode.has_fp8_kv_cache()

View File

@ -1,4 +1,4 @@
# SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");

View File

@ -1,5 +1,5 @@
/*
* Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/*
* Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/*
* Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/*
* Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/*
* Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -170,6 +170,7 @@ public:
\
void set##funcName(TensorPtr const& tensor) \
{ \
TLLM_CHECK_WITH_INFO(tensor, "Cannot set nullptr when calling %s", __FUNCTION__); \
mInputTensors[tensorName] = tensor; \
}

View File

@ -1,5 +1,5 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
@ -31,9 +31,11 @@ public:
explicit KvCacheConfig(std::optional<SizeType> maxTokens = std::nullopt,
std::optional<SizeType> maxAttentionWindow = std::nullopt,
std::optional<SizeType> sinkTokenLength = std::nullopt,
std::optional<float> freeGpuMemoryFraction = std::nullopt, bool enableBlockReuse = false)
: maxTokens{maxTokens}
, maxAttentionWindow{maxAttentionWindow}
, sinkTokenLength{sinkTokenLength}
, freeGpuMemoryFraction{freeGpuMemoryFraction}
, enableBlockReuse(enableBlockReuse)
{
@ -41,10 +43,9 @@ public:
std::optional<SizeType> maxTokens;
std::optional<SizeType> maxAttentionWindow;
std::optional<SizeType> sinkTokenLength;
std::optional<float> freeGpuMemoryFraction;
bool enableBlockReuse;
static constexpr auto kDefaultGpuMemFraction = 0.85f;
};
} // namespace tensorrt_llm::batch_manager::kv_cache_manager

View File

@ -1,5 +1,5 @@
/*
* Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -129,6 +129,8 @@ public:
[[nodiscard]] bool isFull() const;
[[nodiscard]] bool isShared() const;
private:
// Linear index of block in pool
SizeType mBlockIdx;
@ -199,6 +201,11 @@ public:
mCacheBlockIds.at(beamIdx).push_back(blockIdx);
}
void changeCacheBlock(SizeType beamIdx, SizeType pagedBlockIdx, SizeType blockIdx)
{
mCacheBlockIds.at(beamIdx).at(pagedBlockIdx) = blockIdx;
}
void clearCacheBlocks()
{
for (auto& beamBlockIds : mCacheBlockIds)
@ -259,12 +266,14 @@ public:
void addSequence(GenerationRequest& sequence, SizeType inputLength, std::shared_ptr<LlmRequest> const& llmRequest);
//! \brief Assign blocks for new sequence. Does not try to reuse blocks.
void addSequence(GenerationRequest& sequence, SizeType inputLength, bool enableCyclicKvCache);
void addSequence(GenerationRequest& sequence, SizeType numBlocks, SizeType unsharedBlockIdx);
//! \brief Allocate new block for each beam of the sequence.
//! \details Might free cached blocks if no free blocks are available.
void allocateBlock(GenerationRequest& sequence, bool shareAmongBeams = false);
void replaceSharedBlock(GenerationRequest& sequence, SizeType blockIdx);
//! \brief Release blocks of the sequence. Store blocks for reuse if llmReqeust is provided.
void releaseBlocks(GenerationRequest& sequence, std::shared_ptr<LlmRequest> const& llmRequest = nullptr);
@ -354,8 +363,8 @@ public:
KVCacheManager(SizeType numLayers, SizeType numHeads, SizeType numKvHeads, SizeType hiddenSize,
SizeType tokensPerBlock, SizeType maxNumBlocks, SizeType maxNumSequences, SizeType maxBeamWidth,
SizeType maxBlocksPerSeq, SizeType maxAttentionWindow, nvinfer1::DataType dtype, CudaStreamPtr stream,
bool enableBlockReuse = false);
SizeType maxBlocksPerSeq, SizeType maxAttentionWindow, SizeType sinkTokenLength, bool useOneMoreBlock,
nvinfer1::DataType dtype, CudaStreamPtr stream, bool enableBlockReuse = false);
void startScheduling();
@ -466,6 +475,7 @@ private:
void resetBlockPointers(SizeType seqSlotIdx, SizeType beamWidth);
void cacheBlockPointers(GenerationRequest const& seq, SizeType seqSlotIdx);
void cacheNewBlockPointers(GenerationRequest const& seq, SizeType seqSlotIdx);
void updateNewBlockPointer(const GenerationRequest& seq, SizeType seqSlotIdx, SizeType blockIdx);
private:
// Number of elements per one blocks
@ -479,6 +489,14 @@ private:
// Maximum kv cache length per sequence
// Enable cyclic kv cache when it exceeds
SizeType mMaxAttentionWindow;
// Sink token length in the kv cache per sequence
SizeType mSinkTokenLength;
// Bubble token length
SizeType mBubbleLength;
// Maximum token length (including bubble)
SizeType mMaxTokenNum;
// Number of tokens in the sink blocks
SizeType mSinkBlockTokenLength;
// Pools
std::vector<runtime::ITensor::SharedPtr> mPools;
// Block manager

View File

@ -1,5 +1,5 @@
/*
* Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/*
* Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
@ -34,12 +34,14 @@ public:
explicit TrtGptModelOptionalParams(KvCacheConfig const& kvCacheConfig = KvCacheConfig{},
std::optional<SizeType> maxNumSequences = std::nullopt, bool enableTrtOverlap = true,
std::optional<std::vector<SizeType>> const& deviceIds = std::nullopt, bool normalizeLogProbs = true)
std::optional<std::vector<SizeType>> const& deviceIds = std::nullopt, bool normalizeLogProbs = true,
bool logIterationData = false)
: kvCacheConfig{kvCacheConfig}
, maxNumSequences{maxNumSequences}
, enableTrtOverlap{enableTrtOverlap}
, deviceIds(deviceIds)
, normalizeLogProbs{normalizeLogProbs}
, logIterationData{logIterationData}
{
}
@ -48,6 +50,7 @@ public:
bool enableTrtOverlap;
std::optional<std::vector<SizeType>> deviceIds;
bool normalizeLogProbs;
bool logIterationData;
};
} // namespace tensorrt_llm::batch_manager

View File

@ -1,5 +1,5 @@
/*
* Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/*
* Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/*
* Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/*
* Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/*
* Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -29,11 +29,12 @@ class DecodingInput
public:
using TensorPtr = std::shared_ptr<ITensor const>;
DecodingInput(
SizeType maxLength, SizeType maxAttentionWindow, SizeType batchSize, TensorPtr logits, TensorPtr endIds)
DecodingInput(SizeType maxLength, SizeType maxAttentionWindow, SizeType sinkTokenLength, SizeType batchSize,
TensorPtr logits, TensorPtr endIds)
: step{maxLength}
, maxLength{maxLength}
, maxAttentionWindow{maxAttentionWindow}
, sinkTokenLength{sinkTokenLength}
, batchSize{batchSize}
, logits{std::move(logits)}
, endIds{std::move(endIds)}
@ -46,6 +47,7 @@ public:
SizeType step;
SizeType maxLength;
SizeType maxAttentionWindow;
SizeType sinkTokenLength;
SizeType batchSize;
TensorPtr logits; // [batchSize, beamWidth, vocabSizePadded], on gpu
TensorPtr endIds; // [batchSize * beamWidth], on gpu

View File

@ -1,5 +1,5 @@
/*
* Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/*
* Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/*
* Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/*
* Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/*
* Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -45,8 +45,8 @@ public:
GptDecoderBatch(std::size_t vocabSize, std::size_t vocabSizePadded, CudaStreamPtr stream);
//! Setup the decoder before calling `forward()`
void setup(SizeType maxBatchSize, SizeType maxBeamWidth, SizeType maxAttentionWindow, SizeType maxSequenceLength,
SizeType maxTokensPerStep, nvinfer1::DataType dtype) override;
void setup(SizeType maxBatchSize, SizeType maxBeamWidth, SizeType maxAttentionWindow, SizeType sinkTokenLength,
SizeType maxSequenceLength, SizeType maxTokensPerStep, nvinfer1::DataType dtype) override;
//! @brief Initialize the decoder at `batchIdx` with a new `request`.
void newRequest(
@ -201,6 +201,7 @@ private:
// decoding accept by logits kernel, on gpu
SizeType mMaxSequenceLength{};
SizeType mMaxAttentionWindow{};
SizeType mSinkTokenLength{};
SizeType mActualBatchSize{};
SizeType mMaxTokensPerStep{};
};

View File

@ -1,5 +1,5 @@
/*
* Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/*
* Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/*
* Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -145,10 +145,10 @@ private:
void createContexts();
void createBuffers(SizeType numMicroBatches);
void createDecoders(SizeType batchSize, SizeType beamWidth, SizeType maxAttentionWindow, SizeType maxSequenceLength,
nvinfer1::DataType logitsType, bool decoderPerRequest, SizeType numMicroBatches);
void createDecoders(SizeType batchSize, SizeType beamWidth, SizeType maxAttentionWindow, SizeType sinkTokenLength,
SizeType maxSequenceLength, nvinfer1::DataType logitsType, bool decoderPerRequest, SizeType numMicroBatches);
void createKvCacheManager(SizeType batchSize, SizeType beamWidth, SizeType maxAttentionWindow,
SizeType maxSequenceLength, KvCacheConfig const& config);
SizeType sinkTokenLength, SizeType maxSequenceLength, KvCacheConfig const& config);
void createCustomAllReduceWorkspace(SizeType batchSize, SizeType beamWidth, SizeType maxSequenceLength);
void executeContextStep(std::vector<GenerationInput> const& microBatchesInputs,
@ -257,6 +257,7 @@ private:
SizeType mDecoderMaxSequenceLength{};
SizeType mDecoderMaxAttentionWindow{};
SizeType mDecoderSinkTokenLength{};
LoggerPtr mLogger;
std::shared_ptr<TllmRuntime> mRuntime;

View File

@ -1,5 +1,5 @@
/*
* Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/*
* Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/*
* Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -75,7 +75,7 @@ public:
//! Setup the decoder before calling `forward()`, also calls reshapeBuffers
virtual void setup(SizeType maxBatchSize, SizeType maxBeamWidth, SizeType maxAttentionWindow,
SizeType maxSequenceLength, SizeType maxTokensPerStep, nvinfer1::DataType dtype)
SizeType sinkTokenLength, SizeType maxSequenceLength, SizeType maxTokensPerStep, nvinfer1::DataType dtype)
= 0;
//! @brief Initialize the decoder with new batch of inputs.

View File

@ -1,5 +1,5 @@
/*
* Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.

View File

@ -1,6 +1,6 @@
/*
* Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/*
* Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/*
* Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/*
* Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/*
* Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/*
* Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.

View File

@ -1,4 +1,4 @@
# SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION &
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION &
# AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:b5007e359b3c93562b81ea6dbf414a6cb98a88de2f82aebb740a044a2deb3946
size 1846872
oid sha256:327edb4d1e50392467f194cb8ccacad39d2d872d1f89aef79cafa203171a4734
size 1858074

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:256cd1e9da6a7e77ed2981abbeca7fb18660b9469643791333d9a909c01bc601
size 1860514
oid sha256:e55ee683c569bde1fd18442152b201cf4ebb41bbe21c6c1c6abfc5bac6256e5f
size 1873024

View File

@ -1,3 +1,3 @@
097b6a4d83e0c954f5b6510c0e871e87 libtensorrt_llm_batch_manager_static.a
9c73d43e5b1a94dbd16c94237905719e libtensorrt_llm_batch_manager_static.pre_cxx11.a
0067d2b225abceb20316ddb3a5cddf426ef25160 commit
a4ea2e88effa61e397d984943fd47802 libtensorrt_llm_batch_manager_static.a
94a233ead0aca9a51776a0bab70c59b4 libtensorrt_llm_batch_manager_static.pre_cxx11.a
9251c03ebe80b196becd8ac3abbfa02c2a3273ad commit

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:c95be2b543ce79a591b12569fc5d245b76c72b9e0485b17bfb5f16bc46fa7029
size 1775504
oid sha256:0a8dc8411449452686afc7b4005cdb77905914edc9d5257d5d283b0dfc4eb9aa
size 1790812

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:fe2ddf4d130704e82f0ab299b203408684432e9c111bf481db2b34a9a1023b83
size 1763222
oid sha256:fda2fdc9c2b3672e94f15927b6bbeb5321c436761c8d1b1de96f23f6807a351a
size 1776536

View File

@ -1,2 +1,2 @@
cd974e5d12c72241dec0c6b439f2f7a0 libtensorrt_llm_batch_manager_static.a
4e88c3c8609582273a1b026fadac4abd libtensorrt_llm_batch_manager_static.pre_cxx11.a
1c62064ee5f68d76194bad877504b0d3 libtensorrt_llm_batch_manager_static.a
4d7b4ff0c4c14865a18535cef6117ee9 libtensorrt_llm_batch_manager_static.pre_cxx11.a

View File

@ -1,5 +1,5 @@
/*
* Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/*
* Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/*
* Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/*
* Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/*
* Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/*
* Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");

View File

@ -1,5 +1,5 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");

View File

@ -1,5 +1,5 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");

View File

@ -1,5 +1,5 @@
/*
* Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/*
* Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/*
* Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");

View File

@ -1,5 +1,5 @@
/*
* Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/*
* Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/*
* Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/*
* Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");

View File

@ -1,5 +1,5 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");

View File

@ -1,5 +1,5 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");

View File

@ -1,5 +1,5 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");

View File

@ -1,5 +1,5 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");

View File

@ -1,5 +1,5 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");

View File

@ -1,5 +1,5 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");

View File

@ -1,5 +1,5 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");

View File

@ -1,5 +1,5 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");

View File

@ -1,5 +1,5 @@
/*
* Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/*
* Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.

View File

@ -535,9 +535,13 @@ public:
const KVBlockArray pagedKVCache, const void* cuQSeqlenPtr, const void* cuKVSeqlenPtr, void* outputPtr,
cudaStream_t stream)
{
KVBlockArrayForContextFMHA pagedKVCacheForContextMHA;
pagedKVCacheForContextMHA = KVBlockArrayForContextFMHA(pagedKVCache.mMaxSeqs, pagedKVCache.mMaxBlocksPerSeq,
pagedKVCache.mTokensPerBlock, mPagedKVParams.h_kv * mPagedKVParams.d * sizeof(half));
pagedKVCacheForContextMHA.data = pagedKVCache.data;
mPagedKVParams.q_ptr = qPtr;
mPagedKVParams.tma_desc_paged_kv = reinterpret_cast<cudaTmaDesc*>(pagedKVTmaDesc);
mPagedKVParams.paged_kv_cache = pagedKVCache;
mPagedKVParams.paged_kv_cache = pagedKVCacheForContextMHA;
mPagedKVParams.o_ptr = outputPtr;
mPagedKVParams.cu_q_seqlens = reinterpret_cast<const int*>(cuQSeqlenPtr);
mPagedKVParams.cu_seqlens = reinterpret_cast<const int*>(cuKVSeqlenPtr);

View File

@ -179,7 +179,7 @@ struct Fused_multihead_attention_paged_kv_params_v2
// The Q matrices.
const void* q_ptr;
// Paged KV Cache buffer.
KVBlockArray paged_kv_cache;
KVBlockArrayForContextFMHA paged_kv_cache;
// The O matrix (output).
void* o_ptr;
// The packed mask for random mask.

View File

@ -1,5 +1,5 @@
/*
* Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/*
* Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.

View File

@ -29,7 +29,8 @@ namespace mmha
// Forward declaration of the kernel launcher to avoid including decoderMaskedMultiheadAttentionLaunch.h
template <typename T, typename KVCacheBuffer, typename T_PARAMS, int Dh>
void mmha_launch_kernel(const T_PARAMS& params, const KVCacheBuffer& kv_cache_buffer, const cudaStream_t& stream);
void mmha_launch_kernel(const T_PARAMS& params, const KVCacheBuffer& kv_cache_buffer,
const KVLinearBuffer& shift_k_cache, const cudaStream_t& stream);
} // namespace mmha
@ -37,44 +38,67 @@ namespace
{
#define MMHA_LAUNCH_KERNEL(Dh) \
mmha::mmha_launch_kernel<T, KVCacheBuffer, KERNEL_PARAMS_TYPE, Dh>(params, kv_cache_buffer, stream); \
mmha::mmha_launch_kernel<T, KVCacheBuffer, KERNEL_PARAMS_TYPE, Dh>( \
params, kv_cache_buffer, shift_k_cache, stream); \
break;
template <typename T, typename KVCacheBuffer, typename KERNEL_PARAMS_TYPE>
void multihead_attention_(
const KERNEL_PARAMS_TYPE& params, const KVCacheBuffer& kv_cache_buffer, const cudaStream_t& stream)
void multihead_attention_(const KERNEL_PARAMS_TYPE& params, const KVCacheBuffer& kv_cache_buffer,
const KVLinearBuffer& shift_k_cache, const cudaStream_t& stream)
{
switch (params.hidden_size_per_head)
{
case 32: mmha::mmha_launch_kernel<T, KVCacheBuffer, KERNEL_PARAMS_TYPE, 32>(params, kv_cache_buffer, stream); break;
case 64: mmha::mmha_launch_kernel<T, KVCacheBuffer, KERNEL_PARAMS_TYPE, 64>(params, kv_cache_buffer, stream); break;
case 32:
mmha::mmha_launch_kernel<T, KVCacheBuffer, KERNEL_PARAMS_TYPE, 32>(
params, kv_cache_buffer, shift_k_cache, stream);
break;
case 64:
mmha::mmha_launch_kernel<T, KVCacheBuffer, KERNEL_PARAMS_TYPE, 64>(
params, kv_cache_buffer, shift_k_cache, stream);
break;
case 128:
mmha::mmha_launch_kernel<T, KVCacheBuffer, KERNEL_PARAMS_TYPE, 128>(params, kv_cache_buffer, stream);
mmha::mmha_launch_kernel<T, KVCacheBuffer, KERNEL_PARAMS_TYPE, 128>(
params, kv_cache_buffer, shift_k_cache, stream);
break;
case 256:
mmha::mmha_launch_kernel<T, KVCacheBuffer, KERNEL_PARAMS_TYPE, 256>(params, kv_cache_buffer, stream);
mmha::mmha_launch_kernel<T, KVCacheBuffer, KERNEL_PARAMS_TYPE, 256>(
params, kv_cache_buffer, shift_k_cache, stream);
break;
#ifndef FAST_BUILD // skip mmha 48, 80, 96, 112, 144, 160, 192 and 224 for fast build
case 48: mmha::mmha_launch_kernel<T, KVCacheBuffer, KERNEL_PARAMS_TYPE, 48>(params, kv_cache_buffer, stream); break;
case 80: mmha::mmha_launch_kernel<T, KVCacheBuffer, KERNEL_PARAMS_TYPE, 80>(params, kv_cache_buffer, stream); break;
case 96: mmha::mmha_launch_kernel<T, KVCacheBuffer, KERNEL_PARAMS_TYPE, 96>(params, kv_cache_buffer, stream); break;
case 48:
mmha::mmha_launch_kernel<T, KVCacheBuffer, KERNEL_PARAMS_TYPE, 48>(
params, kv_cache_buffer, shift_k_cache, stream);
break;
case 80:
mmha::mmha_launch_kernel<T, KVCacheBuffer, KERNEL_PARAMS_TYPE, 80>(
params, kv_cache_buffer, shift_k_cache, stream);
break;
case 96:
mmha::mmha_launch_kernel<T, KVCacheBuffer, KERNEL_PARAMS_TYPE, 96>(
params, kv_cache_buffer, shift_k_cache, stream);
break;
case 112:
mmha::mmha_launch_kernel<T, KVCacheBuffer, KERNEL_PARAMS_TYPE, 112>(params, kv_cache_buffer, stream);
mmha::mmha_launch_kernel<T, KVCacheBuffer, KERNEL_PARAMS_TYPE, 112>(
params, kv_cache_buffer, shift_k_cache, stream);
break;
case 144:
mmha::mmha_launch_kernel<T, KVCacheBuffer, KERNEL_PARAMS_TYPE, 144>(params, kv_cache_buffer, stream);
mmha::mmha_launch_kernel<T, KVCacheBuffer, KERNEL_PARAMS_TYPE, 144>(
params, kv_cache_buffer, shift_k_cache, stream);
break;
case 160:
mmha::mmha_launch_kernel<T, KVCacheBuffer, KERNEL_PARAMS_TYPE, 160>(params, kv_cache_buffer, stream);
mmha::mmha_launch_kernel<T, KVCacheBuffer, KERNEL_PARAMS_TYPE, 160>(
params, kv_cache_buffer, shift_k_cache, stream);
break;
case 192:
mmha::mmha_launch_kernel<T, KVCacheBuffer, KERNEL_PARAMS_TYPE, 192>(params, kv_cache_buffer, stream);
mmha::mmha_launch_kernel<T, KVCacheBuffer, KERNEL_PARAMS_TYPE, 192>(
params, kv_cache_buffer, shift_k_cache, stream);
break;
case 224:
mmha::mmha_launch_kernel<T, KVCacheBuffer, KERNEL_PARAMS_TYPE, 224>(params, kv_cache_buffer, stream);
mmha::mmha_launch_kernel<T, KVCacheBuffer, KERNEL_PARAMS_TYPE, 224>(
params, kv_cache_buffer, shift_k_cache, stream);
break;
#endif // FAST_BUILD
default: TLLM_THROW("unsupported head_size");
default: TLLM_CHECK_WITH_INFO(false, "unsupported head_size %d", params.hidden_size_per_head);
}
}
@ -86,16 +110,16 @@ void multihead_attention_(
#define INSTANTIATE_MMHA_NORMAL_AND_PAGED(T, CROSS_ATTENTION) \
void masked_multihead_attention(const Multihead_attention_params<T, CROSS_ATTENTION>& params, \
const KVBlockArray& kv_cache_buffer, const cudaStream_t& stream) \
const KVBlockArray& kv_cache_buffer, const KVLinearBuffer& shift_k_cache, const cudaStream_t& stream) \
{ \
multihead_attention_<T, KVBlockArray, Multihead_attention_params<T, CROSS_ATTENTION>>( \
params, kv_cache_buffer, stream); \
params, kv_cache_buffer, shift_k_cache, stream); \
} \
void masked_multihead_attention(const Multihead_attention_params<T, CROSS_ATTENTION>& params, \
const KVLinearBuffer& kv_cache_buffer, const cudaStream_t& stream) \
const KVLinearBuffer& kv_cache_buffer, const KVLinearBuffer& shift_k_cache, const cudaStream_t& stream) \
{ \
multihead_attention_<T, KVLinearBuffer, Multihead_attention_params<T, CROSS_ATTENTION>>( \
params, kv_cache_buffer, stream); \
params, kv_cache_buffer, shift_k_cache, stream); \
}
INSTANTIATE_MMHA_NORMAL_AND_PAGED(float, true)
INSTANTIATE_MMHA_NORMAL_AND_PAGED(float, false)

View File

@ -108,6 +108,8 @@ struct Multihead_attention_params_base
int max_attention_window_size = 0;
// Cyclic kv cache capacity (used to get the cyclic kv cache position for new tokens)
int cyclic_attention_window_size = 0;
// Length of the sink token in KV cache
int sink_token_length = 0;
// The number of heads (H).
int num_heads = 0;
// Controls MHA/MQA/GQA
@ -122,6 +124,8 @@ struct Multihead_attention_params_base
RotaryScalingType rotary_embedding_scale_type = RotaryScalingType::kNONE;
float rotary_embedding_scale = 0.0f;
int rotary_embedding_max_positions = 0;
// Position shift for streamingllm
bool position_shift_enabled = false;
// The current timestep. TODO Check that do we only this param in cross attention?
int timestep = 0;
// The current timestep of each sentences (support different timestep for different sentences)
@ -222,13 +226,13 @@ using Cross_multihead_attention_params = Multihead_attention_params<T, true>;
#define DECLARE_MMHA_NORMAL_AND_PAGED(T) \
void masked_multihead_attention(const Masked_multihead_attention_params<T>& params, \
const KVBlockArray& block_array, const cudaStream_t& stream); \
const KVBlockArray& block_array, const KVLinearBuffer& shift_k_cache, const cudaStream_t& stream); \
void masked_multihead_attention(const Masked_multihead_attention_params<T>& params, \
const KVLinearBuffer& kv_cache_buffer, const cudaStream_t& stream); \
const KVLinearBuffer& kv_cache_buffer, const KVLinearBuffer& shift_k_cache, const cudaStream_t& stream); \
void masked_multihead_attention(const Cross_multihead_attention_params<T>& params, \
const KVBlockArray& block_array, const cudaStream_t& stream); \
const KVBlockArray& block_array, const KVLinearBuffer& shift_k_cache, const cudaStream_t& stream); \
void masked_multihead_attention(const Cross_multihead_attention_params<T>& params, \
const KVLinearBuffer& kv_cache_buffer, const cudaStream_t& stream);
const KVLinearBuffer& kv_cache_buffer, const KVLinearBuffer& shift_k_cache, const cudaStream_t& stream);
DECLARE_MMHA_NORMAL_AND_PAGED(float);
DECLARE_MMHA_NORMAL_AND_PAGED(uint16_t);
#ifdef ENABLE_BF16

View File

@ -1,5 +1,5 @@
#!/usr/bin/env python3
# SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");

View File

@ -149,15 +149,15 @@ inline void multi_block_grid_setup(dim3& grid, const Multihead_attention_params<
if (dynamic_smem_sz >= 46 * 1024) \
{ \
cudaError_t res = cudaFuncSetAttribute( \
mmha::masked_multihead_attention_kernel<T, T_cache, KVCacheBuffer, Dh, DYNAMIC_THDS_PER_BLOCK, \
KernelParamsType::DO_CROSS_ATTENTION, HAS_BEAMS, DO_MULTI_BLOCK>, \
mmha::masked_multihead_attention_kernel<T, T_cache, TKcache, KVCacheBuffer, KCacheBuffer, Dh, \
DYNAMIC_THDS_PER_BLOCK, KernelParamsType::DO_CROSS_ATTENTION, HAS_BEAMS, DO_MULTI_BLOCK, POS_SHIFT>, \
cudaFuncAttributeMaxDynamicSharedMemorySize, dynamic_smem_sz); \
TLLM_CHECK_WITH_INFO( \
res == cudaSuccess, "Sequence Length is too long for the MMHA kernel (not enough shared memory)."); \
} \
TLLM_CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&available_blocks, \
mmha::masked_multihead_attention_kernel<T, T_cache, KVCacheBuffer, Dh, DYNAMIC_THDS_PER_BLOCK, \
KernelParamsType::DO_CROSS_ATTENTION, HAS_BEAMS, DO_MULTI_BLOCK>, \
mmha::masked_multihead_attention_kernel<T, T_cache, TKcache, KVCacheBuffer, KCacheBuffer, Dh, \
DYNAMIC_THDS_PER_BLOCK, KernelParamsType::DO_CROSS_ATTENTION, HAS_BEAMS, DO_MULTI_BLOCK, POS_SHIFT>, \
DYNAMIC_THDS_PER_BLOCK, dynamic_smem_sz));
#define MMHA_KERNEL(DYNAMIC_THDS_PER_BLOCK, ENABLE_MULTI_BLOCK) \
@ -166,16 +166,17 @@ inline void multi_block_grid_setup(dim3& grid, const Multihead_attention_params<
/* Set 46KB threshold here because we have to take static/driver shared memory into consideration. */ \
if (dynamic_smem_sz >= 46 * 1024) \
{ \
cudaError_t res = cudaFuncSetAttribute( \
mmha::masked_multihead_attention_kernel<T, T_cache, KVCacheBuffer, Dh, DYNAMIC_THDS_PER_BLOCK, \
KernelParamsType::DO_CROSS_ATTENTION, HAS_BEAMS, ENABLE_MULTI_BLOCK>, \
cudaFuncAttributeMaxDynamicSharedMemorySize, dynamic_smem_sz); \
cudaError_t res \
= cudaFuncSetAttribute(mmha::masked_multihead_attention_kernel<T, T_cache, TKcache, KVCacheBuffer, \
KCacheBuffer, Dh, DYNAMIC_THDS_PER_BLOCK, KernelParamsType::DO_CROSS_ATTENTION, \
HAS_BEAMS, ENABLE_MULTI_BLOCK, POS_SHIFT>, \
cudaFuncAttributeMaxDynamicSharedMemorySize, dynamic_smem_sz); \
TLLM_CHECK_WITH_INFO( \
res == cudaSuccess, "Sequence Length is too long for the MMHA kernel (not enough shared memory)."); \
} \
mmha::masked_multihead_attention_kernel<T, T_cache, KVCacheBuffer, Dh, DYNAMIC_THDS_PER_BLOCK, \
KernelParamsType::DO_CROSS_ATTENTION, HAS_BEAMS, ENABLE_MULTI_BLOCK> \
<<<grid, DYNAMIC_THDS_PER_BLOCK, dynamic_smem_sz, stream>>>(params, kv_cache_buffer);
mmha::masked_multihead_attention_kernel<T, T_cache, TKcache, KVCacheBuffer, KCacheBuffer, Dh, \
DYNAMIC_THDS_PER_BLOCK, KernelParamsType::DO_CROSS_ATTENTION, HAS_BEAMS, ENABLE_MULTI_BLOCK, POS_SHIFT> \
<<<grid, DYNAMIC_THDS_PER_BLOCK, dynamic_smem_sz, stream>>>(params, kv_cache_buffer, k_cache_buffer);
// if resources are not enough to launch 512 threads per block, we will fallback to 256.
#define MMHA_512_BLOCKSIZE_CHECK() \
@ -204,10 +205,10 @@ inline void multi_block_grid_setup(dim3& grid, const Multihead_attention_params<
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename T, typename T_cache, typename KVCacheBuffer, typename KernelParamsType, int Dh, int THDS_PER_BLOCK,
bool HAS_BEAMS, bool DO_MULTI_BLOCK>
void mmha_launch_kernel_ex(
const KernelParamsType& params, const KVCacheBuffer& kv_cache_buffer, const cudaStream_t& stream, int tlength)
template <typename T, typename T_cache, typename TKcache, typename KVCacheBuffer, typename KCacheBuffer,
typename KernelParamsType, int Dh, int THDS_PER_BLOCK, bool HAS_BEAMS, bool DO_MULTI_BLOCK, bool POS_SHIFT>
void mmha_launch_kernel_ex(const KernelParamsType& params, const KVCacheBuffer& kv_cache_buffer,
const KCacheBuffer& k_cache_buffer, const cudaStream_t& stream, int tlength)
{
dim3 grid{static_cast<unsigned>(params.num_heads), static_cast<unsigned>(params.batch_size), 1};
@ -225,8 +226,8 @@ void mmha_launch_kernel_ex(
// Set 0 dynamic shared memory size as we need the number of available blocks limited by registers.
// Dynamic shared memory is fixed for different block size.
TLLM_CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks_per_sm,
mmha::masked_multihead_attention_kernel<T, T_cache, KVCacheBuffer, Dh, THDS_PER_BLOCK,
KernelParamsType::DO_CROSS_ATTENTION, HAS_BEAMS, DO_MULTI_BLOCK>,
mmha::masked_multihead_attention_kernel<T, T_cache, TKcache, KVCacheBuffer, KCacheBuffer, Dh, THDS_PER_BLOCK,
KernelParamsType::DO_CROSS_ATTENTION, HAS_BEAMS, DO_MULTI_BLOCK, POS_SHIFT>,
THDS_PER_BLOCK, 0));
int block_size_factor
@ -289,61 +290,80 @@ void mmha_launch_kernel_ex(
}
}
template <typename T, typename T_cache, typename KVCacheBuffer, typename KernelParamsType, int Dh, int THDS_PER_BLOCK,
bool HAS_BEAMS, bool DO_MULTI_BLOCK>
void mmha_launch_kernel_dispatch_pos_shift(const KernelParamsType& params, const KVCacheBuffer& kv_cache_buffer,
const KVLinearBuffer& shift_k_cache, const cudaStream_t& stream, int tlength)
{
if (params.position_shift_enabled && !KernelParamsType::DO_CROSS_ATTENTION)
{
mmha_launch_kernel_ex<T, T_cache, T, KVCacheBuffer, KVLinearBuffer, KernelParamsType, Dh, THDS_PER_BLOCK,
HAS_BEAMS, DO_MULTI_BLOCK, true>(params, kv_cache_buffer, shift_k_cache, stream, tlength);
}
else
{
mmha_launch_kernel_ex<T, T_cache, T_cache, KVCacheBuffer, KVCacheBuffer, KernelParamsType, Dh, THDS_PER_BLOCK,
HAS_BEAMS, DO_MULTI_BLOCK, false>(params, kv_cache_buffer, kv_cache_buffer, stream, tlength);
}
}
template <typename T, typename KVCacheBuffer, typename KernelParamsType, int Dh, int THDS_PER_BLOCK, bool HAS_BEAMS,
bool DO_MULTI_BLOCK>
void mmha_launch_kernel_dispatch_8bits_kv_cache(
const KernelParamsType& params, const KVCacheBuffer& kv_cache_buffer, const cudaStream_t& stream, int tlength)
void mmha_launch_kernel_dispatch_8bits_kv_cache(const KernelParamsType& params, const KVCacheBuffer& kv_cache_buffer,
const KVLinearBuffer& shift_k_cache, const cudaStream_t& stream, int tlength)
{
if (params.int8_kv_cache)
{
mmha_launch_kernel_ex<T, int8_t, KVCacheBuffer, KernelParamsType, Dh, THDS_PER_BLOCK, HAS_BEAMS,
DO_MULTI_BLOCK>(params, kv_cache_buffer, stream, tlength);
mmha_launch_kernel_dispatch_pos_shift<T, int8_t, KVCacheBuffer, KernelParamsType, Dh, THDS_PER_BLOCK, HAS_BEAMS,
DO_MULTI_BLOCK>(params, kv_cache_buffer, shift_k_cache, stream, tlength);
}
#ifdef ENABLE_FP8
else if (params.fp8_kv_cache)
{
mmha_launch_kernel_ex<T, __nv_fp8_e4m3, KVCacheBuffer, KernelParamsType, Dh, THDS_PER_BLOCK, HAS_BEAMS,
DO_MULTI_BLOCK>(params, kv_cache_buffer, stream, tlength);
mmha_launch_kernel_dispatch_pos_shift<T, __nv_fp8_e4m3, KVCacheBuffer, KernelParamsType, Dh, THDS_PER_BLOCK,
HAS_BEAMS, DO_MULTI_BLOCK>(params, kv_cache_buffer, shift_k_cache, stream, tlength);
}
#endif // ENABLE_FP8
else
{
mmha_launch_kernel_ex<T, T, KVCacheBuffer, KernelParamsType, Dh, THDS_PER_BLOCK, HAS_BEAMS, DO_MULTI_BLOCK>(
params, kv_cache_buffer, stream, tlength);
mmha_launch_kernel_dispatch_pos_shift<T, T, KVCacheBuffer, KernelParamsType, Dh, THDS_PER_BLOCK, HAS_BEAMS,
DO_MULTI_BLOCK>(params, kv_cache_buffer, shift_k_cache, stream, tlength);
}
}
template <typename T, typename KVCacheBuffer, typename KernelParamsType, int Dh, bool HAS_BEAMS>
void mmha_launch_kernel_dispatch(
const KernelParamsType& params, const KVCacheBuffer& kv_cache_buffer, const cudaStream_t& stream)
void mmha_launch_kernel_dispatch(const KernelParamsType& params, const KVCacheBuffer& kv_cache_buffer,
const KVLinearBuffer& shift_k_cache, const cudaStream_t& stream)
{
int const tlength = params.timestep;
if (params.multi_block_mode)
{
mmha_launch_kernel_dispatch_8bits_kv_cache<T, KVCacheBuffer, KernelParamsType, Dh, 256, HAS_BEAMS, true>(
params, kv_cache_buffer, stream, tlength);
params, kv_cache_buffer, shift_k_cache, stream, tlength);
}
else
{
mmha_launch_kernel_dispatch_8bits_kv_cache<T, KVCacheBuffer, KernelParamsType, Dh, 256, HAS_BEAMS, false>(
params, kv_cache_buffer, stream, tlength);
params, kv_cache_buffer, shift_k_cache, stream, tlength);
}
}
template <typename T, typename KVCacheBuffer, typename KernelParamsType, int Dh>
void mmha_launch_kernel(
const KernelParamsType& params, const KVCacheBuffer& kv_cache_buffer, const cudaStream_t& stream)
void mmha_launch_kernel(const KernelParamsType& params, const KVCacheBuffer& kv_cache_buffer,
const KVLinearBuffer& shift_k_cache, const cudaStream_t& stream)
{
assert((params.rotary_embedding_dim != 0)
== (params.position_embedding_type == PositionEmbeddingType::kROPE_GPT_NEOX
|| params.position_embedding_type == PositionEmbeddingType::kROPE_GPTJ));
if (params.beam_width == 1)
{
mmha_launch_kernel_dispatch<T, KVCacheBuffer, KernelParamsType, Dh, false>(params, kv_cache_buffer, stream);
mmha_launch_kernel_dispatch<T, KVCacheBuffer, KernelParamsType, Dh, false>(
params, kv_cache_buffer, shift_k_cache, stream);
}
else
{
mmha_launch_kernel_dispatch<T, KVCacheBuffer, KernelParamsType, Dh, true>(params, kv_cache_buffer, stream);
mmha_launch_kernel_dispatch<T, KVCacheBuffer, KernelParamsType, Dh, true>(
params, kv_cache_buffer, shift_k_cache, stream);
}
}
@ -352,16 +372,16 @@ void mmha_launch_kernel(
#define INSTANTIATE_MMHA_LAUNCHERS(T, Dh) \
template void mmha_launch_kernel<T, KVLinearBuffer, Masked_multihead_attention_params<T>, Dh>( \
const Masked_multihead_attention_params<T>& params, const KVLinearBuffer& kv_cache_buffer, \
const cudaStream_t& stream); \
const KVLinearBuffer& shift_k_cache, const cudaStream_t& stream); \
template void mmha_launch_kernel<T, KVBlockArray, Masked_multihead_attention_params<T>, Dh>( \
const Masked_multihead_attention_params<T>& params, const KVBlockArray& kv_cache_buffer, \
const cudaStream_t& stream); \
const KVLinearBuffer& shift_k_cache, const cudaStream_t& stream); \
template void mmha_launch_kernel<T, KVLinearBuffer, Cross_multihead_attention_params<T>, Dh>( \
const Cross_multihead_attention_params<T>& params, const KVLinearBuffer& kv_cache_buffer, \
const cudaStream_t& stream); \
const KVLinearBuffer& shift_k_cache, const cudaStream_t& stream); \
template void mmha_launch_kernel<T, KVBlockArray, Cross_multihead_attention_params<T>, Dh>( \
const Cross_multihead_attention_params<T>& params, const KVBlockArray& kv_cache_buffer, \
const cudaStream_t& stream);
const KVLinearBuffer& shift_k_cache, const cudaStream_t& stream);
} // namespace kernels
} // namespace tensorrt_llm

View File

@ -1227,8 +1227,12 @@ template <
typename T,
// The type of the cache.
typename Tcache,
// The type of the shift key cache.
typename TKcache,
// Type of struct containing KV cache
typename KVCacheBuffer,
// Type of struct containing K cache to read past keys
typename KCacheBuffer,
// The hidden dimension per head.
unsigned Dh,
// The number of threads in a threadblock.
@ -1239,6 +1243,8 @@ template <
bool HAS_BEAMS,
// Whether enable multi-block mode for long-sequence-length.
bool DO_MULTI_BLOCK = false,
// Whether enable position shift for streamingllm
bool POS_SHIFT = false,
// The number of threads per key.
unsigned THREADS_PER_KEY = threads_per_key<T, dh_max(Dh)>(),
// The number of threads per value.
@ -1249,13 +1255,15 @@ template <
// The unroll factor for loading from V cache.
unsigned V_LOOP_UNROLL = 8>
__global__ void masked_multihead_attention_kernel(
Multihead_attention_params<T, DO_CROSS_ATTENTION> params, KVCacheBuffer kvCacheBuffer)
Multihead_attention_params<T, DO_CROSS_ATTENTION> params, KVCacheBuffer kvCacheBuffer, KCacheBuffer pastKCache)
{
using Tk = typename kernel_type_t<T>::Type;
// Use 8bit cache.
static constexpr bool ENABLE_8BITS_CACHE = sizeof(Tcache) == 1;
static constexpr bool ENABLE_8BITS_K_CACHE = sizeof(TKcache) == 1;
static constexpr bool ENABLE_8BITS_KV_CACHE = sizeof(Tcache) == 1;
// FP8 KV Cache.
static constexpr bool FP8_K_CACHE = std::is_same<TKcache, __nv_fp8_e4m3>::value;
static constexpr bool FP8_KV_CACHE = std::is_same<Tcache, __nv_fp8_e4m3>::value;
// INT8 KV Cache.
static constexpr bool INT8_KV_CACHE = std::is_same<Tcache, int8_t>::value;
@ -1276,6 +1284,8 @@ __global__ void masked_multihead_attention_kernel(
// Note max_attention_window_size is maximum of cyclic_attention_window_size among all layers.
// By default, you can assume that they are the same.
const auto cyclic_kv_cache_len = static_cast<unsigned>(params.cyclic_attention_window_size);
// The number of sink tokens in kv cache to support streamingllm
const auto sink_token_len = static_cast<unsigned>(params.sink_token_length);
// The current timestep (including paddings).
// It is only used to calculate the smem stride.
const auto timestep = static_cast<unsigned>(DO_MULTI_BLOCK ? params.timesteps_per_block : params.timestep);
@ -1335,7 +1345,7 @@ __global__ void masked_multihead_attention_kernel(
// The type of queries and keys for the math in the Q*K^T product.
using K_vec_k = typename K_vec_k_<T, K_VEC_SIZE>::Type;
// Only used when key cache is quantized to 8 bits.
using K_vec_m = typename packed_type<Tcache, num_elems<K_vec_k>::value>::type;
using K_vec_m = typename packed_type<TKcache, num_elems<K_vec_k>::value>::type;
#ifdef MMHA_USE_FP32_ACCUM_FOR_FMA
using K_vec_accum = typename Qk_vec_accum_fp32_<K_vec_k>::Type;
#else
@ -1424,7 +1434,12 @@ __global__ void masked_multihead_attention_kernel(
: (params.length_per_sample ? (params.length_per_sample[batch_beam_idx] - 1) : static_cast<int>(timestep));
// We will use cyclic kv cache when it exceeds the limit.
// The length position for storing new key and value.
const int cyclic_tlength = tlength % cyclic_kv_cache_len;
const int cyclic_tlength = kvCacheBuffer.getKVTokenIdx(tlength);
// When enable cyclic kv cache and one more block mode, we need to shift the index to the actual index in the
// sequence. Otherwise, if the token is not the sink token, we need to add the bubblen length to the index.
const bool enable_use_seq_idx_kv = kvCacheBuffer.mEnableOneMoreBlock && tlength > cyclic_kv_cache_len;
const int shift_for_cyclic_kv = (enable_use_seq_idx_kv) ? tlength - cyclic_kv_cache_len : kvCacheBuffer.mBubbleLen;
const int shift_for_cyclic_k = (enable_use_seq_idx_kv) ? tlength - cyclic_kv_cache_len : pastKCache.mBubbleLen;
// The actual kv cache length.
// tlength is the past length actually.
const int kv_loop_length = min(tlength, cyclic_kv_cache_len);
@ -1433,6 +1448,8 @@ __global__ void masked_multihead_attention_kernel(
// as context kv cache might be overwritten by the new kv cache
const int beam0_context_length
= HAS_BEAMS && tlength > cyclic_kv_cache_len ? 0 : params.input_lengths[batch_beam_idx];
// The position of the current timestep, and it is used to apply the position embedding
const int current_pos_idx = (!POS_SHIFT || DO_CROSS_ATTENTION) ? tlength : kv_loop_length;
// The offset in the Q and K buffer also accounts for the batch.
const auto qk_vec_idx = tidx * QK_VEC_SIZE;
@ -1443,25 +1460,29 @@ __global__ void masked_multihead_attention_kernel(
// Quant/Dequant scales for 8bits kv cache.
using T_scale = typename kv_cache_scale_type_t<T, Tcache>::Type;
T_scale kv_scale_orig_quant, kv_scale_quant_orig;
const float kv_scale_quant_orig_f = (ENABLE_8BITS_CACHE ? params.kv_scale_quant_orig[0] : 1.0f);
convert_from_float(&kv_scale_quant_orig, kv_scale_quant_orig_f);
convert_from_float(&kv_scale_orig_quant, (ENABLE_8BITS_CACHE ? params.kv_scale_orig_quant[0] : 1.0f));
T_scale kv_scale_orig_quant, k_scale_quant_orig;
const float k_scale_quant_orig_f = (ENABLE_8BITS_K_CACHE ? params.kv_scale_quant_orig[0] : 1.0f);
const float kv_scale_quant_orig_f = (ENABLE_8BITS_KV_CACHE ? params.kv_scale_quant_orig[0] : 1.0f);
convert_from_float(&k_scale_quant_orig, k_scale_quant_orig_f);
convert_from_float(&kv_scale_orig_quant, (ENABLE_8BITS_KV_CACHE ? params.kv_scale_orig_quant[0] : 1.0f));
// Up to QK_VECS_PER_Dh_MAX threads load Q and K + the bias values for the current timestep.
// Trigger the loads from the Q and K buffers.
Qk_vec_k q, k, q_bias, k_bias;
// key without position embedding
Qk_vec_k k_wo_pos;
zero(q);
zero(k);
zero(q_bias);
zero(k_bias);
zero(k_wo_pos);
float rotary_embedding_base = params.rotary_embedding_base;
float rotary_embedding_scale = params.rotary_embedding_scale;
if (is_valid_qk_vec)
{
mmha::update_rotary_base_n_scale(rotary_embedding_base, rotary_embedding_scale,
params.rotary_embedding_scale_type, params.rotary_embedding_dim, params.rotary_embedding_max_positions,
tlength);
current_pos_idx);
// Query
// The stride between tokens. We may be able to always use params.stride.
uint32_t q_stride = params.stride ? static_cast<uint32_t>(params.stride) : (num_heads * Dh);
@ -1553,6 +1574,7 @@ __global__ void masked_multihead_attention_kernel(
vec_conversion<Qk_vec_k, Qk_vec_m>(*reinterpret_cast<const Qk_vec_m*>(
&params.ia3_key_weights[tensorrt_llm::common::flat_index2(ia3_ti_hi, qk_vec_idx, Dh)])));
}
k_wo_pos = k;
// Note we have no paddings in KV cache now.
switch (params.position_embedding_type)
@ -1569,12 +1591,12 @@ __global__ void masked_multihead_attention_kernel(
if (HANDLE_KV)
{
apply_rotary_embedding(q, k, tidx, params.rotary_embedding_dim, params.rotary_embedding_base,
params.rotary_embedding_scale, tlength);
params.rotary_embedding_scale, current_pos_idx);
}
else
{
apply_rotary_embedding(q, tidx, params.rotary_embedding_dim, params.rotary_embedding_base,
params.rotary_embedding_scale, tlength);
params.rotary_embedding_scale, current_pos_idx);
}
break;
}
@ -1613,14 +1635,14 @@ __global__ void masked_multihead_attention_kernel(
mmha::vec_from_smem_transpose(k, k_smem_, transpose_idx, smem_pitch);
mmha::apply_rotary_embedding(q, k, transpose_idx / tidx_factor, params.rotary_embedding_dim,
rotary_embedding_base, rotary_embedding_scale, tlength);
rotary_embedding_base, rotary_embedding_scale, current_pos_idx);
mmha::write_smem_transpose(k, k_smem_, transpose_idx, smem_pitch);
}
else
{
mmha::apply_rotary_embedding(q, transpose_idx / tidx_factor, params.rotary_embedding_dim,
rotary_embedding_base, rotary_embedding_scale, tlength);
rotary_embedding_base, rotary_embedding_scale, current_pos_idx);
}
mmha::write_smem_transpose(q, q_smem_, transpose_idx, smem_pitch);
}
@ -1648,7 +1670,7 @@ __global__ void masked_multihead_attention_kernel(
// Store the Q values to shared memory.
#ifdef MMHA_FP8_SCALE_Q_INSTEAD_OF_K
if constexpr (FP8_KV_CACHE)
if constexpr (FP8_K_CACHE)
{
// There are many more elements from K than elements from Q so we pre-scale Q instead
// of scaling all the elements from K. It helps reduce the number of ops.
@ -1656,7 +1678,7 @@ __global__ void masked_multihead_attention_kernel(
zero(scaled_q);
if (is_valid_qk_vec)
{
scaled_q = mul<Qk_vec_k, Tk, Qk_vec_k>(kv_scale_quant_orig, q);
scaled_q = mul<Qk_vec_k, Tk, Qk_vec_k>(k_scale_quant_orig, q);
}
reinterpret_cast<Qk_vec_k*>(&q_smem[qk_vec_idx])[0] = scaled_q;
}
@ -1672,7 +1694,14 @@ __global__ void masked_multihead_attention_kernel(
// Store the K values to shared memory.
// We store K values from shared memory to global memory
// when the target position of K cache in global memory has been accessed (in the case of cyclic kv cache)
reinterpret_cast<Qk_vec_k*>(&k_smem[qk_vec_idx])[0] = k;
if (POS_SHIFT && !DO_CROSS_ATTENTION)
{
reinterpret_cast<Qk_vec_k*>(&k_smem[qk_vec_idx])[0] = k_wo_pos;
}
else
{
reinterpret_cast<Qk_vec_k*>(&k_smem[qk_vec_idx])[0] = k;
}
// Compute \sum_i Q[i] * K^T[i] for the current timestep.
qk = dot<Qk_vec_accum, Qk_vec_k>(q, k);
@ -1766,9 +1795,6 @@ __global__ void masked_multihead_attention_kernel(
// The number of unrolled keys per ieration.
constexpr unsigned UNROLLED_K_PER_ITER = K_PER_ITER * K_LOOP_UNROLL;
// Base pointer for the row of pointers to k cache blocks
void** k_cache_base_row_ptr = reinterpret_cast<void**>(kvCacheBuffer.getRowPtr(KVIdxType::K_IDX, batch_beam_idx));
const auto timesteps_per_block = static_cast<unsigned>(params.timesteps_per_block);
// Pick a number of keys to make sure all the threads of a warp enter (due to shfl_sync).
@ -1840,13 +1866,24 @@ __global__ void masked_multihead_attention_kernel(
// Dh OOB values will be handled by zero_q.
// Seq OOB values will be masked out when storing back to smem.
auto const jj = min(k_idx.y + k_vec_i * K_ELTS_PER_CHUNK, Dh - K_VEC_SIZE);
const int valid_time_now = min(time_now + k_loop * K_PER_ITER, context_length - 1);
int valid_time_now = min(time_now + k_loop * K_PER_ITER, context_length - 1);
if (valid_time_now >= sink_token_len)
{
// If one more block mode is enabled, we use the index in sequence as tokenIdx.
// Otherwise, we need to add the bubble length to the index
valid_time_now += shift_for_cyclic_k;
if (enable_use_seq_idx_kv)
{
// Convert the token index in sequence to token index in K cache.
valid_time_now = pastKCache.getKVTokenIdx(valid_time_now);
}
}
const int seqIdx = batch_idx * beam_width;
// Base pointer to k cache block for beam's batch
Tcache* k_cache_batch = reinterpret_cast<Tcache*>(kvCacheBuffer.getKBlockPtr(seqIdx, valid_time_now));
TKcache* k_cache_batch = reinterpret_cast<TKcache*>(pastKCache.getKBlockPtr(seqIdx, valid_time_now));
int inBlockIdx = kvCacheBuffer.getKVLocalIdx(valid_time_now, hi_kv, Dh, jj);
int inBlockIdx = pastKCache.getKVLocalIdx(valid_time_now, hi_kv, Dh, jj);
k_vec_cache[k_loop][k_vec_i] = *reinterpret_cast<const K_vec_m*>(&k_cache_batch[inBlockIdx]);
}
}
@ -1904,16 +1941,16 @@ __global__ void masked_multihead_attention_kernel(
// Note that dot will convert 8bit vec to the accumulation data type (float by default).
float qk_ = 0.f;
#ifdef MMHA_FP8_SCALE_Q_INSTEAD_OF_K
if constexpr (FP8_KV_CACHE)
if constexpr (FP8_K_CACHE)
{
qk_ = Qk_dot<T, THREADS_PER_KEY>::dot(q_vec, k_vec) * params.inv_sqrt_dh;
}
else
#endif // MMHA_FP8_SCALE_Q_INSTEAD_OF_K
{
if constexpr (ENABLE_8BITS_CACHE)
if constexpr (ENABLE_8BITS_K_CACHE)
{
qk_ = Qk_dot<T, THREADS_PER_KEY>::scale_dot(q_vec, k_vec, kv_scale_quant_orig_f)
qk_ = Qk_dot<T, THREADS_PER_KEY>::scale_dot(q_vec, k_vec, k_scale_quant_orig_f)
* params.inv_sqrt_dh;
}
else
@ -1975,13 +2012,24 @@ __global__ void masked_multihead_attention_kernel(
for (int k_vec_i = 0; k_vec_i < K_VECS_PER_THREAD; ++k_vec_i)
{
const int jj = min(k_idx.y + k_vec_i * K_ELTS_PER_CHUNK, Dh - K_VEC_SIZE);
const int valid_time_now = min(time_now, kv_loop_length - 1);
int valid_time_now = min(time_now, kv_loop_length - 1);
int beam_offset = beam_indices[valid_time_now];
if (valid_time_now >= sink_token_len)
{
// If one more block mode is enabled, we use the index in sequence as tokenIdx.
// Otherwise, we need to add the bubble length to the index
valid_time_now += shift_for_cyclic_k;
if (enable_use_seq_idx_kv)
{
// Convert the token index in sequence to token index in K cache.
valid_time_now = pastKCache.getKVTokenIdx(valid_time_now);
}
}
const int seqIdx = batch_idx * beam_width + beam_offset;
// Base pointer to k cache block for beam's batch, before offsetting with indirection buffer
Tcache* k_cache_batch = reinterpret_cast<Tcache*>(kvCacheBuffer.getKBlockPtr(seqIdx, valid_time_now));
TKcache* k_cache_batch = reinterpret_cast<TKcache*>(pastKCache.getKBlockPtr(seqIdx, valid_time_now));
int inBlockIdx = kvCacheBuffer.getKVLocalIdx(valid_time_now, hi_kv, Dh, jj);
int inBlockIdx = pastKCache.getKVLocalIdx(valid_time_now, hi_kv, Dh, jj);
k_vec[k_vec_i] = (*reinterpret_cast<const K_vec_m*>(&k_cache_batch[inBlockIdx]));
}
@ -2024,16 +2072,16 @@ __global__ void masked_multihead_attention_kernel(
// Note that dot will convert 8bit vec to the accumulation data type (float by default).
float qk_ = 0.f;
#ifdef MMHA_FP8_SCALE_Q_INSTEAD_OF_K
if constexpr (FP8_KV_CACHE)
if constexpr (FP8_K_CACHE)
{
qk_ = Qk_dot<T, THREADS_PER_KEY>::dot(q_vec, k_vec) * params.inv_sqrt_dh;
}
else
#endif // MMHA_FP8_SCALE_Q_INSTEAD_OF_K
{
if constexpr (ENABLE_8BITS_CACHE)
if constexpr (ENABLE_8BITS_K_CACHE)
{
qk_ = Qk_dot<T, THREADS_PER_KEY>::scale_dot(q_vec, k_vec, kv_scale_quant_orig_f)
qk_ = Qk_dot<T, THREADS_PER_KEY>::scale_dot(q_vec, k_vec, k_scale_quant_orig_f)
* params.inv_sqrt_dh;
}
else
@ -2126,7 +2174,7 @@ __global__ void masked_multihead_attention_kernel(
// The base pointer for the value in the cache buffer.
Tcache* k_cache = reinterpret_cast<Tcache*>(kvCacheBuffer.getKBlockPtr(batch_beam_idx, cyclic_tlength));
if constexpr (ENABLE_8BITS_CACHE)
if constexpr (ENABLE_8BITS_KV_CACHE)
{
store_8bits_kv_cache_vec(reinterpret_cast<Tcache*>(k_cache), k_vec, inBlockIdx, kv_scale_orig_quant);
}
@ -2224,12 +2272,6 @@ __global__ void masked_multihead_attention_kernel(
const auto vo = v_idx.x;
// The hidden dimensions computed by this particular thread.
const auto vi = v_idx.y;
// Base pointer for the row of pointers to v cache blocks
void** v_cache_base_row_ptr = reinterpret_cast<void**>(kvCacheBuffer.getRowPtr(KVIdxType::V_IDX, batch_beam_idx));
// Base pointer for the row of pointers to v cache blocks for beam's batch, before offsetting with indirection
// buffer
void** v_cache_batch_row_ptr
= reinterpret_cast<void**>(kvCacheBuffer.getRowPtr(KVIdxType::V_IDX, batch_idx * beam_width));
// The number of values processed per iteration of the loop.
constexpr unsigned V_PER_ITER{THREADS_PER_BLOCK / THREADS_PER_VALUE};
@ -2293,6 +2335,17 @@ __global__ void masked_multihead_attention_kernel(
// Fetch offset based on cache_indir when beam sampling
int time_idx = ti + v_loop * V_PER_ITER + (MULTI_BLOCK_FLAG ? c_tile_times_timesteps_per_block : 0);
time_idx = min(time_idx, kv_loop_length - 1);
if (time_idx >= sink_token_len)
{
// If one more block mode is enabled, we use the index in sequence as tokenIdx.
// Otherwise, we need to add the bubble length to the index
time_idx += shift_for_cyclic_kv;
if (enable_use_seq_idx_kv)
{
// Convert the token index in sequence to token index in V cache.
time_idx = kvCacheBuffer.getKVTokenIdx(time_idx);
}
}
int rowIdx = batch_idx * beam_width;
const int inBlockIdx = kvCacheBuffer.getKVLocalIdx(time_idx, hi_kv, Dh, vi);
@ -2339,6 +2392,18 @@ __global__ void masked_multihead_attention_kernel(
}
int rowIdx = batch_idx * beam_width + beam_indices[time_idx];
if (time_idx >= sink_token_len)
{
// If one more block mode is enabled, we use the index in sequence as tokenIdx.
// Otherwise, we need to add the bubble length to the index
time_idx += shift_for_cyclic_kv;
if (enable_use_seq_idx_kv)
{
// Convert the token index in sequence to token index in V cache.
time_idx = kvCacheBuffer.getKVTokenIdx(time_idx);
}
}
const int inBlockIdx = kvCacheBuffer.getKVLocalIdx(time_idx, hi_kv, Dh, vi);
// The base pointer for the value in the cache buffer.
Tcache* v_cache_batch = reinterpret_cast<Tcache*>(kvCacheBuffer.getVBlockPtr(rowIdx, time_idx));
@ -2362,10 +2427,9 @@ __global__ void masked_multihead_attention_kernel(
// One group of threads computes the product(s) for the current timestep.
if (vo == kv_loop_length % V_PER_ITER && is_valid_vi && (!MULTI_BLOCK_FLAG || (c_tile == ctile_idx)))
{
const int tokenIdx = cyclic_tlength;
const int inBlockIdx = kvCacheBuffer.getKVLocalIdx(tokenIdx, hi_kv, Dh, vi);
const int inBlockIdx = kvCacheBuffer.getKVLocalIdx(cyclic_tlength, hi_kv, Dh, vi);
// The base pointer for the value in the cache buffer.
Tcache* v_cache_base = reinterpret_cast<Tcache*>(kvCacheBuffer.getBlockPtr(v_cache_base_row_ptr, tokenIdx));
Tcache* v_cache_base = reinterpret_cast<Tcache*>(kvCacheBuffer.getVBlockPtr(batch_beam_idx, cyclic_tlength));
V_vec_k v;
if (DO_CROSS_ATTENTION)
@ -2414,7 +2478,7 @@ __global__ void masked_multihead_attention_kernel(
// For MQA/GQA mode, write only with the first Q head of each group per KV head.
if (hi == (hi_kv * qhead_per_kv))
{
if (ENABLE_8BITS_CACHE)
if (ENABLE_8BITS_KV_CACHE)
{
store_8bits_kv_cache_vec(v_cache_base, v, inBlockIdx, kv_scale_orig_quant);
}

View File

@ -287,8 +287,8 @@ public:
}
template <typename T, bool HAS_BEAM>
void run(const XQAParams& xqaParams, KVLinearBuffer& kv_linear_buffer, const cudaStream_t& stream,
int multiprocessor_count, int max_multi_block_slots) const
void run(const XQAParams& xqaParams, KVLinearBuffer& kv_linear_buffer, int2& rotary_kernel_launch_cache,
const cudaStream_t& stream, int multiprocessor_count, int max_multi_block_slots) const
{
unsigned int head_size = xqaParams.head_size;
int num_q_heads = xqaParams.num_q_heads;
@ -303,11 +303,12 @@ public:
invokeApplyBiasRopeUpdateKVCache<T, KVLinearBuffer, true>(static_cast<T*>(const_cast<void*>(xqaParams.qkv)),
nullptr, kv_linear_buffer, static_cast<const T*>(xqaParams.qkv_bias), xqaParams.sequence_lengths, nullptr,
nullptr, xqaParams.batch_size, 1, xqaParams.cyclic_attention_window_size, xqaParams.batch_size * beam_width,
xqaParams.num_q_heads, xqaParams.num_kv_heads, xqaParams.head_size, xqaParams.rotary_embedding_dim,
xqaParams.rotary_embedding_base, xqaParams.rotary_embedding_scale_type, xqaParams.rotary_embedding_scale,
xqaParams.rotary_embedding_max_positions, xqaParams.position_embedding_type, (float*) nullptr, 0,
cache_type, xqaParams.kv_scale_orig_quant, false, stream, beam_width);
nullptr, xqaParams.batch_size, 1, xqaParams.cyclic_attention_window_size, xqaParams.sink_token_length,
xqaParams.batch_size * beam_width, xqaParams.num_q_heads, xqaParams.num_kv_heads, xqaParams.head_size,
xqaParams.rotary_embedding_dim, xqaParams.rotary_embedding_base, xqaParams.rotary_embedding_scale_type,
xqaParams.rotary_embedding_scale, xqaParams.rotary_embedding_max_positions,
xqaParams.position_embedding_type, xqaParams.position_shift_enabled, (float*) nullptr, 0, cache_type,
xqaParams.kv_scale_orig_quant, false, beam_width, rotary_kernel_launch_cache, stream);
XQAKernelRuntimeHashKey hash_key{xqaParams.kv_cache_data_type, head_size, num_q_heads_over_kv, beam_width};
const auto findIter = mFunctions.find(hash_key);
@ -459,8 +460,8 @@ public:
return xqaKernel->supportConfig(xqaParams) && xqaKernel->mayHavePerfGain(xqaParams, mMultiProcessorCount);
}
void run(const XQAParams& xqa_params, KVLinearBuffer& kv_linear_buffer, const cudaStream_t& stream,
int max_multi_block_slots);
void run(const XQAParams& xqa_params, KVLinearBuffer& kv_linear_buffer, int2& rotary_kernel_launch_cache,
const cudaStream_t& stream, int max_multi_block_slots);
private:
const XQAKernelList* xqaKernel;
@ -470,17 +471,17 @@ private:
};
void DecoderXQARunner::xqaImpl::run(const XQAParams& xqa_params, KVLinearBuffer& kv_linear_buffer,
const cudaStream_t& stream, int max_multi_block_slots)
int2& rotary_kernel_launch_cache, const cudaStream_t& stream, int max_multi_block_slots)
{
if (xqa_params.beam_width > 1)
{
xqaKernel->template run<__half, true>(
xqa_params, kv_linear_buffer, stream, mMultiProcessorCount, max_multi_block_slots);
xqaKernel->template run<__half, true>(xqa_params, kv_linear_buffer, rotary_kernel_launch_cache, stream,
mMultiProcessorCount, max_multi_block_slots);
}
else
{
xqaKernel->template run<__half, false>(
xqa_params, kv_linear_buffer, stream, mMultiProcessorCount, max_multi_block_slots);
xqaKernel->template run<__half, false>(xqa_params, kv_linear_buffer, rotary_kernel_launch_cache, stream,
mMultiProcessorCount, max_multi_block_slots);
}
}
@ -545,7 +546,7 @@ void DecoderXQARunner::run(const XQAParams& xqa_params, KVLinearBuffer& kv_linea
{
int max_multi_block_slots = kMaxBeamWidth * XQALaunchParam<true>::GetMaxBatchSizePerWave(mMultiProcessorCount)
* kMaxNbCtaPerKVHeadFactor * mNumKVHeads;
return pimpl->run(xqa_params, kv_linear_buffer, stream, max_multi_block_slots);
return pimpl->run(xqa_params, kv_linear_buffer, mLaunchGridBlockCache, stream, max_multi_block_slots);
}
} // namespace kernels

View File

@ -63,6 +63,7 @@ struct XQAParams
int32_t beam_width = 0;
int32_t max_attention_window_size = 0;
int32_t cyclic_attention_window_size = 0;
int32_t sink_token_length = 0;
int timestep = 0;
const void* qkv_bias;
const int32_t* sequence_lengths; //
@ -82,6 +83,7 @@ struct XQAParams
float rotary_embedding_scale;
int rotary_embedding_max_positions;
tensorrt_llm::kernels::PositionEmbeddingType position_embedding_type;
bool position_shift_enabled = false;
bool remove_padding = false;
tensorrt_llm::kernels::AttentionMaskType mask_type;
bool paged_kv_cache;
@ -144,6 +146,8 @@ public:
SUPPORT_RETURN_FALSE("beam_width");
if (xqaParams.cyclic_attention_window_size != xqaParams.max_attention_window_size)
SUPPORT_RETURN_FALSE("cyclic_attention_window_size != max_attention_window_size");
if (xqaParams.position_shift_enabled || xqaParams.sink_token_length > 0)
SUPPORT_RETURN_FALSE("streaming-llm");
return shouldUseImpl(xqaParams);
}
@ -182,6 +186,10 @@ private:
static constexpr int kMaxBeamWidth = 4;
// Cache the grid_size and block_size that gives the highest occupancy for
// invokeApplyBiasRopeUpdateKVCache.
int2 mLaunchGridBlockCache = make_int2(0, 0);
class xqaImpl;
std::unique_ptr<xqaImpl> pimpl;

View File

@ -28,6 +28,7 @@ using tensorrt_llm::common::bf16hmul2;
using tensorrt_llm::common::bf16hmul;
using tensorrt_llm::common::bf16hadd2;
using tensorrt_llm::common::float22bf162;
using tensorrt_llm::common::hsub2;
#endif
namespace tensorrt_llm
@ -48,7 +49,7 @@ struct __align__(16) Float4_
////////////////////////////////////////////////////////////////////////////////////////////////////
struct __align__(32) Float8_
struct __align__(16) Float8_
{
float2 x;
float2 y;
@ -336,6 +337,32 @@ struct packed_type<float, 8>
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ uint32_t sub(uint32_t a, uint32_t b)
{
uint32_t c;
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b));
return c;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ __nv_bfloat162 sub(__nv_bfloat162 a, __nv_bfloat162 b)
{
return hsub2(a, b);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ float2 sub(float2 a, float2 b)
{
float2 c;
c.x = a.x - b.x;
c.y = a.y - b.y;
return c;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ float add(float a, float b)
{
return a + b;
@ -2510,7 +2537,19 @@ inline __device__ float update_rotary_base(
{
const float b = (scale * kv_seq_len / max_positions) - (scale - 1);
const float p = static_cast<float>(embed_dim) / (embed_dim - 2);
return base * pow(b, p);
return base * __powf(b, p);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ float2 update_dynamic_scaling_rotary(float base, float scale, const int kv_seq_len,
const int max_positions, const int embed_dim, const bool dynamic_scaling)
{
const float b = kv_seq_len * __fdividef(scale, max_positions) - (scale - 1);
const float p = __fdividef(embed_dim, embed_dim - 2);
const float updated_base = dynamic_scaling ? base * __powf(b, p) : base;
const float updated_scale = dynamic_scaling ? 1.0f : scale;
return {updated_base, updated_scale};
}
inline __device__ void update_rotary_base_n_scale(float& base, float& scale, RotaryScalingType const scale_type,
@ -2534,8 +2573,8 @@ inline __device__ void update_rotary_base_n_scale(float& base, float& scale, Rot
inline __device__ float2 rotary_embedding_coefficient(
const int zid, const int rot_embed_dim, const float base, const float scale, const float t_step)
{
const float inv_freq = (t_step * scale) / pow(base, zid / (float) rot_embed_dim);
return {cos(inv_freq), sin(inv_freq)};
const float inv_freq = __fdividef(float(t_step * scale), __powf(base, zid / (float) rot_embed_dim));
return {__cosf(inv_freq), __sinf(inv_freq)};
}
inline __device__ float2 rotary_embedding_transform(const float2 v, const float2 coef)
@ -2629,6 +2668,30 @@ inline __device__ void apply_rotary_embedding(
k_.y = rotary_embedding_transform(k_.y, coef1);
}
inline __device__ void apply_rotary_embedding(
Float8_& q, Float8_& k, int tid, int rot_embed_dim, float base, float scale, int t_step)
{
if (8 * tid >= rot_embed_dim)
{
return;
}
Float8_& q_ = *reinterpret_cast<Float8_*>(&q);
Float8_& k_ = *reinterpret_cast<Float8_*>(&k);
const auto coef0 = rotary_embedding_coefficient(8 * tid, rot_embed_dim, base, scale, t_step);
q_.x = rotary_embedding_transform(q_.x, coef0);
k_.x = rotary_embedding_transform(k_.x, coef0);
const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, rot_embed_dim, base, scale, t_step);
q_.y = rotary_embedding_transform(q_.y, coef1);
k_.y = rotary_embedding_transform(k_.y, coef1);
const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, rot_embed_dim, base, scale, t_step);
q_.z = rotary_embedding_transform(q_.z, coef2);
k_.z = rotary_embedding_transform(k_.z, coef2);
const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, rot_embed_dim, base, scale, t_step);
q_.w = rotary_embedding_transform(q_.w, coef3);
k_.w = rotary_embedding_transform(k_.w, coef3);
}
inline __device__ void apply_rotary_embedding(
uint32_t& q, int tid, int rot_embed_dim, float base, float scale, int t_step)
{
@ -2821,6 +2884,205 @@ inline __device__ void apply_rotary_embedding(
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void apply_rotary_embedding(uint16_t& q, uint16_t q_pair, uint16_t& k, uint16_t k_pair, int tid0,
int tid1, // not used
int rot_embed_dim, float base, float scale, int t_step, int first_half)
{
const float2 coef = rotary_embedding_coefficient(tid0, rot_embed_dim, base, scale, t_step);
uint32_t cos = float2_to_half2(make_float2(coef.x, coef.x));
uint32_t sin = float2_to_half2(make_float2(coef.y, coef.y));
uint32_t h2, h2_pair;
reinterpret_cast<uint16_t*>(&h2)[0] = q;
reinterpret_cast<uint16_t*>(&h2)[1] = k;
reinterpret_cast<uint16_t*>(&h2_pair)[0] = q_pair;
reinterpret_cast<uint16_t*>(&h2_pair)[1] = k_pair;
if (first_half)
{
h2 = sub(mul<uint32_t>(cos, h2), mul<uint32_t>(sin, h2_pair));
}
else
{
h2 = add(mul<uint32_t>(cos, h2), mul<uint32_t>(sin, h2_pair));
}
q = reinterpret_cast<uint16_t*>(&h2)[0];
k = reinterpret_cast<uint16_t*>(&h2)[1];
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void apply_rotary_embedding(uint32_t& q, uint32_t q_pair, uint32_t& k, uint32_t k_pair, int tid0,
int tid1, int rot_embed_dim, float base, float scale, int t_step, int first_half)
{
const float2 coef0 = rotary_embedding_coefficient(tid0, rot_embed_dim, base, scale, t_step);
const float2 coef1 = rotary_embedding_coefficient(tid1, rot_embed_dim, base, scale, t_step);
uint32_t cos0 = float2_to_half2(make_float2(coef0.x, coef1.x));
uint32_t sin0 = float2_to_half2(make_float2(coef0.y, coef1.y));
if (first_half)
{
q = sub(mul<uint32_t>(cos0, q), mul<uint32_t>(sin0, q_pair));
k = sub(mul<uint32_t>(cos0, k), mul<uint32_t>(sin0, k_pair));
}
else
{
q = add(mul<uint32_t>(cos0, q), mul<uint32_t>(sin0, q_pair));
k = add(mul<uint32_t>(cos0, k), mul<uint32_t>(sin0, k_pair));
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void apply_rotary_embedding(__nv_bfloat16& q, __nv_bfloat16 q_pair, __nv_bfloat16& k,
__nv_bfloat16 k_pair, int tid0,
int tid1, // not used
int rot_embed_dim, float base, float scale, int t_step, int first_half)
{
const float2 coef = rotary_embedding_coefficient(tid0, rot_embed_dim, base, scale, t_step);
__nv_bfloat162 cos = float22bf162(make_float2(coef.x, coef.x));
__nv_bfloat162 sin = float22bf162(make_float2(coef.y, coef.y));
__nv_bfloat162 h2, h2_pair;
reinterpret_cast<__nv_bfloat16*>(&h2)[0] = q;
reinterpret_cast<__nv_bfloat16*>(&h2)[1] = k;
reinterpret_cast<__nv_bfloat16*>(&h2_pair)[0] = q_pair;
reinterpret_cast<__nv_bfloat16*>(&h2_pair)[1] = k_pair;
if (first_half)
{
h2 = sub(mul<__nv_bfloat162>(cos, h2), mul<__nv_bfloat162>(sin, h2_pair));
}
else
{
h2 = add(mul<__nv_bfloat162>(cos, h2), mul<__nv_bfloat162>(sin, h2_pair));
}
q = reinterpret_cast<__nv_bfloat16*>(&h2)[0];
k = reinterpret_cast<__nv_bfloat16*>(&h2)[1];
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void apply_rotary_embedding(__nv_bfloat162& q, __nv_bfloat162 q_pair, __nv_bfloat162& k,
__nv_bfloat162 k_pair, int tid0, int tid1, int rot_embed_dim, float base, float scale, int t_step, int first_half)
{
const float2 coef0 = rotary_embedding_coefficient(tid0, rot_embed_dim, base, scale, t_step);
const float2 coef1 = rotary_embedding_coefficient(tid1, rot_embed_dim, base, scale, t_step);
__nv_bfloat162 cos0 = float22bf162(make_float2(coef0.x, coef1.x));
__nv_bfloat162 sin0 = float22bf162(make_float2(coef0.y, coef1.y));
if (first_half)
{
q = sub(mul<__nv_bfloat162>(cos0, q), mul<__nv_bfloat162>(sin0, q_pair));
k = sub(mul<__nv_bfloat162>(cos0, k), mul<__nv_bfloat162>(sin0, k_pair));
}
else
{
q = add(mul<__nv_bfloat162>(cos0, q), mul<__nv_bfloat162>(sin0, q_pair));
k = add(mul<__nv_bfloat162>(cos0, k), mul<__nv_bfloat162>(sin0, k_pair));
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void apply_rotary_embedding(float& q, float q_pair, float& k, float k_pair, int tid0,
int tid1, // not used
int rot_embed_dim, float base, float scale, int t_step, int first_half)
{
const float2 coef = rotary_embedding_coefficient(tid0, rot_embed_dim, base, scale, t_step);
float cos = coef.x;
float sin = coef.y;
if (first_half)
{
q = sub(mul<float>(cos, q), mul<float>(sin, q_pair));
k = sub(mul<float>(cos, k), mul<float>(sin, k_pair));
}
else
{
q = add(mul<float>(cos, q), mul<float>(sin, q_pair));
k = add(mul<float>(cos, k), mul<float>(sin, k_pair));
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void apply_rotary_embedding(float2& q, float2 q_pair, float2& k, float2 k_pair, int tid0, int tid1,
int rot_embed_dim, float base, float scale, int t_step, int first_half)
{
const float2 coef0 = rotary_embedding_coefficient(tid0, rot_embed_dim, base, scale, t_step);
const float2 coef1 = rotary_embedding_coefficient(tid1, rot_embed_dim, base, scale, t_step);
float2 cos0 = make_float2(coef0.x, coef1.x);
float2 sin0 = make_float2(coef0.y, coef1.y);
if (first_half)
{
q = sub(mul<float2>(cos0, q), mul<float2>(sin0, q_pair));
k = sub(mul<float2>(cos0, k), mul<float2>(sin0, k_pair));
}
else
{
q = add(mul<float2>(cos0, q), mul<float2>(sin0, q_pair));
k = add(mul<float2>(cos0, k), mul<float2>(sin0, k_pair));
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Vec_type, typename Packed_type, typename T>
inline __device__ void apply_rotary_embedding_gptneox(Vec_type& q, Vec_type& k, int tidx, int rotary_embedding_dim,
float rotary_embedding_base, float rotary_embedding_scale, int t_step, bool first_half)
{
// 32 threads: each hold VEC_SIZE elements (half)
Vec_type q_pair, k_pair;
constexpr int VEC_SIZE = sizeof(Vec_type) / sizeof(Packed_type);
constexpr int PACKED_ELT_SIZE = sizeof(Packed_type) / sizeof(T);
if constexpr (sizeof(Vec_type) == 2)
{
reinterpret_cast<uint16_t&>(q_pair) = __shfl_xor_sync(0xffffffff, reinterpret_cast<uint16_t&>(q), 16);
reinterpret_cast<uint16_t&>(k_pair) = __shfl_xor_sync(0xffffffff, reinterpret_cast<uint16_t&>(k), 16);
}
else if constexpr (sizeof(Vec_type) == 4)
{
reinterpret_cast<unsigned int&>(q_pair) = __shfl_xor_sync(0xffffffff, reinterpret_cast<unsigned int&>(q), 16);
reinterpret_cast<unsigned int&>(k_pair) = __shfl_xor_sync(0xffffffff, reinterpret_cast<unsigned int&>(k), 16);
}
else if constexpr (sizeof(Vec_type) >= 8)
{
#pragma unroll
for (int vec_id = 0; vec_id < sizeof(Vec_type) / 8; vec_id++)
{
reinterpret_cast<unsigned long*>(&q_pair)[vec_id]
= __shfl_xor_sync(0xffffffff, reinterpret_cast<unsigned long*>(&q)[vec_id], 16);
reinterpret_cast<unsigned long*>(&k_pair)[vec_id]
= __shfl_xor_sync(0xffffffff, reinterpret_cast<unsigned long*>(&k)[vec_id], 16);
}
}
const int half_rotary_dim = rotary_embedding_dim / 2;
#pragma unroll
for (int elt_id = 0; elt_id < VEC_SIZE; elt_id++)
{
// Pack two elements for calculation (only one if each the thread only gets one element)
// Assume the head size (or rotary embedding) is multiple of 8.
const int rotary_emd_pos0_id
= (tidx * VEC_SIZE * PACKED_ELT_SIZE + elt_id * PACKED_ELT_SIZE + 0 - int(!first_half) * half_rotary_dim)
* 2;
const int rotary_emd_pos1_id
= (tidx * VEC_SIZE * PACKED_ELT_SIZE + elt_id * PACKED_ELT_SIZE + 1 - int(!first_half) * half_rotary_dim)
* 2;
const bool valid_rotary_pos = rotary_emd_pos1_id < rotary_embedding_dim;
Packed_type q_ = reinterpret_cast<Packed_type*>(&q)[elt_id];
Packed_type q_pair_ = reinterpret_cast<Packed_type*>(&q_pair)[elt_id];
Packed_type k_ = reinterpret_cast<Packed_type*>(&k)[elt_id];
Packed_type k_pair_ = reinterpret_cast<Packed_type*>(&k_pair)[elt_id];
apply_rotary_embedding(q_, q_pair_, k_, k_pair_, rotary_emd_pos0_id, rotary_emd_pos1_id, rotary_embedding_dim,
rotary_embedding_base, rotary_embedding_scale, t_step, first_half);
if (valid_rotary_pos)
{
reinterpret_cast<Packed_type*>(&q)[elt_id] = q_;
reinterpret_cast<Packed_type*>(&k)[elt_id] = k_;
}
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void convert_from_float(float* dst, float src)

View File

@ -1,5 +1,5 @@
/*
* Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/*
* Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -204,7 +204,7 @@ __global__ void computeAttentionMask(AttentionMaskDataType* attentionMask, const
case AttentionMaskType::BIDIRECTIONALGLM:
// clang-format off
isValid = (colIdx < seqLength - 1) ||
(rowIdx == maxSeqLength - 1 && colIdx == maxSeqLength - 1);
(rowIdx == seqLength - 1 && colIdx == seqLength - 1);
// clang-format on
// seq_length==4, max_seq_len==5
// 1 1 1 1 0

View File

@ -1,5 +1,5 @@
/*
* Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -83,6 +83,8 @@ struct BuildDecoderInfoParams
// The kv cache capacity.
// We will apply the limited_length_causal mask when there are not enough kv cache.
int attentionWindowSize;
// The number of sink tokens in the kv cache.
int sinkTokenLength;
// The number of tokens in total. It's \sum_{ii=0}^{batchSize} seqLengths[ii].
int numTokens;
// The type of attention.

Some files were not shown because too many files have changed in this diff Show More