Update TensorRT-LLM (20240116) (#891)

* Update TensorRT-LLM

---------

Co-authored-by: Eddie-Wang1120 <81598289+Eddie-Wang1120@users.noreply.github.com>
Co-authored-by: Shixiaowei02 <39303645+Shixiaowei02@users.noreply.github.com>
This commit is contained in:
Kaiyu Xie 2024-01-16 20:03:11 +08:00 committed by GitHub
parent 12e82e30b0
commit c89653021e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
255 changed files with 285533 additions and 25334 deletions

116
.github/ISSUE_TEMPLATE/bug_report.yml vendored Normal file
View File

@ -0,0 +1,116 @@
name: "Bug Report"
description: Submit a bug report to help us improve TensorRT-LLM
labels: [ "bug" ]
body:
- type: textarea
id: system-info
attributes:
label: System Info
description: Please share your system info with us.
placeholder: |
- CPU architecture (e.g., x86_64, aarch64)
- CPU/Host memory size (if known)
- GPU properties
- GPU name (e.g., NVIDIA H100, NVIDIA A100, NVIDIA L40S)
- GPU memory size (if known)
- Clock frequencies used (if applicable)
- Libraries
- TensorRT-LLM branch or tag (e.g., main, v0.7.1)
- TensorRT-LLM commit (if known)
- Versions of TensorRT, AMMO, CUDA, cuBLAS, etc. used
- Container used (if running TensorRT-LLM in a container)
- NVIDIA driver version
- OS (Ubuntu 22.04, CentOS 7, Windows 10)
- Any other information that may be useful in reproducing the bug
validations:
required: true
- type: textarea
id: who-can-help
attributes:
label: Who can help?
description: |
To expedite the response to your issue, it would be helpful if you could identify the appropriate person
to tag using the **@** symbol. Here is a general guideline on **whom to tag**.
Rest assured that all issues are reviewed by the core maintainers. If you are unsure about whom to tag,
you can leave it blank, and a core maintainer will make sure to involve the appropriate person.
Please tag fewer than 3 people.
Quantization: @Tracin
Documentation: @juney-nvidia
Feature request: @ncomly-nvidia
Performance: @kaiyux
Others: @byshiue
placeholder: "@Username ..."
- type: checkboxes
id: information-scripts-examples
attributes:
label: Information
description: 'The problem arises when using:'
options:
- label: "The official example scripts"
- label: "My own modified scripts"
- type: checkboxes
id: information-tasks
attributes:
label: Tasks
description: "The tasks I am working on are:"
options:
- label: "An officially supported task in the `examples` folder (such as GLUE/SQuAD, ...)"
- label: "My own task or dataset (give details below)"
- type: textarea
id: reproduction
validations:
required: true
attributes:
label: Reproduction
description: |
Kindly share a code example that demonstrates the issue you encountered. It is recommending to provide a code snippet directly.
Additionally, if you have any error messages, or stack traces related to the problem, please include them here.
Remember to use code tags to properly format your code. You can refer to the
link https://help.github.com/en/github/writing-on-github/creating-and-highlighting-code-blocks#syntax-highlighting for guidance on code formatting.
Please refrain from using screenshots, as they can be difficult to read and prevent others from copying and pasting your code.
It would be most helpful if we could reproduce your issue by simply copying and pasting your scripts and codes.
placeholder: |
Steps to reproduce the behavior:
1.
2.
3.
- type: textarea
id: expected-behavior
validations:
required: true
attributes:
label: Expected behavior
description: "Provide a brief summary of the expected behavior of the software. Provide output files or examples if possible."
- type: textarea
id: actual-behavior
validations:
required: true
attributes:
label: actual behavior
description: "Describe the actual behavior of the software and how it deviates from the expected behavior. Provide output files or examples if possible."
- type: textarea
id: additioanl-notes
validations:
required: true
attributes:
label: additional notes
description: "Provide any additional context here you think might be useful for the TensorRT-LLM team to help debug this issue (such as experiments done, potential things to investigate)."

View File

@ -15,7 +15,8 @@ repos:
rev: v4.1.0
hooks:
- id: check-added-large-files
exclude: 'cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/'
exclude: |
(?x)^(.*cubin.cpp)$
- id: check-merge-conflict
- id: check-symlinks
- id: detect-private-key

View File

@ -45,8 +45,6 @@ H200 is now 2.4x faster on Llama-70B with recent improvements to TensorRT-LLM GQ
- [TensorRT-LLM Overview](#tensorrt-llm-overview)
- [Installation](#installation)
- [Linux](./docs/source/installation.md)
- [Windows](windows/README.md)
- [Quick Start](#quick-start)
- [Support Matrix](#support-matrix)
- [Devices](#devices)
@ -110,10 +108,26 @@ concepts used in TensorRT-LLM, we recommend you to read the following
## Installation
*For Linux installation, see [`Linux`](./docs/source/installation.md).*
*For Windows installation, see [`Windows`](windows/README.md).*
After installing the [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit),
please run the following commands to install TensorRT-LLM.
Once installed, commands to build and run LLMs must be executed from the TensorRT-LLM container.
```bash
# Obtain and start the basic docker image environment
nvidia-docker run --entrypoint /bin/bash -it nvidia/cuda:12.1.0-devel-ubuntu22.04
# Install dependencies, TensorRT-LLM requires Python 3.10
apt-get update && apt-get -y install python3.10 python3-pip openmpi-bin libopenmpi-dev
# Install the latest preview version (corresponding to the main branch) of TensorRT-LLM.
# If you want to install the stable version (corresponding to the release branch), please
# remove the `--pre` option.
pip3 install tensorrt_llm -U --pre --extra-index-url https://pypi.nvidia.com
# Check installation
python3 -c "import tensorrt_llm; print(tensorrt_llm.__version__)"
```
For users who require the best performance or debugging capabilities, please refer to the instructions for
[building from source code](docs/source/build_from_source.md).
For Windows installation, see [`Windows`](windows/README.md).
## Quick Start

View File

@ -7,7 +7,7 @@ multiple GPUs or multiple nodes with multiple GPUs.
### 1. Build TensorRT-LLM and benchmarking source code
Please follow the [`installation document`](../../docs/source/installation.md) to build TensorRT-LLM.
Please follow the [`installation document`](../../README.md#installation) to build TensorRT-LLM.
Note that the benchmarking source code for C++ runtime is not built by default, you can use the argument `--benchmarks` in [`build_wheel.py`](source:scripts/build_wheel.py) to build the corresponding executable.

View File

@ -123,6 +123,17 @@ void benchmarkGptSession(std::string const& modelName, std::filesystem::path con
bufferManager.emptyTensor(MemoryType::kGPU, nvinfer1::DataType::kINT32),
bufferManager.emptyTensor(MemoryType::kGPU, nvinfer1::DataType::kINT32)};
if (session.getModelConfig().computeContextLogits())
{
generationOutput.contextLogits
= bufferManager.emptyTensor(MemoryType::kGPU, nvinfer1::DataType::kFLOAT);
}
if (session.getModelConfig().computeGenerationLogits())
{
generationOutput.generationLogits
= bufferManager.emptyTensor(MemoryType::kGPU, nvinfer1::DataType::kFLOAT);
}
TLLM_LOG_INFO(memoryCounter.toString());
for (auto r = 0; r < warmUp; ++r)
@ -175,21 +186,20 @@ void benchmarkGptSession(std::string const& modelName, std::filesystem::path con
if (session.getModelConfig().computeContextLogits() && printAllLogits)
{
std::cout << "generationOutput.contextLogits.shape: "
<< generationOutput.contextLogitsHost->getShape()
<< generationOutput.contextLogits->getShape()
<< std::endl; // (batchsize, prompt_len, vocabsize)
std::cout << "generationOutput.contextLogits: " << *generationOutput.contextLogitsHost
<< std::endl;
std::cout << "generationOutput.contextLogits: " << *generationOutput.contextLogits << std::endl;
}
if (session.getModelConfig().computeGenerationLogits() && printAllLogits)
{
std::cout << "generationOutput.generationLogits.shape: "
<< generationOutput.generationLogitsHost->getShape()
<< generationOutput.generationLogits->getShape()
<< std::endl; // (batchsize, beamwidth, maxNewTokens, vocabsize)
generationOutput.generationLogitsHost->reshape(ITensor::makeShape({batchSize * beamWidth,
generationOutput.generationLogits->reshape(ITensor::makeShape({batchSize * beamWidth,
maxNewTokens, modelConfig.getVocabSizePadded(worldConfig.getSize())}));
std::cout << "generationOutput.generationLogits: " << *generationOutput.generationLogitsHost
std::cout << "generationOutput.generationLogits: " << *generationOutput.generationLogits
<< std::endl;
}
}

View File

@ -17,7 +17,6 @@ from argparse import ArgumentParser
# isort: off
import torch
import tensorrt as trt
# isort: on
from cuda import cuda, cudart
from mpi4py import MPI
@ -25,8 +24,10 @@ from polygraphy.backend.trt import CreateConfig, EngineFromNetwork
import tensorrt_llm as tllm
from tensorrt_llm import Mapping, Tensor
from tensorrt_llm._ipc_utils import IpcMemory, peer_access
from tensorrt_llm._ipc_utils import peer_access
from tensorrt_llm.functional import AllReduceStrategy, allreduce
from tensorrt_llm.plugin.plugin import (current_all_reduce_helper,
init_all_reduce_helper)
def allreduce_benchmark(dtype: str, test_range: str = "10,10000000,10"):
@ -42,25 +43,17 @@ def allreduce_benchmark(dtype: str, test_range: str = "10,10000000,10"):
if world_size == 1:
raise RuntimeError("Benchmark must run with mpi_world_size > 1")
ipc_barriers_in = IpcMemory(
mapping, IpcMemory.IPC_BARRIERS_SIZE_PER_GPU * mapping.tp_size)
ipc_barriers_out = IpcMemory(
mapping, IpcMemory.IPC_BARRIERS_SIZE_PER_GPU * mapping.tp_size)
torch_dtype = tllm._utils.str_dtype_to_torch(dtype)
min_size, max_size, ratio = [int(i) for i in test_range.split(",")]
inner_loop = 1000
size = min_size
dtype_size = torch.finfo(torch_dtype).bits // 8
init_all_reduce_helper()
while size < max_size:
ipc_buffers = IpcMemory(mapping, size * 4)
workspace = torch.tensor(ipc_buffers.serialize() +
ipc_barriers_in.serialize() +
ipc_barriers_out.serialize(),
dtype=torch.int64,
device="cpu")
input = torch.zeros(size, dtype=torch_dtype, device="cuda")
_buffers, workspace = current_all_reduce_helper().allocate_workspace(
mapping, size * dtype_size)
input = torch.ones(size, dtype=torch_dtype, device="cuda")
for strategy in [
AllReduceStrategy.RING, AllReduceStrategy.ONESHOT,
@ -77,16 +70,11 @@ def allreduce_benchmark(dtype: str, test_range: str = "10,10000000,10"):
shape=input.shape,
dtype=tllm.str_dtype_to_trt(dtype))
w = Tensor(name='workspace',
shape=workspace.shape,
dtype=trt.int64)
current_all_reduce_helper().set_workspace_tensor(mapping)
current = x
for i in range(inner_loop):
current = allreduce(
current, mapping.tp_group,
w if strategy != AllReduceStrategy.RING else None, i,
strategy)
for _ in range(inner_loop):
current = allreduce(current, mapping.tp_group, strategy)
output = current.trt_tensor
output.name = 'output'
@ -104,7 +92,7 @@ def allreduce_benchmark(dtype: str, test_range: str = "10,10000000,10"):
output = torch.zeros_like(input)
stream = torch.cuda.current_stream()
feed_dict = {'x': input, 'workspace': workspace}
feed_dict = {'x': input, 'all_reduce_workspace': workspace}
session = tllm.runtime.Session.from_engine(build_engine())
_, start = cuda.cuEventCreate(0)
@ -119,9 +107,11 @@ def allreduce_benchmark(dtype: str, test_range: str = "10,10000000,10"):
cuda.cuEventRecord(stop, stream.cuda_stream)
torch.cuda.synchronize()
_, ms = cuda.cuEventElapsedTime(start, stop)
assert torch.allclose(output, (input * world_size)**inner_loop)
if mapping.rank == 0:
print(f"{size=}, {strategy=}, {ms=}")
size *= ratio
if mapping.rank == 0:
print("")

View File

@ -197,8 +197,8 @@ _allowed_configs = {
builder_opt=None,
quantization="int8_sq_per_token_channel",
)),
"gpt-next_2b":
ModelConfig(name="gpt-next_2b",
"gpt_next_2b":
ModelConfig(name="gpt_next_2b",
family="gpt",
benchmark_type="gpt",
build_config=BuildConfig(
@ -305,25 +305,6 @@ _allowed_configs = {
max_output_len=200,
builder_opt=None,
)),
"llama_7b_moe":
ModelConfig(name="llama_7b_moe",
family="llama",
benchmark_type="gpt",
build_config=BuildConfig(
num_layers=32,
num_heads=32,
hidden_size=4096,
vocab_size=32000,
hidden_act='silu',
n_positions=2048,
inter_size=11008,
max_batch_size=128,
max_input_len=512,
max_output_len=200,
builder_opt=None,
moe_num_experts=4,
moe_top_k=1,
)),
"llama_13b":
ModelConfig(name="llama_13b",
family="llama",
@ -427,6 +408,25 @@ _allowed_configs = {
max_output_len=200,
builder_opt=None,
quantization="int8_sq_per_tensor")),
"mixtral_8x7b":
ModelConfig(name="mixtral_8x7b",
family="llama",
benchmark_type="gpt",
build_config=BuildConfig(
num_layers=32,
num_heads=32,
hidden_size=4096,
vocab_size=32000,
hidden_act='swiglu',
n_positions=2048,
inter_size=14336,
max_batch_size=128,
max_input_len=512,
max_output_len=200,
builder_opt=None,
moe_num_experts=8,
moe_top_k=2,
)),
"gptj_6b":
ModelConfig(name="gptj_6b",
family="gptj",

View File

@ -251,7 +251,6 @@ def main(args):
from bert_benchmark import BERTBenchmark
from enc_dec_benchmark import EncDecBenchmark
from gpt_benchmark import GPTBenchmark
from mem_monitor import MemoryMonitor
import tensorrt_llm
from tensorrt_llm.logger import logger
@ -281,6 +280,13 @@ def main(args):
rank = tensorrt_llm.mpi_rank()
world_size = tensorrt_llm.mpi_world_size()
# TODO: Re-enable memory monitor for multi-gpu benchmarks.
# Current Mem Monitor will cause benchmark script hang
# because MPI does not work well with multiprocessing.
disable_mem_monitor = world_size > 1
if not disable_mem_monitor:
from mem_monitor import MemoryMonitor
benchmark_profiler = None
if args.model in get_allowed_models(benchmark_type="gpt"):
benchmark_profiler = BenchmarkProfiler()
@ -314,8 +320,9 @@ def main(args):
torch.cuda.empty_cache()
latencies = []
memory_monitor = MemoryMonitor()
memory_monitor.start()
if not disable_mem_monitor:
memory_monitor = MemoryMonitor()
memory_monitor.start()
iter_idx = 0
try:
@ -346,12 +353,17 @@ def main(args):
except Exception as e:
print("Found exception during benchmarking", e.with_traceback())
memory_monitor.kill()
if not disable_mem_monitor:
memory_monitor.kill()
raise e
memory_monitor.stop()
_, peak_gpu_used = memory_monitor.get_peak_memory_usage("GiB")
peak_gpu_used = round(peak_gpu_used, 3)
if not disable_mem_monitor:
memory_monitor.stop()
_, peak_gpu_used = memory_monitor.get_peak_memory_usage("GiB")
peak_gpu_used = round(peak_gpu_used, 3)
else:
peak_gpu_used = 0.0
if benchmark_profiler is not None:
benchmark_profiler.add_aux_info('iter_count', iter_idx)
benchmark_profiler.stop()

View File

@ -12,11 +12,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from multiprocessing import Event, Process, Queue
from tensorrt_llm.logger import logger
from tensorrt_llm.profiler import (MemUnitType, bytes_to_target_unit,
device_memory_info)
device_memory_info, host_memory_info)
class MemoryMonitor:
@ -28,6 +29,7 @@ class MemoryMonitor:
self._peak_host_memory = 0
self._peak_device_memory = 0
self.pid = os.getpid()
self.device_handles = {}
self.signal_event = Event() # Sending signal to subprocess
@ -49,19 +51,28 @@ class MemoryMonitor:
self.signal_event.set()
logger.debug("Sent signal to stop memory monitor subprocess.")
peak_mem_use = self.peak_mem_queue.get()
self._peak_host_memory = max(self._peak_host_memory, peak_mem_use[0])
self._peak_device_memory = max(self._peak_device_memory,
self.peak_mem_queue.get())
peak_mem_use[1])
self.mem_monitor_process.join()
self.mem_monitor_process = None
logger.debug("Memory monitor subprocess joined.")
def _upd_peak_memory_usage(self, signal_event, peak_mem_queue):
peak_used, _, _ = device_memory_info()
peak_host_used, peak_device_used = self.get_memory_usage()
while not signal_event.is_set():
used, _, _ = device_memory_info()
peak_used = max(used, peak_used)
peak_mem_queue.put(peak_used)
host_used, device_used = self.get_memory_usage()
peak_host_used = max(host_used, peak_host_used)
peak_device_used = max(device_used, peak_device_used)
peak_mem_queue.put((peak_host_used, peak_device_used))
def get_memory_usage(self):
host_used, _, _ = host_memory_info(self.pid)
device_used, _, _ = device_memory_info()
return host_used, device_used
def get_peak_memory_usage(self, unit: MemUnitType = 'GiB'):
return bytes_to_target_unit(self._peak_host_memory, unit), \

View File

@ -29,7 +29,7 @@ project(tensorrt_llm LANGUAGES CXX)
# Build options
option(BUILD_PYT "Build in PyTorch TorchScript class mode" ON)
option(BUILD_PYBIND "Build Python bindings for C++ runtime and batch manager"
OFF)
ON)
option(BUILD_TESTS "Build Google tests" ON)
option(BUILD_BENCHMARKS "Build benchmarks" ON)
option(NVTX_DISABLE "Disable all NVTX features" ON)

View File

@ -33,7 +33,7 @@ public:
using SizeType = tensorrt_llm::runtime::SizeType;
explicit TrtGptModelOptionalParams(KvCacheConfig const& kvCacheConfig = KvCacheConfig{},
std::optional<SizeType> maxNumSequences = std::nullopt, bool enableTrtOverlap = true,
std::optional<SizeType> maxNumSequences = std::nullopt, bool enableTrtOverlap = false,
std::optional<std::vector<SizeType>> const& deviceIds = std::nullopt, bool normalizeLogProbs = true,
bool logIterationData = false)
: kvCacheConfig{kvCacheConfig}

View File

@ -39,8 +39,6 @@ public:
{
TLLM_CHECK_WITH_INFO(static_cast<bool>(this->ids), "Invalid ids tensor");
TLLM_CHECK_WITH_INFO(static_cast<bool>(this->lengths), "Invalid lengths tensor");
generationLogitsFragments = std::make_shared<std::vector<TensorPtr>>();
}
// mandatory parameters
@ -53,11 +51,6 @@ public:
TensorPtr contextLogits; // [batch_size, max_input_length, vocab_size_padded], if packed, the shape will be
// [packed_size, vocab_size_padded]
TensorPtr generationLogits; // [batch_size, beam_width, max_output_length, vocab_size_padded]
// generation logit pointer list
std::shared_ptr<std::vector<TensorPtr>> generationLogitsFragments;
TensorPtr contextLogitsHost;
TensorPtr generationLogitsHost;
// callbacks
Callback onTokenGenerated;

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:32eb4ace3a218b307f4a8ade8b2f47540d690118539c75bbff82ee251e098f3c
size 1927834
oid sha256:85691c252e1025d087c3aeb34fa249ea2424b079b883c28228ec0f838704ee0b
size 1920092

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:dcf82d3f89bfd45abffbdd0856bd18d0c80154397fedd4617f65efca8c5619d5
size 1938574
oid sha256:5ac6ca64007c016dff30aadb3f167b263c31aef995a66ac96d731d2c077a3535
size 1930918

View File

@ -1,3 +1,3 @@
0ad3a2c87135a42a7c766472e3690945 libtensorrt_llm_batch_manager_static.a
e83f9560653b8a7bf0e5465fed610422 libtensorrt_llm_batch_manager_static.pre_cxx11.a
87230159783a783af845d5fbc908d47108fbb754 commit
435c2a649cca52168e243ff67b7e4811 libtensorrt_llm_batch_manager_static.a
a6c17eee9cb72f3442482656a0f7307c libtensorrt_llm_batch_manager_static.pre_cxx11.a
d086d9953961e0cf567a4f6185d929df4c5bf9ec commit

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:00627c7c18991cf9d832c63d4df4174744b4bed2d3a393051db5523f36a6a16b
size 1867794
oid sha256:52183471dbd6db61eff4cc09f78f56e8ce48f624947aa81806f7fb7585dd2617
size 1863048

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:3d06f1cc77ea915b9494f6362a8229dbc6d4728711b25c7dab813869a75b599e
size 1847520
oid sha256:a63a46d5744b119398678983e252b2a142d2005894643650eb698cb98e5362cd
size 1838410

View File

@ -1,2 +1,2 @@
1424807fe5927445d2daf3a3d9db19cc libtensorrt_llm_batch_manager_static.a
afc1bba58d7138cfb018790b62a92990 libtensorrt_llm_batch_manager_static.pre_cxx11.a
591e69f5ed2dd632ef62dc57900894db libtensorrt_llm_batch_manager_static.a
1733f8c663e9c0c1969e74f6a741dbb8 libtensorrt_llm_batch_manager_static.pre_cxx11.a

View File

@ -200,15 +200,6 @@ inline __device__ __nv_bfloat162 bf16exp2(const __nv_bfloat162 x)
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)
#if defined(CUDART_VERSION) && (CUDART_VERSION < 12020)
inline __device__ __nv_bfloat162 operator*(const __nv_bfloat162 x, const __nv_bfloat162 y)
{
return bf16hmul2(x, y);
};
inline __device__ __nv_bfloat162 operator+(const __nv_bfloat162 x, const __nv_bfloat162 y)
{
return bf16hadd2(x, y);
};
inline __device__ __nv_bfloat162 make_bfloat162(const __nv_bfloat16 x, const __nv_bfloat16 y)
{
@ -301,3 +292,22 @@ inline __device__ __nv_bfloat162 bf16hfma2(__nv_bfloat162 a, __nv_bfloat162 b, _
} // namespace common
} // namespace tensorrt_llm
// Operator definitions intentionally in global namespace
namespace
{
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)
#if defined(CUDART_VERSION) && (CUDART_VERSION < 12020)
inline __device__ __nv_bfloat162 operator*(const __nv_bfloat162 x, const __nv_bfloat162 y)
{
return tensorrt_llm::common::bf16hmul2(x, y);
};
inline __device__ __nv_bfloat162 operator+(const __nv_bfloat162 x, const __nv_bfloat162 y)
{
return tensorrt_llm::common::bf16hadd2(x, y);
};
#endif
#endif
} // namespace

View File

@ -0,0 +1,40 @@
/*
* Copyright (c) 2019-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <cstddef>
namespace tensorrt_llm::utils::customAllReduceUtils
{
constexpr size_t NUM_POINTERS_PER_RANK = 4;
namespace
{
// WARNING: MUST BE KEPT IN SYNC with tensorrt_llm/plugin/plugin.py
size_t getMaxRequiredWorkspaceSize(int worldSize) noexcept
{
if (worldSize <= 2)
{
return 16 * 1000 * 1000;
}
return 8 * 1000 * 1000;
}
} // namespace
} // namespace tensorrt_llm::utils::customAllReduceUtils

View File

@ -22,6 +22,21 @@
namespace tensorrt_llm::common
{
// XQA kernels (optimized kernels for generation phase).
bool forceXQAKernels()
{
const char* force_xqa_env_var = getenv("TRTLLM_FORCE_XQA");
static bool forceXQA = false;
if (force_xqa_env_var != nullptr)
{
if (force_xqa_env_var[0] == '1' && force_xqa_env_var[1] == '\0')
{
forceXQA = true;
}
}
return forceXQA;
}
// Tune the number of blocks per sequence for accuracy/performance purpose.
bool getEnvMmhaMultiblockDebug()
{

View File

@ -20,6 +20,9 @@
namespace tensorrt_llm::common
{
// XQA kernels (optimized kernels for generation phase).
bool forceXQAKernels();
// Tune the number of blocks per sequence for accuracy/performance purpose.
bool getEnvMmhaMultiblockDebug();

View File

@ -860,5 +860,47 @@ bool invokeCheckRange(const T* buffer, const size_t size, T min, T max, bool* d_
template bool invokeCheckRange<int>(
const int* buffer, const size_t size, int min, int max, bool* d_within_range, cudaStream_t stream);
/*
* Determine the total workspace size based on a vector containing multiple variable sizes.
*/
size_t calcAlignedSize(const std::vector<size_t>& sizes, const size_t ALIGN_BYTES)
{
const size_t ALIGN_MASK = ~(ALIGN_BYTES - 1);
// Check ALIGN_BYTES is a power of 2
assert((ALIGN_BYTES & (ALIGN_BYTES - 1)) == 0);
size_t total = 0;
for (auto sz : sizes)
{
total += (sz + ALIGN_BYTES - 1) & ALIGN_MASK;
}
// We add extra "ALIGN_BYTES - 1" bytes in case the start address passed to the function calcAlignedPointers() is
// not aligned.
return total + ALIGN_BYTES - 1;
}
/*
* Given the address of the workspace and the vector containing multiple variable sizes, calculate the start addresses
* of each variable.
*/
void calcAlignedPointers(
std::vector<void*>& outPtrs, const void* p, const std::vector<size_t>& sizes, size_t ALIGN_BYTES)
{
const size_t ALIGN_MASK = ~(ALIGN_BYTES - 1);
// Check ALIGN_BYTES is a power of 2
assert((ALIGN_BYTES & (ALIGN_BYTES - 1)) == 0);
// In case the start address is not aligned
char* ptr = reinterpret_cast<char*>((reinterpret_cast<size_t>(p) + ALIGN_BYTES - 1) & ALIGN_MASK);
outPtrs.reserve(sizes.size());
for (auto sz : sizes)
{
outPtrs.push_back(ptr);
ptr += (sz + ALIGN_BYTES - 1) & ALIGN_MASK;
}
}
} // namespace common
} // namespace tensorrt_llm

View File

@ -258,5 +258,8 @@ size_t cuda_datatype_size(TRTLLMCudaDataType dt);
template <typename T>
bool invokeCheckRange(const T* buffer, const size_t size, T min, T max, bool* d_within_range, cudaStream_t stream);
size_t calcAlignedSize(const std::vector<size_t>& sizes, size_t ALIGN_BYTES = 256);
void calcAlignedPointers(
std::vector<void*>& outPtrs, const void* p, const std::vector<size_t>& sizes, size_t ALIGN_BYTES = 256);
} // namespace common
} // namespace tensorrt_llm

View File

@ -0,0 +1,542 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*!
\file
\brief Based on cutlass/include/cutlass/gemm/kernel/gemm_grouped.h
*/
#pragma once
#include <limits>
#include <numeric>
#include <vector>
#include "cutlass/arch/arch.h"
#include "cutlass/cutlass.h"
#include "cutlass/device_kernel.h"
#include "cutlass/numeric_types.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/gemm/kernel/gemm_universal.h"
#include "cutlass/gemm/threadblock/threadblock_swizzle.h"
#include "cutlass/gemm/device/default_gemm_configuration.h"
#include "cutlass/gemm/kernel/default_gemm_universal.h"
#include "cutlass/trace.h"
////////////////////////////////////////////////////////////////////////////////
namespace cutlass
{
namespace gemm
{
namespace device
{
/////////////////////////////////////////////////////////////////////////////////////////////////
template <typename T_IN, typename T_OUT>
__global__ void splitkReduction(T_OUT** out_tensor, const T_IN* in_tensor, const GemmCoord* problem_sizes, int splitk,
int64_t* splitk_buffer_offsets)
{
// in_tensor: [problem_idx, k_partition, hidden_size]
// Note that different requests of in_tensor might have different hidden_size (=m*n)
// so, we need to use splitk_buffer_offsets.
// out_tensor: problem_idx * [hidden_size]
const int problem_idx = blockIdx.y;
GemmCoord problem = problem_sizes[problem_idx];
const int hidden_size = problem.m() * problem.n();
const T_IN* in_tensor_ = in_tensor + splitk_buffer_offsets[problem_idx] * splitk;
T_OUT* out_tensor_ = out_tensor[problem_idx];
for (int i = threadIdx.x + blockIdx.x * blockDim.x; i < hidden_size; i += blockDim.x * gridDim.x)
{
float sum = 0.0f;
for (int k_idx = 0; k_idx < splitk; k_idx++)
{
sum += (float) in_tensor_[k_idx * hidden_size + i];
}
out_tensor_[i] = (T_OUT) (sum);
}
}
/// GEMM Grouped
template <typename BaseKernel_>
class BaseSplitkGrouped
{
public:
using BaseKernel = BaseKernel_;
using ElementA = typename BaseKernel::ElementA;
using LayoutA = typename BaseKernel::LayoutA;
using TensorRefA = TensorRef<ElementA const, LayoutA>;
static ComplexTransform const kTransformA = BaseKernel::kTransformA;
static int const kAlignmentA = BaseKernel::kAlignmentA;
using ElementB = typename BaseKernel::ElementB;
using LayoutB = typename BaseKernel::LayoutB;
using TensorRefB = TensorRef<ElementB const, LayoutB>;
static ComplexTransform const kTransformB = BaseKernel::kTransformB;
static int const kAlignmentB = BaseKernel::kAlignmentB;
using ElementC = typename BaseKernel::ElementC;
using LayoutC = typename BaseKernel::LayoutC;
using TensorRefC = TensorRef<ElementC const, LayoutC>;
using TensorRefD = TensorRef<ElementC, LayoutC>;
static int const kAlignmentC = BaseKernel::kAlignmentC;
using ElementAccumulator = typename BaseKernel::Mma::Policy::Operator::ElementC;
using EpilogueOutputOp = typename BaseKernel::EpilogueOutputOp;
using ThreadblockSwizzle = typename threadblock::GemmSplitKHorizontalThreadblockSwizzle;
using Operator = typename BaseKernel::Operator;
using WarpMmaOperator = typename BaseKernel::Mma::Policy::Operator;
using ArchMmaOperator = typename WarpMmaOperator::ArchMmaOperator;
using MathOperator = typename WarpMmaOperator::MathOperator;
using OperatorClass = typename WarpMmaOperator::OperatorClass;
using ArchTag = typename WarpMmaOperator::ArchTag;
using ThreadblockShape = typename BaseKernel::Mma::Shape;
using WarpShape = typename BaseKernel::WarpShape;
using InstructionShape = typename BaseKernel::InstructionShape;
static int const kStages = BaseKernel::Mma::kStages;
/// Argument structure
using Arguments = typename BaseKernel::Arguments;
using ProblemInfo = typename BaseKernel::ProblemVisitor::ProblemInfo;
protected:
/// Kernel parameters object
typename BaseKernel::Params gemm_params_;
private:
/// Get the number of tiles across all problems in a group
static int32_t group_tile_count(const cutlass::gemm::GemmCoord* problem_sizes_ptr, int problem_count)
{
int32_t tiles = 0;
for (int32_t i = 0; i < problem_count; ++i)
{
cutlass::gemm::GemmCoord problem = problem_sizes_ptr[i];
BaseKernel::ProblemVisitor::possibly_transpose_problem(problem);
tiles += problem_tile_count(problem);
}
return tiles;
}
/// Copy from `data` to `workspace`
Status copy_to_workspace(void* workspace, void* data, size_t bytes)
{
cudaError_t cuda_error = cudaMemcpy(workspace, data, bytes, cudaMemcpyHostToDevice);
if (cuda_error != cudaSuccess)
{
// Call cudaGetLastError() to clear the error bit
cuda_error = cudaGetLastError();
CUTLASS_TRACE_HOST(" cudaMemcpy() returned error " << cudaGetErrorString(cuda_error));
return Status::kErrorInternal;
}
return Status::kSuccess;
}
/// Precomputes scheduling information for the grouped GEMM
Status precompute(Arguments const& args, int32_t tile_count, void* workspace)
{
size_t workspace_bytes = get_workspace_size(args);
std::vector<uint8_t> host_workspace(workspace_bytes);
BaseKernel::ProblemVisitor::host_precompute(
args.host_problem_sizes, args.problem_count, args.threadblock_count, (void*) host_workspace.data());
return copy_to_workspace(workspace, host_workspace.data(), workspace_bytes);
}
/// Reorder `data` according to `indices`
template <typename T>
static void reorder_array(T* data, const std::vector<size_t>& indices)
{
// For now, simply create a copy of the data and then copy over to the original.
std::vector<T> copy(indices.size());
for (size_t i = 0; i < indices.size(); ++i)
{
copy.at(i) = data[indices[i]];
}
memcpy(data, copy.data(), indices.size() * sizeof(T));
}
public:
/// Constructs the GEMM.
BaseSplitkGrouped() {}
/// Determines whether the GEMM can execute the given problem.
static Status can_implement(Arguments const& args)
{
return BaseKernel::can_implement(args);
}
/// Get the number of tiles in a problem
static int32_t problem_tile_count(cutlass::gemm::GemmCoord const& problem)
{
auto grid = BaseKernel::ProblemVisitor::grid_shape(problem);
return BaseKernel::ProblemVisitor::tile_count(grid);
}
/// Get the number of tiles across all problems in a group
static int32_t group_tile_count(Arguments const& args)
{
if (args.host_problem_sizes == nullptr)
{
CUTLASS_TRACE_HOST("Received nullptr for `args.host_problem_sizes");
return -1;
}
return group_tile_count(args.host_problem_sizes, args.problem_count);
}
/// Gets the workspace size
static size_t get_workspace_size(Arguments const& args)
{
size_t total_mn = 0;
for (int i = 0; i < args.problem_count; i++)
{
total_mn += args.host_problem_sizes[i].m() * args.host_problem_sizes[i].n();
}
size_t workSpaceSize = total_mn * sizeof(ElementAccumulator) * args.split_k_slices;
if (BaseKernel::ProblemVisitor::kRequiresPrecomputation)
{
workSpaceSize += BaseKernel::ProblemVisitor::get_workspace_size(
args.host_problem_sizes, args.problem_count, args.threadblock_count);
}
return workSpaceSize;
}
/// Computes the grid shape
static dim3 get_grid_shape(Arguments const& args)
{
return dim3(args.threadblock_count, 1, 1);
}
/// Computes the maximum number of active blocks per multiprocessor
static int maximum_active_blocks(int smem_capacity = -1)
{
CUTLASS_TRACE_HOST("BaseSplitkGrouped::maximum_active_blocks()");
int smem_size = int(sizeof(typename BaseKernel::SharedStorage));
CUTLASS_TRACE_HOST(" smem_size: " << smem_size << " bytes");
cudaError_t result;
if (smem_size > (48 << 10))
{
result = cudaFuncSetAttribute(Kernel<BaseKernel>, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
if (result != cudaSuccess)
{
// Call cudaGetLastError() to clear the error bit
result = cudaGetLastError();
CUTLASS_TRACE_HOST(" cudaFuncSetAttribute() returned error " << cudaGetErrorString(result));
return -1;
}
}
int max_active_blocks = -1;
result = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&max_active_blocks, Kernel<BaseKernel>, BaseKernel::kThreadCount, smem_size);
if (result != cudaSuccess)
{
// Call cudaGetLastError() to clear the error bit
result = cudaGetLastError();
CUTLASS_TRACE_HOST(
" cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error " << cudaGetErrorString(result));
return -1;
}
CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks);
return max_active_blocks;
}
/// Sorts each pointer passed in according to the indices that sort
/// `problem_sizes_ptr` in descending order of problem-K dimension.
static void sort_problems(int problem_count, cutlass::gemm::GemmCoord* problem_sizes_ptr, int64_t* lda_host_ptr,
int64_t* ldb_host_ptr, int64_t* ldc_host_ptr, int64_t* ldd_host_ptr, int64_t* offset_A_ptr,
int64_t* offset_B_ptr, int64_t* offset_C_ptr, int64_t* offset_D_ptr)
{
std::vector<size_t> indices(problem_count);
std::iota(indices.begin(), indices.end(), 0);
std::stable_sort(indices.begin(), indices.end(),
[&problem_sizes_ptr](size_t i, size_t j) { return problem_sizes_ptr[i].k() > problem_sizes_ptr[j].k(); });
reorder_array(problem_sizes_ptr, indices);
reorder_array(lda_host_ptr, indices);
reorder_array(ldb_host_ptr, indices);
reorder_array(ldc_host_ptr, indices);
reorder_array(ldd_host_ptr, indices);
reorder_array(offset_A_ptr, indices);
reorder_array(offset_B_ptr, indices);
reorder_array(offset_C_ptr, indices);
reorder_array(offset_D_ptr, indices);
}
/// Computes the number of threadblocks to launch for the grouped kernel
static int sufficient(
const cutlass::gemm::GemmCoord* problem_sizes_ptr = nullptr, int problem_count = 0, int available_sm_count = -1)
{
// Determine the number of blocks that would be launched to fill up a single
// wave on the GPU with each SM having maximum occupancy.
int device_idx;
cudaError_t result = cudaGetDevice(&device_idx);
if (result != cudaSuccess)
{
// Call cudaGetLastError() to clear the error bit
result = cudaGetLastError();
CUTLASS_TRACE_HOST(" cudaGetDevice() returned error " << cudaGetErrorString(result));
return 0;
}
int multiprocessor_count;
result = cudaDeviceGetAttribute(&multiprocessor_count, cudaDevAttrMultiProcessorCount, device_idx);
if (result != cudaSuccess)
{
CUTLASS_TRACE_HOST(" cudaDeviceGetAttribute() returned error " << cudaGetErrorString(result));
return 0;
}
bool override_sm_count = (available_sm_count < 0 || available_sm_count > multiprocessor_count);
if (override_sm_count)
{
available_sm_count = multiprocessor_count;
}
int max_active_blocks = maximum_active_blocks();
if (max_active_blocks <= 0)
{
return 0;
}
int occupancy_based_block_count = available_sm_count * max_active_blocks;
if (problem_sizes_ptr == nullptr || problem_count == 0)
{
return occupancy_based_block_count;
}
int total_tiles = group_tile_count(problem_sizes_ptr, problem_count);
// If the group contains a single problem, launching the exact number of
// threadblocks needed to cover the problem minimizes the work performed
// per threadblock in finding the next tile to compute. We return total_tiles
// unless the user has provided the SM count.
if (problem_count == 1 && override_sm_count)
{
return total_tiles;
}
// Choose between the full wave of threadblocks and the tile count. If there
// are fewer tiles in the group than threadblocks in the full wave, only
// some threadblocks will be assigned tiles. Those threadblocks
// which are not assigned tiles still need to perform the work of iterating through
// problem sizes to determine that they have no work to do. This competes for cycles
// with those threadblocks that are assigned tiles to compute.
return std::min(total_tiles, occupancy_based_block_count);
}
/// Initializes GEMM state from arguments.
Status initialize(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr)
{
CUTLASS_TRACE_HOST("BaseSplitkGrouped::initialize() - workspace "
<< workspace << ", stream: " << (stream ? "non-null" : "null"));
// Workspace
size_t workspace_bytes = get_workspace_size(args);
if (workspace_bytes && !workspace)
{
return Status::kErrorWorkspaceNull;
}
if (BaseKernel::ProblemVisitor::kRequiresPrecomputation)
{
int32_t tile_count = group_tile_count(args);
Status status = precompute(args, tile_count, workspace);
if (status != Status::kSuccess)
{
return status;
}
gemm_params_ = typename BaseKernel::Params(args, workspace, tile_count);
}
else
{
gemm_params_ = typename BaseKernel::Params(args, workspace);
}
// Specify shared memory capacity for kernel.
int smem_size = int(sizeof(typename BaseKernel::SharedStorage));
if (smem_size >= (48 << 10))
{
cudaError_t result
= cudaFuncSetAttribute(Kernel<BaseKernel>, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
if (result != cudaSuccess)
{
return Status::kErrorInternal;
}
}
return Status::kSuccess;
}
/// Lightweight update given a subset of arguments
Status update(Arguments const& args, void* workspace = nullptr)
{
size_t workspace_bytes = get_workspace_size(args);
if (workspace_bytes && !workspace)
{
return Status::kErrorWorkspaceNull;
}
if (BaseKernel::ProblemVisitor::kRequiresPrecomputation)
{
int32_t tile_count = group_tile_count(args);
Status status = precompute(args, tile_count, workspace);
if (status != Status::kSuccess)
{
return status;
}
gemm_params_.update(args, workspace, tile_count);
}
else
{
gemm_params_.update(args, workspace);
}
return Status::kSuccess;
}
/// Runs the kernel using initialized state.
Status run(cudaStream_t stream = nullptr)
{
if (!gemm_params_.problem_visitor.problem_count)
{
return Status::kSuccess;
}
//
// Launch kernel
//
// Launch splitk grouped gemm
{
dim3 grid(gemm_params_.threadblock_count, 1, gemm_params_.split_k_slices);
dim3 block(BaseKernel::kThreadCount, 1, 1);
int smem_size = int(sizeof(typename BaseKernel::SharedStorage));
cutlass::Kernel<BaseKernel><<<grid, block, smem_size, stream>>>(gemm_params_);
cudaError_t result = cudaGetLastError();
if (result != cudaSuccess)
{
CUTLASS_TRACE_HOST(" grid launch failed with error " << cudaGetErrorString(result));
return Status::kErrorInternal;
}
}
// Launch splitkReduction
{
dim3 grid(32, gemm_params_.problem_visitor.problem_count);
dim3 block(256);
splitkReduction<<<grid, block, 0, stream>>>(gemm_params_.ptr_D, gemm_params_.ptr_D_split,
gemm_params_.problem_visitor.problem_sizes, gemm_params_.split_k_slices,
gemm_params_.splitk_buffer_offsets);
cudaError_t result = cudaGetLastError();
if (result != cudaSuccess)
{
CUTLASS_TRACE_HOST(" grid launch failed with error " << cudaGetErrorString(result));
return Status::kErrorInternal;
}
}
return Status::kSuccess;
}
/// Runs the kernel using initialized state.
Status operator()(cudaStream_t stream = nullptr)
{
return run(stream);
}
/// Initializes and runs the kernel.
Status operator()(Arguments const& args, void* workspace, cudaStream_t stream = nullptr)
{
Status status = initialize(args, workspace, stream);
if (status == Status::kSuccess)
{
status = run(stream);
}
return status;
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
/// GEMM Grouped
template <typename GemmKernel_>
class SplitkGemmGrouped : public BaseSplitkGrouped<GemmKernel_>
{
public:
using GemmKernel = GemmKernel_;
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace device
} // namespace gemm
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,207 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief
Default kernel-level GEMM definitions combine threadblock-scoped matrix multiply-add with
the appropriate threadblock-scoped epilogue.
Note, CUTLASS epilogues universally target row-major outputs. Column-major outputs are
accommodated by exchanging A and B operands and assuming transposed layouts. Partial
specializations here choose 'device::GemmTransposed' to implement this functionality.
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/complex.h"
#include "cutlass/layout/matrix.h"
#include "cutlass/numeric_types.h"
#include "cutlass/gemm/device/default_gemm_configuration.h"
#include "cutlass/gemm/kernel/default_gemm.h"
#include "cutlass/gemm/kernel/default_gemm_complex.h"
#include "cutlass/gemm/kernel/gemm_transpose_operands.h"
#include "cutlass/layout/permute.h"
#include "splitk_gemm_grouped.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass
{
namespace gemm
{
namespace kernel
{
/////////////////////////////////////////////////////////////////////////////////////////////////
template <
/// Element type for A matrix operand
typename ElementA_,
/// Layout type for A matrix operand
typename LayoutA_,
/// Complex elementwise transformation on A operand
ComplexTransform TransformA,
/// Access granularity of A matrix in units of elements
int kAlignmentA,
/// Element type for B matrix operand
typename ElementB_,
/// Layout type for B matrix operand
typename LayoutB_,
/// Complex elementwise transformation on B operand
ComplexTransform TransformB,
/// Access granularity of B matrix in units of elements
int kAlignmentB,
/// Element type for C and D matrix operands
typename ElementC_,
/// Layout type for C and D matrix operands
typename LayoutC_,
/// Element type for internal accumulation
typename ElementAccumulator,
/// Operator class tag
typename OperatorClass,
/// Tag indicating architecture to tune for
typename ArchTag,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape,
/// Warp-level tile size (concept: GemmShape)
typename InstructionShape,
/// Epilogue output operator
typename EpilogueOutputOp,
/// Threadblock-level swizzling operator
typename ThreadblockSwizzle,
/// Number of stages used in the pipelined mainloop
int Stages,
/// Whether the schedule of problems to visit has been precomputed
GroupScheduleMode GroupScheduleMode_ = GroupScheduleMode::kDeviceOnly,
/// Operation performed by GEMM
typename Operator = typename device::DefaultGemmConfiguration<OperatorClass, ArchTag, ElementA_, ElementB_,
ElementC_, ElementAccumulator>::Operator,
/// Use zfill or predicate for out-of-bound cp.async
SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone,
/// Permute result D
typename PermuteDLayout = layout::NoPermute,
///
typename Enable = void>
struct DefaultSplitkGemmGrouped;
/////////////////////////////////////////////////////////////////////////////////////////////////
//
// Real-valued GEMM kernels
//
template <
/// Element type for A matrix operand
typename ElementA,
/// Layout type for A matrix operand
typename LayoutA,
/// Access granularity of A matrix in units of elements
int kAlignmentA,
/// Element type for B matrix operand
typename ElementB,
/// Layout type for B matrix operand
typename LayoutB,
/// Access granularity of B matrix in units of elements
int kAlignmentB,
/// Element type for C and D matrix operands
typename ElementC,
/// Layout type for C and D matrix operands
typename LayoutC,
/// Element type for internal accumulation
typename ElementAccumulator,
/// Operator class tag
typename OperatorClass,
/// Tag indicating architecture to tune for
typename ArchTag,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape,
/// Warp-level tile size (concept: GemmShape)
typename InstructionShape,
/// Epilogue output operator
typename EpilogueOutputOp,
/// Threadblock-level swizzling operator
typename ThreadblockSwizzle,
/// Number of stages used in the pipelined mainloop
int Stages,
/// Whether the schedule of problems to visit has been precomputed
GroupScheduleMode GroupScheduleMode_,
/// Operation performed by GEMM
typename Operator,
/// Use zfill or predicate for out-of-bound cp.async
SharedMemoryClearOption SharedMemoryClear,
/// Permute result D
typename PermuteDLayout>
struct DefaultSplitkGemmGrouped<ElementA, LayoutA,
ComplexTransform::kNone, // transform A
kAlignmentA, ElementB, LayoutB,
ComplexTransform::kNone, // transform B
kAlignmentB, ElementC, LayoutC, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape,
InstructionShape, EpilogueOutputOp, ThreadblockSwizzle, Stages, GroupScheduleMode_, Operator, SharedMemoryClear,
PermuteDLayout, typename platform::enable_if<!cutlass::is_complex<ElementAccumulator>::value>::type>
{
// If true, we must construct a 'transposed-and-exchanged' Mma operator.
static bool const kInternalTranspose = platform::is_same<LayoutC, layout::ColumnMajor>::value;
using MapArguments = kernel::detail::MapArguments<ElementA, LayoutA, ComplexTransform::kNone, kAlignmentA, ElementB,
LayoutB, ComplexTransform::kNone, kAlignmentB, LayoutC, kInternalTranspose>;
// Define the default GEMM kernel
using DefaultGemmKernel = typename kernel::DefaultGemm<typename MapArguments::ElementA,
typename MapArguments::LayoutA, MapArguments::kAlignmentA, typename MapArguments::ElementB,
typename MapArguments::LayoutB, MapArguments::kAlignmentB, ElementC, typename MapArguments::LayoutC,
ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp,
ThreadblockSwizzle, Stages, true, Operator, SharedMemoryClear, false, /*GatherA*/
false, /*GatherB*/
false, /*ScatterD*/
PermuteDLayout>::GemmKernel;
/// Define the kernel in terms of the default kernel
using GemmKernel = kernel::SplitkGemmGrouped<typename DefaultGemmKernel::Mma, typename DefaultGemmKernel::Epilogue,
ThreadblockSwizzle, GroupScheduleMode_, kInternalTranspose>;
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace kernel
} // namespace gemm
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,494 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief based on cutlass/include/cutlass/gemm/kernel/gemm_grouped.h
*/
#pragma once
#include "cutlass/complex.h"
#include "cutlass/cutlass.h"
#include "cutlass/fast_math.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/matrix_coord.h"
#include "cutlass/semaphore.h"
#include "cutlass/gemm/kernel/gemm_grouped_problem_visitor.h"
#include "cutlass/gemm/kernel/gemm_transpose_operands.h"
#include "cutlass/layout/matrix.h"
#include "cutlass/trace.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass
{
namespace gemm
{
namespace kernel
{
/////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate
typename Epilogue_, ///! Epilogue
typename ThreadblockSwizzle_, ///! Threadblock swizzling function
GroupScheduleMode GroupScheduleMode_, ///! Type of scheduling to perform
bool Transposed = false>
struct SplitkGemmGrouped
{
public:
using Mma = Mma_;
using Epilogue = Epilogue_;
using EpilogueOutputOp = typename Epilogue::OutputOp;
using ThreadblockSwizzle = ThreadblockSwizzle_;
static GroupScheduleMode const kGroupScheduleMode = GroupScheduleMode_;
static bool const kTransposed = Transposed;
// Optional transpose
using MapArguments = kernel::detail::MapArguments<typename Mma::IteratorA::Element, typename Mma::IteratorA::Layout,
Mma::kTransformA, Mma::IteratorA::AccessType::kElements, typename Mma::IteratorB::Element,
typename Mma::IteratorB::Layout, Mma::kTransformB, Mma::IteratorB::AccessType::kElements, typename Mma::LayoutC,
kTransposed>;
// Public-facing type definitions related to operand element type, layout, and complex conjugate
// operation. Must interact with the 'kTransposed' notion.
using ElementA = typename MapArguments::ElementA;
using LayoutA = typename MapArguments::LayoutA;
using ElementB = typename MapArguments::ElementB;
using LayoutB = typename MapArguments::LayoutB;
using ElementC = typename Epilogue::OutputTileIterator::Element;
using LayoutC = typename MapArguments::LayoutC;
using ElementFinalOutput = typename MapArguments::ElementA;
static ComplexTransform const kTransformA = MapArguments::kTransformA;
static ComplexTransform const kTransformB = MapArguments::kTransformB;
// Type definitions about the mainloop.
using Operator = typename Mma::Operator;
using OperatorClass = typename Mma::Operator::OperatorClass;
using ThreadblockShape = typename Mma::Shape;
using WarpShape = typename Mma::Operator::Shape;
using InstructionShape = typename Mma::Policy::Operator::InstructionShape;
using ArchTag = typename Mma::ArchTag;
static int const kStages = Mma::kStages;
static int const kAlignmentA = MapArguments::kAlignmentA;
static int const kAlignmentB = MapArguments::kAlignmentB;
static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess;
/// Warp count (concept: GemmShape)
using WarpCount = typename Mma::WarpCount;
static int const kThreadCount = 32 * WarpCount::kCount;
using ProblemVisitor
= GemmGroupedProblemVisitor<ThreadblockShape, kGroupScheduleMode, kThreadCount, kThreadCount, kTransposed>;
//
// Structures
//
/// Argument structure
struct Arguments
{
//
// Data members
//
GemmCoord* problem_sizes;
int problem_count;
int threadblock_count;
typename EpilogueOutputOp::Params output_op;
ElementA** ptr_A;
ElementB** ptr_B;
ElementFinalOutput** ptr_C;
ElementFinalOutput** ptr_D;
typename LayoutA::Stride::LongIndex* lda;
typename LayoutB::Stride::LongIndex* ldb;
typename LayoutC::Stride::LongIndex* ldc;
typename LayoutC::Stride::LongIndex* ldd;
// Only used by device-level operator
GemmCoord* host_problem_sizes;
// splitK
int split_k_slices;
int64_t* splitk_buffer_offsets;
//
// Methods
//
/// Default ctor
CUTLASS_HOST_DEVICE
Arguments()
: problem_count(0)
, threadblock_count(0)
, ptr_A(nullptr)
, ptr_B(nullptr)
, ptr_C(nullptr)
, ptr_D(nullptr)
, lda(nullptr)
, ldb(nullptr)
, ldc(nullptr)
, ldd(nullptr)
, host_problem_sizes(nullptr)
, split_k_slices(1)
, splitk_buffer_offsets(nullptr)
{
}
/// Ctor
CUTLASS_HOST_DEVICE
Arguments(GemmCoord* problem_sizes, int problem_count, int threadblock_count,
typename EpilogueOutputOp::Params output_op, ElementA** ptr_A, ElementB** ptr_B, ElementFinalOutput** ptr_C,
ElementFinalOutput** ptr_D, typename LayoutA::Stride::LongIndex* lda,
typename LayoutB::Stride::LongIndex* ldb, typename LayoutC::Stride::LongIndex* ldc,
typename LayoutC::Stride::LongIndex* ldd, GemmCoord* host_problem_sizes, int split_k_slices,
int64_t* splitk_buffer_offsets)
: problem_sizes(problem_sizes)
, problem_count(problem_count)
, threadblock_count(threadblock_count)
, output_op(output_op)
, ptr_A(ptr_A)
, ptr_B(ptr_B)
, ptr_C(ptr_C)
, ptr_D(ptr_D)
, lda(lda)
, ldb(ldb)
, ldc(ldc)
, ldd(ldd)
, host_problem_sizes(host_problem_sizes)
, split_k_slices(split_k_slices)
, splitk_buffer_offsets(splitk_buffer_offsets)
{
}
};
//
// Structure for precomputing values in host memory and passing to kernels
//
/// Parameters structure
struct Params
{
typename ProblemVisitor::Params problem_visitor;
int threadblock_count;
typename EpilogueOutputOp::Params output_op;
ElementA** ptr_A;
ElementB** ptr_B;
ElementFinalOutput** ptr_C;
ElementFinalOutput** ptr_D;
ElementC* ptr_C_split;
ElementC* ptr_D_split;
typename LayoutA::Stride::LongIndex* lda;
typename LayoutB::Stride::LongIndex* ldb;
typename LayoutC::Stride::LongIndex* ldc;
typename LayoutC::Stride::LongIndex* ldd;
//
// Methods
//
// splitk
GemmCoord grid_tiled_shape;
int swizzle_log_tile;
int gemm_k_size;
GemmCoord* host_problem_sizes;
int split_k_slices;
int64_t* splitk_buffer_offsets;
CUTLASS_HOST_DEVICE
Params()
: ptr_A(nullptr)
, ptr_B(nullptr)
, ptr_C(nullptr)
, ptr_D(nullptr)
, ptr_C_split(nullptr)
, ptr_D_split(nullptr)
, lda(nullptr)
, ldb(nullptr)
, ldc(nullptr)
, ldd(nullptr)
, swizzle_log_tile(0)
, gemm_k_size(0)
, host_problem_sizes(nullptr)
, split_k_slices(1)
, splitk_buffer_offsets(nullptr)
{
}
CUTLASS_HOST_DEVICE
Params(Arguments const& args, void* workspace = nullptr, int tile_count = 0)
: problem_visitor(args.problem_sizes, args.problem_count, workspace, tile_count)
, host_problem_sizes(args.host_problem_sizes)
, threadblock_count(args.threadblock_count)
, output_op(args.output_op)
, ptr_A(args.ptr_A)
, ptr_B(args.ptr_B)
, ptr_C(args.ptr_C)
, ptr_D(args.ptr_D)
, ptr_C_split((ElementC*) workspace)
, ptr_D_split((ElementC*) workspace)
, lda(args.lda)
, ldb(args.ldb)
, ldc(args.ldc)
, ldd(args.ldd)
, split_k_slices(args.split_k_slices)
, splitk_buffer_offsets(args.splitk_buffer_offsets)
{
// Determine grid shape
ThreadblockSwizzle threadblock_swizzle;
grid_tiled_shape = threadblock_swizzle.get_tiled_shape(args.host_problem_sizes[0],
{ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, args.split_k_slices);
swizzle_log_tile = ThreadblockSwizzle().get_log_tile(grid_tiled_shape);
// only support same k
int full_gemm_k_iterations = args.host_problem_sizes[0].k() / Mma::Shape::kK;
int gemm_k_iterations = full_gemm_k_iterations / grid_tiled_shape.k();
gemm_k_size = gemm_k_iterations * Mma::Shape::kK;
}
CUTLASS_HOST_DEVICE
void update(Arguments const& args, void* workspace = nullptr, int tile_count = 0)
{
problem_visitor =
typename ProblemVisitor::Params(args.problem_sizes, args.problem_count, workspace, tile_count);
threadblock_count = args.threadblock_count;
output_op = args.output_op;
ptr_A = args.ptr_A;
ptr_B = args.ptr_B;
ptr_C = args.ptr_C;
ptr_D = args.ptr_D;
ptr_C_split = workspace;
ptr_D_split = workspace;
lda = args.lda;
ldb = args.ldb;
ldc = args.ldc;
ldd = args.ldd;
}
};
/// Shared memory storage structure
struct SharedStorage
{
union
{
typename Mma::SharedStorage main_loop;
typename Epilogue::SharedStorage epilogue;
} kernel;
// ProblemVisitor shared storage can't be overlapped with others
typename ProblemVisitor::SharedStorage problem_visitor;
};
public:
//
// Methods
//
CUTLASS_DEVICE
SplitkGemmGrouped() {}
/// Determines whether kernel satisfies alignment
static Status can_implement(cutlass::gemm::GemmCoord const& problem_size)
{
return Status::kSuccess;
}
static Status can_implement(Arguments const& args)
{
return Status::kSuccess;
}
/// Executes one GEMM
CUTLASS_DEVICE
void operator()(Params const& params, SharedStorage& shared_storage)
{
//
// These types shadow the type-level definitions and support the ability to implement
// a 'transposed' GEMM that computes the transposed problems.
//
using ElementA = typename Mma::IteratorA::Element;
using LayoutA = typename Mma::IteratorA::Layout;
using ElementB = typename Mma::IteratorB::Element;
using LayoutB = typename Mma::IteratorB::Layout;
using ElementC = typename Epilogue::OutputTileIterator::Element;
using LayoutC = typename Epilogue::OutputTileIterator::Layout;
//
// Problem visitor.
//
ProblemVisitor problem_visitor(params.problem_visitor, shared_storage.problem_visitor, blockIdx.x);
// Outer 'persistent' loop to iterate over tiles
while (problem_visitor.next_tile())
{
GemmCoord problem_size = problem_visitor.problem_size();
int32_t problem_idx = problem_visitor.problem_index();
int32_t threadblock_idx = int32_t(problem_visitor.threadblock_idx());
GemmCoord grid_shape = problem_visitor.grid_shape(problem_size);
// Load element pointers. Exchange pointers and strides if working on the transpose
ElementA* ptr_A
= reinterpret_cast<ElementA*>((kTransposed ? params.ptr_B[problem_idx] : params.ptr_A[problem_idx]));
typename LayoutA::LongIndex ldm_A = (kTransposed ? params.ldb[problem_idx] : params.lda[problem_idx]);
ElementB* ptr_B
= reinterpret_cast<ElementB*>((kTransposed ? params.ptr_A[problem_idx] : params.ptr_B[problem_idx]));
typename LayoutB::LongIndex ldm_B = (kTransposed ? params.lda[problem_idx] : params.ldb[problem_idx]);
// Compute threadblock location
ThreadblockSwizzle threadblock_swizzle;
GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
cutlass::gemm::GemmCoord threadblock_offset(int(threadblock_idx / grid_shape.n()) * Mma::Shape::kM,
int(threadblock_idx % grid_shape.n()) * Mma::Shape::kN, 0);
// Compute initial location in logical coordinates
cutlass::MatrixCoord tb_offset_A{
threadblock_offset.m(),
threadblock_tile_offset.k() * params.gemm_k_size,
};
cutlass::MatrixCoord tb_offset_B{threadblock_tile_offset.k() * params.gemm_k_size, threadblock_offset.n()};
// Problem size is a function of threadblock index in the K dimension
int problem_size_k;
if (threadblock_tile_offset.k() + 1 == params.grid_tiled_shape.k())
{
problem_size_k = problem_size.k();
}
else
{
problem_size_k = (threadblock_tile_offset.k() + 1) * params.gemm_k_size;
}
// Compute threadblock-scoped matrix multiply-add
int gemm_k_iterations = (problem_size_k - tb_offset_A.column() + Mma::Shape::kK - 1) / Mma::Shape::kK;
// Compute position within threadblock
int thread_idx = threadIdx.x;
// Construct iterators to A and B operands
typename Mma::IteratorA iterator_A(
LayoutA(ldm_A), ptr_A, {problem_size.m(), problem_size_k}, thread_idx, tb_offset_A);
typename Mma::IteratorB iterator_B(
LayoutB(ldm_B), ptr_B, {problem_size_k, problem_size.n()}, thread_idx, tb_offset_B);
typename Mma::FragmentC accumulators;
accumulators.clear();
// Broadcast the warp_id computed by lane 0 to ensure dependent code
// is compiled as warp-uniform.
int warp_idx = canonical_warp_idx_sync();
int lane_idx = threadIdx.x % 32;
//
// Matrix multiply phase
//
// Construct thread-scoped matrix multiply
Mma mma(shared_storage.kernel.main_loop, thread_idx, warp_idx, lane_idx);
// Wait for all threads to finish their epilogue phases from the previous tile.
__syncthreads();
// Compute threadblock-scoped matrix multiply-add
mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators);
//
// Epilogue
//
EpilogueOutputOp output_op(params.output_op);
ElementC* ptr_C = params.ptr_C_split;
ElementC* ptr_D = params.ptr_D_split;
LayoutC layout_C(params.ldc[problem_idx]);
LayoutC layout_D(params.ldd[problem_idx]);
typename Epilogue::OutputTileIterator::Params params_C(layout_C);
typename Epilogue::OutputTileIterator::Params params_D(layout_D);
// assume identity swizzle
MatrixCoord threadblock_offset_C(threadblock_offset.m(), threadblock_offset.n());
// Tile iterator loading from source tensor.
typename Epilogue::OutputTileIterator iterator_C(
params_C, ptr_C, problem_size.mn(), thread_idx, threadblock_offset_C);
iterator_C.add_pointer_offset(problem_size.m() * problem_size.n() * threadblock_tile_offset.k()
+ gridDim.z * params.splitk_buffer_offsets[problem_idx]);
// Tile iterator writing to destination tensor.
typename Epilogue::OutputTileIterator iterator_D(
params_D, ptr_D, problem_size.mn(), thread_idx, threadblock_offset_C);
iterator_D.add_pointer_offset(problem_size.m() * problem_size.n() * threadblock_tile_offset.k()
+ gridDim.z * params.splitk_buffer_offsets[problem_idx]);
Epilogue epilogue(shared_storage.kernel.epilogue, thread_idx, warp_idx, lane_idx);
// Execute the epilogue operator to update the destination tensor.
epilogue(output_op, iterator_D, accumulators, iterator_C);
// Next tile
problem_visitor.advance(gridDim.x);
}
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace kernel
} // namespace gemm
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -15,36 +15,18 @@
*/
#include "customAllReduceKernels.h"
#include "tensorrt_llm/common/cudaBf16Fallbacks.cuh"
#include "tensorrt_llm/common/cudaTypeUtils.cuh"
#include <tuple>
namespace tensorrt_llm::kernels
{
using tensorrt_llm::common::hadd2;
using tensorrt_llm::common::datatype_enum;
using tensorrt_llm::common::divUp;
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ uint32_t myHadd2(const uint32_t& a, const uint32_t& b)
{
uint32_t c;
asm volatile("add.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b));
return c;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ uint32_t fadd(const uint32_t& a, const uint32_t& b)
{
uint32_t c;
asm volatile("add.f32 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b));
return c;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ void st_flag_release(uint32_t& flag, uint32_t* flag_addr)
{
#if __CUDA_ARCH__ >= 700
@ -69,80 +51,62 @@ static inline __device__ void ld_flag_acquire(uint32_t& flag, uint32_t* flag_add
////////////////////////////////////////////////////////////////////////////////////////////////////
// Type Converter that packs data format to 128 bits data type
template <typename T>
struct ARTypeConverter
//
using PackedFloat = union
{
using Type = uint4;
int4 packed;
float unpacked[4];
};
using PackedHalf = union
{
int4 packed;
half2 unpacked[4];
};
template <typename T>
struct PackedOn16Bytes
{
};
template <>
struct PackedOn16Bytes<float>
{
using Type = PackedFloat;
};
template <>
struct PackedOn16Bytes<half>
{
using Type = PackedHalf;
};
#ifdef ENABLE_BF16
template <>
struct ARTypeConverter<__nv_bfloat16>
using PackedBFloat16 = union
{
using Type = bf168;
int4 packed;
__nv_bfloat162 unpacked[4];
};
template <>
struct PackedOn16Bytes<__nv_bfloat16>
{
using Type = PackedBFloat16;
};
#endif
// add two 128b data
template <typename T_IN, typename T_COMP>
inline __device__ T_IN add128b(T_IN a, T_IN b);
template <>
inline __device__ uint4 add128b<uint4, uint16_t>(uint4 a, uint4 b)
{
uint4 c;
c.x = myHadd2(a.x, b.x);
c.y = myHadd2(a.y, b.y);
c.z = myHadd2(a.z, b.z);
c.w = myHadd2(a.w, b.w);
return c;
}
template <>
inline __device__ uint4 add128b<uint4, uint32_t>(uint4 a, uint4 b)
{
uint4 c;
c.x = fadd(a.x, b.x);
c.y = fadd(a.y, b.y);
c.z = fadd(a.z, b.z);
c.w = fadd(a.w, b.w);
return c;
}
#ifdef ENABLE_BF16
template <>
inline __device__ bf168 add128b<bf168, __nv_bfloat16>(bf168 a, bf168 b)
{
bf168 c;
c.x = hadd2(a.x, b.x);
c.y = hadd2(a.y, b.y);
c.z = hadd2(a.z, b.z);
c.w = hadd2(a.w, b.w);
return c;
}
#endif
// init 128bits data with 0
template <typename T>
inline __device__ T init_packed_type();
template <>
inline __device__ uint4 init_packed_type()
inline __device__ int4 add128b(T& a, T& b)
{
return make_uint4(0u, 0u, 0u, 0u);
T c;
c.unpacked[0] = a.unpacked[0] + b.unpacked[0];
c.unpacked[1] = a.unpacked[1] + b.unpacked[1];
c.unpacked[2] = a.unpacked[2] + b.unpacked[2];
c.unpacked[3] = a.unpacked[3] + b.unpacked[3];
return c.packed;
}
#ifdef ENABLE_BF16
template <>
inline __device__ bf168 init_packed_type()
{
bf168 val;
uint4& val_u = reinterpret_cast<uint4&>(val);
val_u = make_uint4(0u, 0u, 0u, 0u);
return val;
}
#endif
__inline__ __device__ void multi_gpu_barrier(
uint32_t** signals, const uint32_t flag, const size_t rank, const size_t world_size, const int tidx, const int bidx)
{
@ -179,10 +143,10 @@ static __global__ void oneShotAllReduceKernel(AllReduceParams params)
const int tidx = threadIdx.x;
// The number of elements packed into one for comms
static constexpr int NUM_ELTS = std::is_same<T, uint32_t>::value ? 4 : 8;
static constexpr int NUM_ELTS = 16 / sizeof(T);
// Packed data type for comms
using PackedType = typename ARTypeConverter<T>::Type;
using PackedStruct = typename PackedOn16Bytes<T>::Type;
multi_gpu_barrier(params.peer_barrier_ptrs_in, params.barrier_flag, params.local_rank, RANKS_PER_NODE, tidx, bidx);
@ -204,23 +168,24 @@ static __global__ void oneShotAllReduceKernel(AllReduceParams params)
for (size_t iter_offset = offset; iter_offset < max_offset; iter_offset += blockDim.x * NUM_ELTS)
{
// Iterate over the different ranks/devices on the node to load the values.
PackedType vals[RANKS_PER_NODE];
PackedStruct vals[RANKS_PER_NODE];
#pragma unroll
for (int ii = 0; ii < RANKS_PER_NODE; ++ii)
{
vals[ii] = reinterpret_cast<const PackedType*>(&src_d[ii][iter_offset])[0];
vals[ii].packed = *reinterpret_cast<const int4*>(&src_d[ii][iter_offset]);
}
// Sum the values from the different ranks.
PackedType sums = init_packed_type<PackedType>();
PackedStruct sums;
sums.packed = {0, 0, 0, 0};
#pragma unroll
for (int ii = 0; ii < RANKS_PER_NODE; ++ii)
{
sums = add128b<PackedType, T>(sums, vals[ii]);
sums.packed = add128b(sums, vals[ii]);
}
// Store to the destination buffer.
reinterpret_cast<PackedType*>(&reinterpret_cast<T*>(params.local_output_buffer_ptr)[iter_offset])[0] = sums;
*reinterpret_cast<int4*>(&reinterpret_cast<T*>(params.local_output_buffer_ptr)[iter_offset]) = sums.packed;
}
}
@ -234,10 +199,10 @@ static __global__ void twoShotAllReduceKernel(AllReduceParams params)
const int tidx = threadIdx.x;
// The number of elements packed into one for comms
static constexpr int NUM_ELTS = std::is_same<T, uint32_t>::value ? 4 : 8;
static constexpr int NUM_ELTS = 16 / sizeof(T);
// Packed data type for comms
using PackedType = typename ARTypeConverter<T>::Type;
using PackedType = typename PackedOn16Bytes<T>::Type;
// The location in the destination array (load 8 fp16 or load 4 fp32 using LDG.128).
const size_t block_offset = bidx * params.elts_per_block + tidx * NUM_ELTS;
@ -268,19 +233,20 @@ static __global__ void twoShotAllReduceKernel(AllReduceParams params)
#pragma unroll
for (int ii = 0; ii < RANKS_PER_NODE; ++ii)
{
vals[ii] = reinterpret_cast<const PackedType*>(&src_d[ii][local_offset])[0];
vals[ii].packed = *reinterpret_cast<const int4*>(&src_d[ii][local_offset]);
}
// Sum the values from the different ranks.
PackedType sums = init_packed_type<PackedType>();
PackedType sums;
sums.packed = {0, 0, 0, 0};
#pragma unroll
for (int ii = 0; ii < RANKS_PER_NODE; ++ii)
{
sums = add128b<PackedType, T>(sums, vals[ii]);
sums.packed = add128b(sums, vals[ii]);
}
// Store to the local buffer.
reinterpret_cast<PackedType*>(&src_d[0][local_offset])[0] = sums;
*reinterpret_cast<int4*>(&src_d[0][local_offset]) = sums.packed;
}
// sync threads to make sure all block threads have the sums
@ -318,8 +284,8 @@ static __global__ void twoShotAllReduceKernel(AllReduceParams params)
{
continue;
}
reinterpret_cast<PackedType*>(&reinterpret_cast<T*>(params.local_output_buffer_ptr)[offset_rank])[0]
= reinterpret_cast<PackedType*>(&src_d[ii][offset_rank])[0];
*reinterpret_cast<int4*>(&reinterpret_cast<T*>(params.local_output_buffer_ptr)[offset_rank])
= *reinterpret_cast<int4*>(&src_d[ii][offset_rank]);
}
}
}
@ -434,18 +400,22 @@ AllReduceParams AllReduceParams::deserialize(const int32_t* buffer, size_t tpSiz
{
void* const* buffer_ptrs = reinterpret_cast<void* const*>(buffer);
AllReduceParams params;
// Even plugins use ping buffers, odd plugins use pong.
// That way, we don't need to wait for other GPUs to be done
// before copying input tensor to workspace.
const auto buffer_offset = (flag_value % 2 == 0) ? 0 : tpSize;
for (int i = 0; i < tpSize; ++i)
{
params.peer_comm_buffer_ptrs[i] = buffer_ptrs[i];
params.peer_comm_buffer_ptrs[i] = buffer_ptrs[buffer_offset + i];
}
for (int i = 0; i < tpSize; ++i)
{
params.peer_barrier_ptrs_in[i] = reinterpret_cast<uint32_t*>(buffer_ptrs[tpSize + i]);
params.peer_barrier_ptrs_in[i] = reinterpret_cast<uint32_t*>(buffer_ptrs[2 * tpSize + i]);
}
for (int i = 0; i < tpSize; ++i)
{
params.peer_barrier_ptrs_out[i] = reinterpret_cast<uint32_t*>(buffer_ptrs[2 * tpSize + i]);
params.peer_barrier_ptrs_out[i] = reinterpret_cast<uint32_t*>(buffer_ptrs[3 * tpSize + i]);
}
params.barrier_flag = flag_value;
params.ranks_per_node = tpSize;
@ -463,19 +433,18 @@ void customAllReduce(kernels::AllReduceParams& params, void* data, size_t elts,
if (dataType == datatype_enum::TYPE_FP32)
{
using T = CustomARCommTypeConverter<float>::Type;
kernels::invokeOneOrTwoShotAllReduceKernel<T>(params, strat, stream);
kernels::invokeOneOrTwoShotAllReduceKernel<float>(params, strat, stream);
}
else if (dataType == datatype_enum::TYPE_FP16)
{
using T = CustomARCommTypeConverter<half>::Type;
kernels::invokeOneOrTwoShotAllReduceKernel<T>(params, strat, stream);
kernels::invokeOneOrTwoShotAllReduceKernel<half>(params, strat, stream);
}
#ifdef ENABLE_BF16
else if (dataType == datatype_enum::TYPE_BF16)
{
using T = CustomARCommTypeConverter<__nv_bfloat16>::Type;
kernels::invokeOneOrTwoShotAllReduceKernel<T>(params, strat, stream);
kernels::invokeOneOrTwoShotAllReduceKernel<__nv_bfloat16>(params, strat, stream);
}
#endif
else
{
TLLM_THROW("Unsupported dataType for customAllReduce");

View File

@ -30,7 +30,6 @@ namespace tensorrt_llm::kernels
{
constexpr size_t WARP_SIZE = 32;
constexpr size_t CUSTOM_AR_SIZE_THRESHOLD = 50331648;
constexpr size_t MAX_ALL_REDUCE_BLOCKS = 24;
constexpr size_t MAX_RANKS_PER_NODE = 8;
constexpr size_t DEFAULT_BLOCK_SIZE = 1024;
@ -45,16 +44,6 @@ enum class AllReduceStrategyType : int8_t
AUTO = 3,
};
#ifdef ENABLE_BF16
typedef struct bf168
{
__nv_bfloat162 x;
__nv_bfloat162 y;
__nv_bfloat162 z;
__nv_bfloat162 w;
} bf168;
#endif
struct AllReduceParams
{
size_t elts_total;
@ -76,26 +65,6 @@ void invokeOneOrTwoShotAllReduceKernel(AllReduceParams& param, AllReduceStrategy
void invokeMultiGpuBarrier(AllReduceParams& param, cudaStream_t stream);
template <typename T>
struct CustomARCommTypeConverter
{
using Type = uint32_t;
};
template <>
struct CustomARCommTypeConverter<half>
{
using Type = uint16_t;
};
#ifdef ENABLE_BF16
template <>
struct CustomARCommTypeConverter<__nv_bfloat16>
{
using Type = __nv_bfloat16;
};
#endif
void customAllReduce(kernels::AllReduceParams& params, void* data, size_t elts, size_t size_per_elem,
common::datatype_enum dataType, AllReduceStrategyType strat, cudaStream_t stream);

View File

@ -19,27 +19,253 @@ namespace tensorrt_llm
namespace kernels
{
// clang-format off
extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_0_nqpkv_8_sm_80_cubin[];
extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_1_nqpkv_8_sm_80_cubin[];
extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_0_nqpkv_8_sm_86_cubin[];
extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_1_nqpkv_8_sm_86_cubin[];
extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_0_nqpkv_8_sm_89_cubin[];
extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_1_nqpkv_8_sm_89_cubin[];
extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_2_nqpkv_8_sm_89_cubin[];
extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_0_nqpkv_8_sm_90_cubin[];
extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_1_nqpkv_8_sm_90_cubin[];
extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_2_nqpkv_8_sm_90_cubin[];
// SingleQueryToken kernels.
extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_8_m_8_sm_80_cubin[];
extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_64_nqpkv_8_m_8_sm_80_cubin[];
extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_128_nqpkv_8_m_8_sm_80_cubin[];
extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_nqpkv_8_m_8_sm_80_cubin[];
extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_pagedKV_64_nqpkv_8_m_8_sm_80_cubin[];
extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_pagedKV_128_nqpkv_8_m_8_sm_80_cubin[];
extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_e4m3_nqpkv_8_m_8_sm_80_cubin[];
extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_e4m3_pagedKV_64_nqpkv_8_m_8_sm_80_cubin[];
extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_e4m3_pagedKV_128_nqpkv_8_m_8_sm_80_cubin[];
extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_nqpkv_8_m_8_sm_80_cubin[];
extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_64_nqpkv_8_m_8_sm_80_cubin[];
extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_128_nqpkv_8_m_8_sm_80_cubin[];
extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_nqpkv_8_m_8_sm_80_cubin[];
extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_pagedKV_64_nqpkv_8_m_8_sm_80_cubin[];
extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_pagedKV_128_nqpkv_8_m_8_sm_80_cubin[];
extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_nqpkv_8_m_8_sm_80_cubin[];
extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_pagedKV_64_nqpkv_8_m_8_sm_80_cubin[];
extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_pagedKV_128_nqpkv_8_m_8_sm_80_cubin[];
extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_8_m_8_sm_86_cubin[];
extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_64_nqpkv_8_m_8_sm_86_cubin[];
extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_128_nqpkv_8_m_8_sm_86_cubin[];
extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_nqpkv_8_m_8_sm_86_cubin[];
extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_pagedKV_64_nqpkv_8_m_8_sm_86_cubin[];
extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_pagedKV_128_nqpkv_8_m_8_sm_86_cubin[];
extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_e4m3_nqpkv_8_m_8_sm_86_cubin[];
extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_e4m3_pagedKV_64_nqpkv_8_m_8_sm_86_cubin[];
extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_e4m3_pagedKV_128_nqpkv_8_m_8_sm_86_cubin[];
extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_nqpkv_8_m_8_sm_86_cubin[];
extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_64_nqpkv_8_m_8_sm_86_cubin[];
extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_128_nqpkv_8_m_8_sm_86_cubin[];
extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_nqpkv_8_m_8_sm_86_cubin[];
extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_pagedKV_64_nqpkv_8_m_8_sm_86_cubin[];
extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_pagedKV_128_nqpkv_8_m_8_sm_86_cubin[];
extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_nqpkv_8_m_8_sm_86_cubin[];
extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_pagedKV_64_nqpkv_8_m_8_sm_86_cubin[];
extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_pagedKV_128_nqpkv_8_m_8_sm_86_cubin[];
extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_8_m_8_sm_89_cubin[];
extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_64_nqpkv_8_m_8_sm_89_cubin[];
extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_128_nqpkv_8_m_8_sm_89_cubin[];
extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_nqpkv_8_m_8_sm_89_cubin[];
extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_pagedKV_64_nqpkv_8_m_8_sm_89_cubin[];
extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_pagedKV_128_nqpkv_8_m_8_sm_89_cubin[];
extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_e4m3_nqpkv_8_m_8_sm_89_cubin[];
extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_e4m3_pagedKV_64_nqpkv_8_m_8_sm_89_cubin[];
extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_e4m3_pagedKV_128_nqpkv_8_m_8_sm_89_cubin[];
extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_nqpkv_8_m_8_sm_89_cubin[];
extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_64_nqpkv_8_m_8_sm_89_cubin[];
extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_128_nqpkv_8_m_8_sm_89_cubin[];
extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_nqpkv_8_m_8_sm_89_cubin[];
extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_pagedKV_64_nqpkv_8_m_8_sm_89_cubin[];
extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_pagedKV_128_nqpkv_8_m_8_sm_89_cubin[];
extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_nqpkv_8_m_8_sm_89_cubin[];
extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_pagedKV_64_nqpkv_8_m_8_sm_89_cubin[];
extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_pagedKV_128_nqpkv_8_m_8_sm_89_cubin[];
extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_8_m_8_sm_90_cubin[];
extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_64_nqpkv_8_m_8_sm_90_cubin[];
extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_128_nqpkv_8_m_8_sm_90_cubin[];
extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_nqpkv_8_m_8_sm_90_cubin[];
extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_pagedKV_64_nqpkv_8_m_8_sm_90_cubin[];
extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_pagedKV_128_nqpkv_8_m_8_sm_90_cubin[];
extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_e4m3_nqpkv_8_m_8_sm_90_cubin[];
extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_e4m3_pagedKV_64_nqpkv_8_m_8_sm_90_cubin[];
extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_e4m3_pagedKV_128_nqpkv_8_m_8_sm_90_cubin[];
extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_nqpkv_8_m_8_sm_90_cubin[];
extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_64_nqpkv_8_m_8_sm_90_cubin[];
extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_128_nqpkv_8_m_8_sm_90_cubin[];
extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_nqpkv_8_m_8_sm_90_cubin[];
extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_pagedKV_64_nqpkv_8_m_8_sm_90_cubin[];
extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_pagedKV_128_nqpkv_8_m_8_sm_90_cubin[];
extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_nqpkv_8_m_8_sm_90_cubin[];
extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_pagedKV_64_nqpkv_8_m_8_sm_90_cubin[];
extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_pagedKV_128_nqpkv_8_m_8_sm_90_cubin[];
extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_0_nqpkv_8_sm_80_cubin_len;
extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_1_nqpkv_8_sm_80_cubin_len;
extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_0_nqpkv_8_sm_86_cubin_len;
extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_1_nqpkv_8_sm_86_cubin_len;
extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_0_nqpkv_8_sm_89_cubin_len;
extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_1_nqpkv_8_sm_89_cubin_len;
extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_2_nqpkv_8_sm_89_cubin_len;
extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_0_nqpkv_8_sm_90_cubin_len;
extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_1_nqpkv_8_sm_90_cubin_len;
extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_2_nqpkv_8_sm_90_cubin_len;
// MultiQueryToken kernels.
extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_0_m_16_sm_80_cubin[];
extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_0_m_32_sm_80_cubin[];
extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_64_nqpkv_0_m_16_sm_80_cubin[];
extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_64_nqpkv_0_m_32_sm_80_cubin[];
extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_128_nqpkv_0_m_16_sm_80_cubin[];
extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_128_nqpkv_0_m_32_sm_80_cubin[];
extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_nqpkv_0_m_16_sm_80_cubin[];
extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_nqpkv_0_m_32_sm_80_cubin[];
extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_64_nqpkv_0_m_16_sm_80_cubin[];
extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_64_nqpkv_0_m_32_sm_80_cubin[];
extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_128_nqpkv_0_m_16_sm_80_cubin[];
extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_128_nqpkv_0_m_32_sm_80_cubin[];
extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_0_m_16_sm_86_cubin[];
extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_0_m_32_sm_86_cubin[];
extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_64_nqpkv_0_m_16_sm_86_cubin[];
extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_64_nqpkv_0_m_32_sm_86_cubin[];
extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_128_nqpkv_0_m_16_sm_86_cubin[];
extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_128_nqpkv_0_m_32_sm_86_cubin[];
extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_nqpkv_0_m_16_sm_86_cubin[];
extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_nqpkv_0_m_32_sm_86_cubin[];
extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_64_nqpkv_0_m_16_sm_86_cubin[];
extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_64_nqpkv_0_m_32_sm_86_cubin[];
extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_128_nqpkv_0_m_16_sm_86_cubin[];
extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_128_nqpkv_0_m_32_sm_86_cubin[];
extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_0_m_16_sm_89_cubin[];
extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_0_m_32_sm_89_cubin[];
extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_64_nqpkv_0_m_16_sm_89_cubin[];
extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_64_nqpkv_0_m_32_sm_89_cubin[];
extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_128_nqpkv_0_m_16_sm_89_cubin[];
extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_128_nqpkv_0_m_32_sm_89_cubin[];
extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_nqpkv_0_m_16_sm_89_cubin[];
extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_nqpkv_0_m_32_sm_89_cubin[];
extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_64_nqpkv_0_m_16_sm_89_cubin[];
extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_64_nqpkv_0_m_32_sm_89_cubin[];
extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_128_nqpkv_0_m_16_sm_89_cubin[];
extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_128_nqpkv_0_m_32_sm_89_cubin[];
extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_0_m_16_sm_90_cubin[];
extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_0_m_32_sm_90_cubin[];
extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_64_nqpkv_0_m_16_sm_90_cubin[];
extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_64_nqpkv_0_m_32_sm_90_cubin[];
extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_128_nqpkv_0_m_16_sm_90_cubin[];
extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_128_nqpkv_0_m_32_sm_90_cubin[];
extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_nqpkv_0_m_16_sm_90_cubin[];
extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_nqpkv_0_m_32_sm_90_cubin[];
extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_64_nqpkv_0_m_16_sm_90_cubin[];
extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_64_nqpkv_0_m_32_sm_90_cubin[];
extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_128_nqpkv_0_m_16_sm_90_cubin[];
extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_128_nqpkv_0_m_32_sm_90_cubin[];
// SingleQueryToken kernels.
extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_8_m_8_sm_80_cubin_len;
extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_64_nqpkv_8_m_8_sm_80_cubin_len;
extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_128_nqpkv_8_m_8_sm_80_cubin_len;
extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_nqpkv_8_m_8_sm_80_cubin_len;
extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_pagedKV_64_nqpkv_8_m_8_sm_80_cubin_len;
extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_pagedKV_128_nqpkv_8_m_8_sm_80_cubin_len;
extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_e4m3_nqpkv_8_m_8_sm_80_cubin_len;
extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_e4m3_pagedKV_64_nqpkv_8_m_8_sm_80_cubin_len;
extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_e4m3_pagedKV_128_nqpkv_8_m_8_sm_80_cubin_len;
extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_nqpkv_8_m_8_sm_80_cubin_len;
extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_64_nqpkv_8_m_8_sm_80_cubin_len;
extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_128_nqpkv_8_m_8_sm_80_cubin_len;
extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_nqpkv_8_m_8_sm_80_cubin_len;
extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_pagedKV_64_nqpkv_8_m_8_sm_80_cubin_len;
extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_pagedKV_128_nqpkv_8_m_8_sm_80_cubin_len;
extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_nqpkv_8_m_8_sm_80_cubin_len;
extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_pagedKV_64_nqpkv_8_m_8_sm_80_cubin_len;
extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_pagedKV_128_nqpkv_8_m_8_sm_80_cubin_len;
extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_8_m_8_sm_86_cubin_len;
extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_64_nqpkv_8_m_8_sm_86_cubin_len;
extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_128_nqpkv_8_m_8_sm_86_cubin_len;
extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_nqpkv_8_m_8_sm_86_cubin_len;
extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_pagedKV_64_nqpkv_8_m_8_sm_86_cubin_len;
extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_pagedKV_128_nqpkv_8_m_8_sm_86_cubin_len;
extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_e4m3_nqpkv_8_m_8_sm_86_cubin_len;
extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_e4m3_pagedKV_64_nqpkv_8_m_8_sm_86_cubin_len;
extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_e4m3_pagedKV_128_nqpkv_8_m_8_sm_86_cubin_len;
extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_nqpkv_8_m_8_sm_86_cubin_len;
extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_64_nqpkv_8_m_8_sm_86_cubin_len;
extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_128_nqpkv_8_m_8_sm_86_cubin_len;
extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_nqpkv_8_m_8_sm_86_cubin_len;
extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_pagedKV_64_nqpkv_8_m_8_sm_86_cubin_len;
extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_pagedKV_128_nqpkv_8_m_8_sm_86_cubin_len;
extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_nqpkv_8_m_8_sm_86_cubin_len;
extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_pagedKV_64_nqpkv_8_m_8_sm_86_cubin_len;
extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_pagedKV_128_nqpkv_8_m_8_sm_86_cubin_len;
extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_8_m_8_sm_89_cubin_len;
extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_64_nqpkv_8_m_8_sm_89_cubin_len;
extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_128_nqpkv_8_m_8_sm_89_cubin_len;
extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_nqpkv_8_m_8_sm_89_cubin_len;
extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_pagedKV_64_nqpkv_8_m_8_sm_89_cubin_len;
extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_pagedKV_128_nqpkv_8_m_8_sm_89_cubin_len;
extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_e4m3_nqpkv_8_m_8_sm_89_cubin_len;
extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_e4m3_pagedKV_64_nqpkv_8_m_8_sm_89_cubin_len;
extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_e4m3_pagedKV_128_nqpkv_8_m_8_sm_89_cubin_len;
extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_nqpkv_8_m_8_sm_89_cubin_len;
extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_64_nqpkv_8_m_8_sm_89_cubin_len;
extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_128_nqpkv_8_m_8_sm_89_cubin_len;
extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_nqpkv_8_m_8_sm_89_cubin_len;
extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_pagedKV_64_nqpkv_8_m_8_sm_89_cubin_len;
extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_pagedKV_128_nqpkv_8_m_8_sm_89_cubin_len;
extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_nqpkv_8_m_8_sm_89_cubin_len;
extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_pagedKV_64_nqpkv_8_m_8_sm_89_cubin_len;
extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_pagedKV_128_nqpkv_8_m_8_sm_89_cubin_len;
extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_8_m_8_sm_90_cubin_len;
extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_64_nqpkv_8_m_8_sm_90_cubin_len;
extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_128_nqpkv_8_m_8_sm_90_cubin_len;
extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_nqpkv_8_m_8_sm_90_cubin_len;
extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_pagedKV_64_nqpkv_8_m_8_sm_90_cubin_len;
extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_pagedKV_128_nqpkv_8_m_8_sm_90_cubin_len;
extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_e4m3_nqpkv_8_m_8_sm_90_cubin_len;
extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_e4m3_pagedKV_64_nqpkv_8_m_8_sm_90_cubin_len;
extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_e4m3_pagedKV_128_nqpkv_8_m_8_sm_90_cubin_len;
extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_nqpkv_8_m_8_sm_90_cubin_len;
extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_64_nqpkv_8_m_8_sm_90_cubin_len;
extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_128_nqpkv_8_m_8_sm_90_cubin_len;
extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_nqpkv_8_m_8_sm_90_cubin_len;
extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_pagedKV_64_nqpkv_8_m_8_sm_90_cubin_len;
extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_pagedKV_128_nqpkv_8_m_8_sm_90_cubin_len;
extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_nqpkv_8_m_8_sm_90_cubin_len;
extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_pagedKV_64_nqpkv_8_m_8_sm_90_cubin_len;
extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_pagedKV_128_nqpkv_8_m_8_sm_90_cubin_len;
// MultiQueryToken kernels.
extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_0_m_16_sm_80_cubin_len;
extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_0_m_32_sm_80_cubin_len;
extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_64_nqpkv_0_m_16_sm_80_cubin_len;
extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_64_nqpkv_0_m_32_sm_80_cubin_len;
extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_128_nqpkv_0_m_16_sm_80_cubin_len;
extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_128_nqpkv_0_m_32_sm_80_cubin_len;
extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_nqpkv_0_m_16_sm_80_cubin_len;
extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_nqpkv_0_m_32_sm_80_cubin_len;
extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_64_nqpkv_0_m_16_sm_80_cubin_len;
extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_64_nqpkv_0_m_32_sm_80_cubin_len;
extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_128_nqpkv_0_m_16_sm_80_cubin_len;
extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_128_nqpkv_0_m_32_sm_80_cubin_len;
extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_0_m_16_sm_86_cubin_len;
extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_0_m_32_sm_86_cubin_len;
extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_64_nqpkv_0_m_16_sm_86_cubin_len;
extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_64_nqpkv_0_m_32_sm_86_cubin_len;
extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_128_nqpkv_0_m_16_sm_86_cubin_len;
extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_128_nqpkv_0_m_32_sm_86_cubin_len;
extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_nqpkv_0_m_16_sm_86_cubin_len;
extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_nqpkv_0_m_32_sm_86_cubin_len;
extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_64_nqpkv_0_m_16_sm_86_cubin_len;
extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_64_nqpkv_0_m_32_sm_86_cubin_len;
extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_128_nqpkv_0_m_16_sm_86_cubin_len;
extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_128_nqpkv_0_m_32_sm_86_cubin_len;
extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_0_m_16_sm_89_cubin_len;
extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_0_m_32_sm_89_cubin_len;
extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_64_nqpkv_0_m_16_sm_89_cubin_len;
extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_64_nqpkv_0_m_32_sm_89_cubin_len;
extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_128_nqpkv_0_m_16_sm_89_cubin_len;
extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_128_nqpkv_0_m_32_sm_89_cubin_len;
extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_nqpkv_0_m_16_sm_89_cubin_len;
extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_nqpkv_0_m_32_sm_89_cubin_len;
extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_64_nqpkv_0_m_16_sm_89_cubin_len;
extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_64_nqpkv_0_m_32_sm_89_cubin_len;
extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_128_nqpkv_0_m_16_sm_89_cubin_len;
extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_128_nqpkv_0_m_32_sm_89_cubin_len;
extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_0_m_16_sm_90_cubin_len;
extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_0_m_32_sm_90_cubin_len;
extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_64_nqpkv_0_m_16_sm_90_cubin_len;
extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_64_nqpkv_0_m_32_sm_90_cubin_len;
extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_128_nqpkv_0_m_16_sm_90_cubin_len;
extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_128_nqpkv_0_m_32_sm_90_cubin_len;
extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_nqpkv_0_m_16_sm_90_cubin_len;
extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_nqpkv_0_m_32_sm_90_cubin_len;
extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_64_nqpkv_0_m_16_sm_90_cubin_len;
extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_64_nqpkv_0_m_32_sm_90_cubin_len;
extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_128_nqpkv_0_m_16_sm_90_cubin_len;
extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_128_nqpkv_0_m_32_sm_90_cubin_len;
static const struct XQAKernelMetaInfo
@ -49,21 +275,137 @@ static const struct XQAKernelMetaInfo
unsigned int mHeadDim;
unsigned int mBeamWidth;
unsigned int mNumQHeadsOverKV;
unsigned int mMTileSize;
unsigned int mTokensPerPage;
bool mPagedKVCache;
bool mMultiQueryTokens;
unsigned int mSM;
const unsigned long long* mCubin;
unsigned int mCubinSize;
const char* mFuncName;
} sXqaKernelMetaInfo[] = {
{ DATA_TYPE_FP16, DATA_TYPE_FP16, 128, 1, 8, kSM_80, xqa_kernel_dt_fp16_d_128_beam_1_kvt_0_nqpkv_8_sm_80_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_0_nqpkv_8_sm_80_cubin_len, "kernel_mha"},
{ DATA_TYPE_FP16, DATA_TYPE_INT8, 128, 1, 8, kSM_80, xqa_kernel_dt_fp16_d_128_beam_1_kvt_1_nqpkv_8_sm_80_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_1_nqpkv_8_sm_80_cubin_len, "kernel_mha"},
{ DATA_TYPE_FP16, DATA_TYPE_FP16, 128, 1, 8, kSM_86, xqa_kernel_dt_fp16_d_128_beam_1_kvt_0_nqpkv_8_sm_86_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_0_nqpkv_8_sm_86_cubin_len, "kernel_mha"},
{ DATA_TYPE_FP16, DATA_TYPE_INT8, 128, 1, 8, kSM_86, xqa_kernel_dt_fp16_d_128_beam_1_kvt_1_nqpkv_8_sm_86_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_1_nqpkv_8_sm_86_cubin_len, "kernel_mha"},
{ DATA_TYPE_FP16, DATA_TYPE_FP16, 128, 1, 8, kSM_89, xqa_kernel_dt_fp16_d_128_beam_1_kvt_0_nqpkv_8_sm_89_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_0_nqpkv_8_sm_89_cubin_len, "kernel_mha"},
{ DATA_TYPE_FP16, DATA_TYPE_INT8, 128, 1, 8, kSM_89, xqa_kernel_dt_fp16_d_128_beam_1_kvt_1_nqpkv_8_sm_89_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_1_nqpkv_8_sm_89_cubin_len, "kernel_mha"},
{ DATA_TYPE_FP16, DATA_TYPE_E4M3, 128, 1, 8, kSM_89, xqa_kernel_dt_fp16_d_128_beam_1_kvt_2_nqpkv_8_sm_89_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_2_nqpkv_8_sm_89_cubin_len, "kernel_mha"},
{ DATA_TYPE_FP16, DATA_TYPE_FP16, 128, 1, 8, kSM_90, xqa_kernel_dt_fp16_d_128_beam_1_kvt_0_nqpkv_8_sm_90_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_0_nqpkv_8_sm_90_cubin_len, "kernel_mha"},
{ DATA_TYPE_FP16, DATA_TYPE_INT8, 128, 1, 8, kSM_90, xqa_kernel_dt_fp16_d_128_beam_1_kvt_1_nqpkv_8_sm_90_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_1_nqpkv_8_sm_90_cubin_len, "kernel_mha"},
{ DATA_TYPE_FP16, DATA_TYPE_E4M3, 128, 1, 8, kSM_90, xqa_kernel_dt_fp16_d_128_beam_1_kvt_2_nqpkv_8_sm_90_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_2_nqpkv_8_sm_90_cubin_len, "kernel_mha"}
// SingleQueryToken kernels.
{ DATA_TYPE_FP16, DATA_TYPE_FP16, 128, 1, 8, 8, 0, false, false, kSM_80, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_8_m_8_sm_80_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_8_m_8_sm_80_cubin_len, "kernel_mha"},
{ DATA_TYPE_FP16, DATA_TYPE_FP16, 128, 1, 8, 8, 64, true, false, kSM_80, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_64_nqpkv_8_m_8_sm_80_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_64_nqpkv_8_m_8_sm_80_cubin_len, "kernel_mha"},
{ DATA_TYPE_FP16, DATA_TYPE_FP16, 128, 1, 8, 8, 128, true, false, kSM_80, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_128_nqpkv_8_m_8_sm_80_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_128_nqpkv_8_m_8_sm_80_cubin_len, "kernel_mha"},
{ DATA_TYPE_FP16, DATA_TYPE_INT8, 128, 1, 8, 8, 0, false, false, kSM_80, xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_nqpkv_8_m_8_sm_80_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_nqpkv_8_m_8_sm_80_cubin_len, "kernel_mha"},
{ DATA_TYPE_FP16, DATA_TYPE_INT8, 128, 1, 8, 8, 64, true, false, kSM_80, xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_pagedKV_64_nqpkv_8_m_8_sm_80_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_pagedKV_64_nqpkv_8_m_8_sm_80_cubin_len, "kernel_mha"},
{ DATA_TYPE_FP16, DATA_TYPE_INT8, 128, 1, 8, 8, 128, true, false, kSM_80, xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_pagedKV_128_nqpkv_8_m_8_sm_80_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_pagedKV_128_nqpkv_8_m_8_sm_80_cubin_len, "kernel_mha"},
{ DATA_TYPE_FP16, DATA_TYPE_E4M3, 128, 1, 8, 8, 0, false, false, kSM_80, xqa_kernel_dt_fp16_d_128_beam_1_kvt_e4m3_nqpkv_8_m_8_sm_80_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_e4m3_nqpkv_8_m_8_sm_80_cubin_len, "kernel_mha"},
{ DATA_TYPE_FP16, DATA_TYPE_E4M3, 128, 1, 8, 8, 64, true, false, kSM_80, xqa_kernel_dt_fp16_d_128_beam_1_kvt_e4m3_pagedKV_64_nqpkv_8_m_8_sm_80_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_e4m3_pagedKV_64_nqpkv_8_m_8_sm_80_cubin_len, "kernel_mha"},
{ DATA_TYPE_FP16, DATA_TYPE_E4M3, 128, 1, 8, 8, 128, true, false, kSM_80, xqa_kernel_dt_fp16_d_128_beam_1_kvt_e4m3_pagedKV_128_nqpkv_8_m_8_sm_80_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_e4m3_pagedKV_128_nqpkv_8_m_8_sm_80_cubin_len, "kernel_mha"},
{ DATA_TYPE_BF16, DATA_TYPE_FP16, 128, 1, 8, 8, 0, false, false, kSM_80, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_nqpkv_8_m_8_sm_80_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_nqpkv_8_m_8_sm_80_cubin_len, "kernel_mha"},
{ DATA_TYPE_BF16, DATA_TYPE_FP16, 128, 1, 8, 8, 64, true, false, kSM_80, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_64_nqpkv_8_m_8_sm_80_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_64_nqpkv_8_m_8_sm_80_cubin_len, "kernel_mha"},
{ DATA_TYPE_BF16, DATA_TYPE_FP16, 128, 1, 8, 8, 128, true, false, kSM_80, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_128_nqpkv_8_m_8_sm_80_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_128_nqpkv_8_m_8_sm_80_cubin_len, "kernel_mha"},
{ DATA_TYPE_BF16, DATA_TYPE_INT8, 128, 1, 8, 8, 0, false, false, kSM_80, xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_nqpkv_8_m_8_sm_80_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_nqpkv_8_m_8_sm_80_cubin_len, "kernel_mha"},
{ DATA_TYPE_BF16, DATA_TYPE_INT8, 128, 1, 8, 8, 64, true, false, kSM_80, xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_pagedKV_64_nqpkv_8_m_8_sm_80_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_pagedKV_64_nqpkv_8_m_8_sm_80_cubin_len, "kernel_mha"},
{ DATA_TYPE_BF16, DATA_TYPE_INT8, 128, 1, 8, 8, 128, true, false, kSM_80, xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_pagedKV_128_nqpkv_8_m_8_sm_80_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_pagedKV_128_nqpkv_8_m_8_sm_80_cubin_len, "kernel_mha"},
{ DATA_TYPE_BF16, DATA_TYPE_E4M3, 128, 1, 8, 8, 0, false, false, kSM_80, xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_nqpkv_8_m_8_sm_80_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_nqpkv_8_m_8_sm_80_cubin_len, "kernel_mha"},
{ DATA_TYPE_BF16, DATA_TYPE_E4M3, 128, 1, 8, 8, 64, true, false, kSM_80, xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_pagedKV_64_nqpkv_8_m_8_sm_80_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_pagedKV_64_nqpkv_8_m_8_sm_80_cubin_len, "kernel_mha"},
{ DATA_TYPE_BF16, DATA_TYPE_E4M3, 128, 1, 8, 8, 128, true, false, kSM_80, xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_pagedKV_128_nqpkv_8_m_8_sm_80_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_pagedKV_128_nqpkv_8_m_8_sm_80_cubin_len, "kernel_mha"},
{ DATA_TYPE_FP16, DATA_TYPE_FP16, 128, 1, 8, 8, 0, false, false, kSM_86, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_8_m_8_sm_86_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_8_m_8_sm_86_cubin_len, "kernel_mha"},
{ DATA_TYPE_FP16, DATA_TYPE_FP16, 128, 1, 8, 8, 64, true, false, kSM_86, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_64_nqpkv_8_m_8_sm_86_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_64_nqpkv_8_m_8_sm_86_cubin_len, "kernel_mha"},
{ DATA_TYPE_FP16, DATA_TYPE_FP16, 128, 1, 8, 8, 128, true, false, kSM_86, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_128_nqpkv_8_m_8_sm_86_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_128_nqpkv_8_m_8_sm_86_cubin_len, "kernel_mha"},
{ DATA_TYPE_FP16, DATA_TYPE_INT8, 128, 1, 8, 8, 0, false, false, kSM_86, xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_nqpkv_8_m_8_sm_86_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_nqpkv_8_m_8_sm_86_cubin_len, "kernel_mha"},
{ DATA_TYPE_FP16, DATA_TYPE_INT8, 128, 1, 8, 8, 64, true, false, kSM_86, xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_pagedKV_64_nqpkv_8_m_8_sm_86_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_pagedKV_64_nqpkv_8_m_8_sm_86_cubin_len, "kernel_mha"},
{ DATA_TYPE_FP16, DATA_TYPE_INT8, 128, 1, 8, 8, 128, true, false, kSM_86, xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_pagedKV_128_nqpkv_8_m_8_sm_86_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_pagedKV_128_nqpkv_8_m_8_sm_86_cubin_len, "kernel_mha"},
{ DATA_TYPE_FP16, DATA_TYPE_E4M3, 128, 1, 8, 8, 0, false, false, kSM_86, xqa_kernel_dt_fp16_d_128_beam_1_kvt_e4m3_nqpkv_8_m_8_sm_86_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_e4m3_nqpkv_8_m_8_sm_86_cubin_len, "kernel_mha"},
{ DATA_TYPE_FP16, DATA_TYPE_E4M3, 128, 1, 8, 8, 64, true, false, kSM_86, xqa_kernel_dt_fp16_d_128_beam_1_kvt_e4m3_pagedKV_64_nqpkv_8_m_8_sm_86_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_e4m3_pagedKV_64_nqpkv_8_m_8_sm_86_cubin_len, "kernel_mha"},
{ DATA_TYPE_FP16, DATA_TYPE_E4M3, 128, 1, 8, 8, 128, true, false, kSM_86, xqa_kernel_dt_fp16_d_128_beam_1_kvt_e4m3_pagedKV_128_nqpkv_8_m_8_sm_86_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_e4m3_pagedKV_128_nqpkv_8_m_8_sm_86_cubin_len, "kernel_mha"},
{ DATA_TYPE_BF16, DATA_TYPE_FP16, 128, 1, 8, 8, 0, false, false, kSM_86, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_nqpkv_8_m_8_sm_86_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_nqpkv_8_m_8_sm_86_cubin_len, "kernel_mha"},
{ DATA_TYPE_BF16, DATA_TYPE_FP16, 128, 1, 8, 8, 64, true, false, kSM_86, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_64_nqpkv_8_m_8_sm_86_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_64_nqpkv_8_m_8_sm_86_cubin_len, "kernel_mha"},
{ DATA_TYPE_BF16, DATA_TYPE_FP16, 128, 1, 8, 8, 128, true, false, kSM_86, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_128_nqpkv_8_m_8_sm_86_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_128_nqpkv_8_m_8_sm_86_cubin_len, "kernel_mha"},
{ DATA_TYPE_BF16, DATA_TYPE_INT8, 128, 1, 8, 8, 0, false, false, kSM_86, xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_nqpkv_8_m_8_sm_86_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_nqpkv_8_m_8_sm_86_cubin_len, "kernel_mha"},
{ DATA_TYPE_BF16, DATA_TYPE_INT8, 128, 1, 8, 8, 64, true, false, kSM_86, xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_pagedKV_64_nqpkv_8_m_8_sm_86_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_pagedKV_64_nqpkv_8_m_8_sm_86_cubin_len, "kernel_mha"},
{ DATA_TYPE_BF16, DATA_TYPE_INT8, 128, 1, 8, 8, 128, true, false, kSM_86, xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_pagedKV_128_nqpkv_8_m_8_sm_86_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_pagedKV_128_nqpkv_8_m_8_sm_86_cubin_len, "kernel_mha"},
{ DATA_TYPE_BF16, DATA_TYPE_E4M3, 128, 1, 8, 8, 0, false, false, kSM_86, xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_nqpkv_8_m_8_sm_86_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_nqpkv_8_m_8_sm_86_cubin_len, "kernel_mha"},
{ DATA_TYPE_BF16, DATA_TYPE_E4M3, 128, 1, 8, 8, 64, true, false, kSM_86, xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_pagedKV_64_nqpkv_8_m_8_sm_86_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_pagedKV_64_nqpkv_8_m_8_sm_86_cubin_len, "kernel_mha"},
{ DATA_TYPE_BF16, DATA_TYPE_E4M3, 128, 1, 8, 8, 128, true, false, kSM_86, xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_pagedKV_128_nqpkv_8_m_8_sm_86_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_pagedKV_128_nqpkv_8_m_8_sm_86_cubin_len, "kernel_mha"},
{ DATA_TYPE_FP16, DATA_TYPE_FP16, 128, 1, 8, 8, 0, false, false, kSM_89, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_8_m_8_sm_89_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_8_m_8_sm_89_cubin_len, "kernel_mha"},
{ DATA_TYPE_FP16, DATA_TYPE_FP16, 128, 1, 8, 8, 64, true, false, kSM_89, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_64_nqpkv_8_m_8_sm_89_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_64_nqpkv_8_m_8_sm_89_cubin_len, "kernel_mha"},
{ DATA_TYPE_FP16, DATA_TYPE_FP16, 128, 1, 8, 8, 128, true, false, kSM_89, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_128_nqpkv_8_m_8_sm_89_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_128_nqpkv_8_m_8_sm_89_cubin_len, "kernel_mha"},
{ DATA_TYPE_FP16, DATA_TYPE_INT8, 128, 1, 8, 8, 0, false, false, kSM_89, xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_nqpkv_8_m_8_sm_89_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_nqpkv_8_m_8_sm_89_cubin_len, "kernel_mha"},
{ DATA_TYPE_FP16, DATA_TYPE_INT8, 128, 1, 8, 8, 64, true, false, kSM_89, xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_pagedKV_64_nqpkv_8_m_8_sm_89_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_pagedKV_64_nqpkv_8_m_8_sm_89_cubin_len, "kernel_mha"},
{ DATA_TYPE_FP16, DATA_TYPE_INT8, 128, 1, 8, 8, 128, true, false, kSM_89, xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_pagedKV_128_nqpkv_8_m_8_sm_89_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_pagedKV_128_nqpkv_8_m_8_sm_89_cubin_len, "kernel_mha"},
{ DATA_TYPE_FP16, DATA_TYPE_E4M3, 128, 1, 8, 8, 0, false, false, kSM_89, xqa_kernel_dt_fp16_d_128_beam_1_kvt_e4m3_nqpkv_8_m_8_sm_89_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_e4m3_nqpkv_8_m_8_sm_89_cubin_len, "kernel_mha"},
{ DATA_TYPE_FP16, DATA_TYPE_E4M3, 128, 1, 8, 8, 64, true, false, kSM_89, xqa_kernel_dt_fp16_d_128_beam_1_kvt_e4m3_pagedKV_64_nqpkv_8_m_8_sm_89_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_e4m3_pagedKV_64_nqpkv_8_m_8_sm_89_cubin_len, "kernel_mha"},
{ DATA_TYPE_FP16, DATA_TYPE_E4M3, 128, 1, 8, 8, 128, true, false, kSM_89, xqa_kernel_dt_fp16_d_128_beam_1_kvt_e4m3_pagedKV_128_nqpkv_8_m_8_sm_89_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_e4m3_pagedKV_128_nqpkv_8_m_8_sm_89_cubin_len, "kernel_mha"},
{ DATA_TYPE_BF16, DATA_TYPE_FP16, 128, 1, 8, 8, 0, false, false, kSM_89, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_nqpkv_8_m_8_sm_89_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_nqpkv_8_m_8_sm_89_cubin_len, "kernel_mha"},
{ DATA_TYPE_BF16, DATA_TYPE_FP16, 128, 1, 8, 8, 64, true, false, kSM_89, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_64_nqpkv_8_m_8_sm_89_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_64_nqpkv_8_m_8_sm_89_cubin_len, "kernel_mha"},
{ DATA_TYPE_BF16, DATA_TYPE_FP16, 128, 1, 8, 8, 128, true, false, kSM_89, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_128_nqpkv_8_m_8_sm_89_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_128_nqpkv_8_m_8_sm_89_cubin_len, "kernel_mha"},
{ DATA_TYPE_BF16, DATA_TYPE_INT8, 128, 1, 8, 8, 0, false, false, kSM_89, xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_nqpkv_8_m_8_sm_89_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_nqpkv_8_m_8_sm_89_cubin_len, "kernel_mha"},
{ DATA_TYPE_BF16, DATA_TYPE_INT8, 128, 1, 8, 8, 64, true, false, kSM_89, xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_pagedKV_64_nqpkv_8_m_8_sm_89_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_pagedKV_64_nqpkv_8_m_8_sm_89_cubin_len, "kernel_mha"},
{ DATA_TYPE_BF16, DATA_TYPE_INT8, 128, 1, 8, 8, 128, true, false, kSM_89, xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_pagedKV_128_nqpkv_8_m_8_sm_89_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_pagedKV_128_nqpkv_8_m_8_sm_89_cubin_len, "kernel_mha"},
{ DATA_TYPE_BF16, DATA_TYPE_E4M3, 128, 1, 8, 8, 0, false, false, kSM_89, xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_nqpkv_8_m_8_sm_89_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_nqpkv_8_m_8_sm_89_cubin_len, "kernel_mha"},
{ DATA_TYPE_BF16, DATA_TYPE_E4M3, 128, 1, 8, 8, 64, true, false, kSM_89, xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_pagedKV_64_nqpkv_8_m_8_sm_89_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_pagedKV_64_nqpkv_8_m_8_sm_89_cubin_len, "kernel_mha"},
{ DATA_TYPE_BF16, DATA_TYPE_E4M3, 128, 1, 8, 8, 128, true, false, kSM_89, xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_pagedKV_128_nqpkv_8_m_8_sm_89_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_pagedKV_128_nqpkv_8_m_8_sm_89_cubin_len, "kernel_mha"},
{ DATA_TYPE_FP16, DATA_TYPE_FP16, 128, 1, 8, 8, 0, false, false, kSM_90, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_8_m_8_sm_90_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_8_m_8_sm_90_cubin_len, "kernel_mha"},
{ DATA_TYPE_FP16, DATA_TYPE_FP16, 128, 1, 8, 8, 64, true, false, kSM_90, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_64_nqpkv_8_m_8_sm_90_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_64_nqpkv_8_m_8_sm_90_cubin_len, "kernel_mha"},
{ DATA_TYPE_FP16, DATA_TYPE_FP16, 128, 1, 8, 8, 128, true, false, kSM_90, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_128_nqpkv_8_m_8_sm_90_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_128_nqpkv_8_m_8_sm_90_cubin_len, "kernel_mha"},
{ DATA_TYPE_FP16, DATA_TYPE_INT8, 128, 1, 8, 8, 0, false, false, kSM_90, xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_nqpkv_8_m_8_sm_90_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_nqpkv_8_m_8_sm_90_cubin_len, "kernel_mha"},
{ DATA_TYPE_FP16, DATA_TYPE_INT8, 128, 1, 8, 8, 64, true, false, kSM_90, xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_pagedKV_64_nqpkv_8_m_8_sm_90_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_pagedKV_64_nqpkv_8_m_8_sm_90_cubin_len, "kernel_mha"},
{ DATA_TYPE_FP16, DATA_TYPE_INT8, 128, 1, 8, 8, 128, true, false, kSM_90, xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_pagedKV_128_nqpkv_8_m_8_sm_90_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_pagedKV_128_nqpkv_8_m_8_sm_90_cubin_len, "kernel_mha"},
{ DATA_TYPE_FP16, DATA_TYPE_E4M3, 128, 1, 8, 8, 0, false, false, kSM_90, xqa_kernel_dt_fp16_d_128_beam_1_kvt_e4m3_nqpkv_8_m_8_sm_90_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_e4m3_nqpkv_8_m_8_sm_90_cubin_len, "kernel_mha"},
{ DATA_TYPE_FP16, DATA_TYPE_E4M3, 128, 1, 8, 8, 64, true, false, kSM_90, xqa_kernel_dt_fp16_d_128_beam_1_kvt_e4m3_pagedKV_64_nqpkv_8_m_8_sm_90_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_e4m3_pagedKV_64_nqpkv_8_m_8_sm_90_cubin_len, "kernel_mha"},
{ DATA_TYPE_FP16, DATA_TYPE_E4M3, 128, 1, 8, 8, 128, true, false, kSM_90, xqa_kernel_dt_fp16_d_128_beam_1_kvt_e4m3_pagedKV_128_nqpkv_8_m_8_sm_90_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_e4m3_pagedKV_128_nqpkv_8_m_8_sm_90_cubin_len, "kernel_mha"},
{ DATA_TYPE_BF16, DATA_TYPE_FP16, 128, 1, 8, 8, 0, false, false, kSM_90, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_nqpkv_8_m_8_sm_90_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_nqpkv_8_m_8_sm_90_cubin_len, "kernel_mha"},
{ DATA_TYPE_BF16, DATA_TYPE_FP16, 128, 1, 8, 8, 64, true, false, kSM_90, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_64_nqpkv_8_m_8_sm_90_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_64_nqpkv_8_m_8_sm_90_cubin_len, "kernel_mha"},
{ DATA_TYPE_BF16, DATA_TYPE_FP16, 128, 1, 8, 8, 128, true, false, kSM_90, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_128_nqpkv_8_m_8_sm_90_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_128_nqpkv_8_m_8_sm_90_cubin_len, "kernel_mha"},
{ DATA_TYPE_BF16, DATA_TYPE_INT8, 128, 1, 8, 8, 0, false, false, kSM_90, xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_nqpkv_8_m_8_sm_90_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_nqpkv_8_m_8_sm_90_cubin_len, "kernel_mha"},
{ DATA_TYPE_BF16, DATA_TYPE_INT8, 128, 1, 8, 8, 64, true, false, kSM_90, xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_pagedKV_64_nqpkv_8_m_8_sm_90_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_pagedKV_64_nqpkv_8_m_8_sm_90_cubin_len, "kernel_mha"},
{ DATA_TYPE_BF16, DATA_TYPE_INT8, 128, 1, 8, 8, 128, true, false, kSM_90, xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_pagedKV_128_nqpkv_8_m_8_sm_90_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_pagedKV_128_nqpkv_8_m_8_sm_90_cubin_len, "kernel_mha"},
{ DATA_TYPE_BF16, DATA_TYPE_E4M3, 128, 1, 8, 8, 0, false, false, kSM_90, xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_nqpkv_8_m_8_sm_90_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_nqpkv_8_m_8_sm_90_cubin_len, "kernel_mha"},
{ DATA_TYPE_BF16, DATA_TYPE_E4M3, 128, 1, 8, 8, 64, true, false, kSM_90, xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_pagedKV_64_nqpkv_8_m_8_sm_90_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_pagedKV_64_nqpkv_8_m_8_sm_90_cubin_len, "kernel_mha"},
{ DATA_TYPE_BF16, DATA_TYPE_E4M3, 128, 1, 8, 8, 128, true, false, kSM_90, xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_pagedKV_128_nqpkv_8_m_8_sm_90_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_pagedKV_128_nqpkv_8_m_8_sm_90_cubin_len, "kernel_mha"},
// MultiQueryToken kernels.
{ DATA_TYPE_FP16, DATA_TYPE_FP16, 128, 1, 0, 16, 0, false, true, kSM_80, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_0_m_16_sm_80_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_0_m_16_sm_80_cubin_len, "kernel_mha"},
{ DATA_TYPE_FP16, DATA_TYPE_FP16, 128, 1, 0, 32, 0, false, true, kSM_80, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_0_m_32_sm_80_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_0_m_32_sm_80_cubin_len, "kernel_mha"},
{ DATA_TYPE_FP16, DATA_TYPE_FP16, 128, 1, 0, 16, 64, true, true, kSM_80, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_64_nqpkv_0_m_16_sm_80_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_64_nqpkv_0_m_16_sm_80_cubin_len, "kernel_mha"},
{ DATA_TYPE_FP16, DATA_TYPE_FP16, 128, 1, 0, 32, 64, true, true, kSM_80, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_64_nqpkv_0_m_32_sm_80_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_64_nqpkv_0_m_32_sm_80_cubin_len, "kernel_mha"},
{ DATA_TYPE_FP16, DATA_TYPE_FP16, 128, 1, 0, 16, 128, true, true, kSM_80, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_128_nqpkv_0_m_16_sm_80_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_128_nqpkv_0_m_16_sm_80_cubin_len, "kernel_mha"},
{ DATA_TYPE_FP16, DATA_TYPE_FP16, 128, 1, 0, 32, 128, true, true, kSM_80, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_128_nqpkv_0_m_32_sm_80_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_128_nqpkv_0_m_32_sm_80_cubin_len, "kernel_mha"},
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 128, 1, 0, 16, 0, false, true, kSM_80, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_nqpkv_0_m_16_sm_80_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_nqpkv_0_m_16_sm_80_cubin_len, "kernel_mha"},
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 128, 1, 0, 32, 0, false, true, kSM_80, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_nqpkv_0_m_32_sm_80_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_nqpkv_0_m_32_sm_80_cubin_len, "kernel_mha"},
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 128, 1, 0, 16, 64, true, true, kSM_80, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_64_nqpkv_0_m_16_sm_80_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_64_nqpkv_0_m_16_sm_80_cubin_len, "kernel_mha"},
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 128, 1, 0, 32, 64, true, true, kSM_80, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_64_nqpkv_0_m_32_sm_80_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_64_nqpkv_0_m_32_sm_80_cubin_len, "kernel_mha"},
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 128, 1, 0, 16, 128, true, true, kSM_80, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_128_nqpkv_0_m_16_sm_80_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_128_nqpkv_0_m_16_sm_80_cubin_len, "kernel_mha"},
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 128, 1, 0, 32, 128, true, true, kSM_80, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_128_nqpkv_0_m_32_sm_80_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_128_nqpkv_0_m_32_sm_80_cubin_len, "kernel_mha"},
{ DATA_TYPE_FP16, DATA_TYPE_FP16, 128, 1, 0, 16, 0, false, true, kSM_86, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_0_m_16_sm_86_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_0_m_16_sm_86_cubin_len, "kernel_mha"},
{ DATA_TYPE_FP16, DATA_TYPE_FP16, 128, 1, 0, 32, 0, false, true, kSM_86, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_0_m_32_sm_86_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_0_m_32_sm_86_cubin_len, "kernel_mha"},
{ DATA_TYPE_FP16, DATA_TYPE_FP16, 128, 1, 0, 16, 64, true, true, kSM_86, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_64_nqpkv_0_m_16_sm_86_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_64_nqpkv_0_m_16_sm_86_cubin_len, "kernel_mha"},
{ DATA_TYPE_FP16, DATA_TYPE_FP16, 128, 1, 0, 32, 64, true, true, kSM_86, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_64_nqpkv_0_m_32_sm_86_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_64_nqpkv_0_m_32_sm_86_cubin_len, "kernel_mha"},
{ DATA_TYPE_FP16, DATA_TYPE_FP16, 128, 1, 0, 16, 128, true, true, kSM_86, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_128_nqpkv_0_m_16_sm_86_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_128_nqpkv_0_m_16_sm_86_cubin_len, "kernel_mha"},
{ DATA_TYPE_FP16, DATA_TYPE_FP16, 128, 1, 0, 32, 128, true, true, kSM_86, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_128_nqpkv_0_m_32_sm_86_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_128_nqpkv_0_m_32_sm_86_cubin_len, "kernel_mha"},
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 128, 1, 0, 16, 0, false, true, kSM_86, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_nqpkv_0_m_16_sm_86_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_nqpkv_0_m_16_sm_86_cubin_len, "kernel_mha"},
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 128, 1, 0, 32, 0, false, true, kSM_86, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_nqpkv_0_m_32_sm_86_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_nqpkv_0_m_32_sm_86_cubin_len, "kernel_mha"},
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 128, 1, 0, 16, 64, true, true, kSM_86, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_64_nqpkv_0_m_16_sm_86_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_64_nqpkv_0_m_16_sm_86_cubin_len, "kernel_mha"},
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 128, 1, 0, 32, 64, true, true, kSM_86, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_64_nqpkv_0_m_32_sm_86_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_64_nqpkv_0_m_32_sm_86_cubin_len, "kernel_mha"},
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 128, 1, 0, 16, 128, true, true, kSM_86, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_128_nqpkv_0_m_16_sm_86_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_128_nqpkv_0_m_16_sm_86_cubin_len, "kernel_mha"},
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 128, 1, 0, 32, 128, true, true, kSM_86, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_128_nqpkv_0_m_32_sm_86_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_128_nqpkv_0_m_32_sm_86_cubin_len, "kernel_mha"},
{ DATA_TYPE_FP16, DATA_TYPE_FP16, 128, 1, 0, 16, 0, false, true, kSM_89, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_0_m_16_sm_89_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_0_m_16_sm_89_cubin_len, "kernel_mha"},
{ DATA_TYPE_FP16, DATA_TYPE_FP16, 128, 1, 0, 32, 0, false, true, kSM_89, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_0_m_32_sm_89_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_0_m_32_sm_89_cubin_len, "kernel_mha"},
{ DATA_TYPE_FP16, DATA_TYPE_FP16, 128, 1, 0, 16, 64, true, true, kSM_89, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_64_nqpkv_0_m_16_sm_89_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_64_nqpkv_0_m_16_sm_89_cubin_len, "kernel_mha"},
{ DATA_TYPE_FP16, DATA_TYPE_FP16, 128, 1, 0, 32, 64, true, true, kSM_89, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_64_nqpkv_0_m_32_sm_89_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_64_nqpkv_0_m_32_sm_89_cubin_len, "kernel_mha"},
{ DATA_TYPE_FP16, DATA_TYPE_FP16, 128, 1, 0, 16, 128, true, true, kSM_89, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_128_nqpkv_0_m_16_sm_89_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_128_nqpkv_0_m_16_sm_89_cubin_len, "kernel_mha"},
{ DATA_TYPE_FP16, DATA_TYPE_FP16, 128, 1, 0, 32, 128, true, true, kSM_89, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_128_nqpkv_0_m_32_sm_89_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_128_nqpkv_0_m_32_sm_89_cubin_len, "kernel_mha"},
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 128, 1, 0, 16, 0, false, true, kSM_89, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_nqpkv_0_m_16_sm_89_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_nqpkv_0_m_16_sm_89_cubin_len, "kernel_mha"},
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 128, 1, 0, 32, 0, false, true, kSM_89, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_nqpkv_0_m_32_sm_89_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_nqpkv_0_m_32_sm_89_cubin_len, "kernel_mha"},
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 128, 1, 0, 16, 64, true, true, kSM_89, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_64_nqpkv_0_m_16_sm_89_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_64_nqpkv_0_m_16_sm_89_cubin_len, "kernel_mha"},
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 128, 1, 0, 32, 64, true, true, kSM_89, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_64_nqpkv_0_m_32_sm_89_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_64_nqpkv_0_m_32_sm_89_cubin_len, "kernel_mha"},
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 128, 1, 0, 16, 128, true, true, kSM_89, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_128_nqpkv_0_m_16_sm_89_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_128_nqpkv_0_m_16_sm_89_cubin_len, "kernel_mha"},
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 128, 1, 0, 32, 128, true, true, kSM_89, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_128_nqpkv_0_m_32_sm_89_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_128_nqpkv_0_m_32_sm_89_cubin_len, "kernel_mha"},
{ DATA_TYPE_FP16, DATA_TYPE_FP16, 128, 1, 0, 16, 0, false, true, kSM_90, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_0_m_16_sm_90_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_0_m_16_sm_90_cubin_len, "kernel_mha"},
{ DATA_TYPE_FP16, DATA_TYPE_FP16, 128, 1, 0, 32, 0, false, true, kSM_90, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_0_m_32_sm_90_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_0_m_32_sm_90_cubin_len, "kernel_mha"},
{ DATA_TYPE_FP16, DATA_TYPE_FP16, 128, 1, 0, 16, 64, true, true, kSM_90, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_64_nqpkv_0_m_16_sm_90_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_64_nqpkv_0_m_16_sm_90_cubin_len, "kernel_mha"},
{ DATA_TYPE_FP16, DATA_TYPE_FP16, 128, 1, 0, 32, 64, true, true, kSM_90, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_64_nqpkv_0_m_32_sm_90_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_64_nqpkv_0_m_32_sm_90_cubin_len, "kernel_mha"},
{ DATA_TYPE_FP16, DATA_TYPE_FP16, 128, 1, 0, 16, 128, true, true, kSM_90, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_128_nqpkv_0_m_16_sm_90_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_128_nqpkv_0_m_16_sm_90_cubin_len, "kernel_mha"},
{ DATA_TYPE_FP16, DATA_TYPE_FP16, 128, 1, 0, 32, 128, true, true, kSM_90, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_128_nqpkv_0_m_32_sm_90_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_128_nqpkv_0_m_32_sm_90_cubin_len, "kernel_mha"},
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 128, 1, 0, 16, 0, false, true, kSM_90, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_nqpkv_0_m_16_sm_90_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_nqpkv_0_m_16_sm_90_cubin_len, "kernel_mha"},
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 128, 1, 0, 32, 0, false, true, kSM_90, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_nqpkv_0_m_32_sm_90_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_nqpkv_0_m_32_sm_90_cubin_len, "kernel_mha"},
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 128, 1, 0, 16, 64, true, true, kSM_90, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_64_nqpkv_0_m_16_sm_90_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_64_nqpkv_0_m_16_sm_90_cubin_len, "kernel_mha"},
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 128, 1, 0, 32, 64, true, true, kSM_90, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_64_nqpkv_0_m_32_sm_90_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_64_nqpkv_0_m_32_sm_90_cubin_len, "kernel_mha"},
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 128, 1, 0, 16, 128, true, true, kSM_90, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_128_nqpkv_0_m_16_sm_90_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_128_nqpkv_0_m_16_sm_90_cubin_len, "kernel_mha"},
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 128, 1, 0, 32, 128, true, true, kSM_90, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_128_nqpkv_0_m_32_sm_90_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_128_nqpkv_0_m_32_sm_90_cubin_len, "kernel_mha"}
};
// clang-format on

Some files were not shown because too many files have changed in this diff Show More