mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
Update TensorRT-LLM (20240116) (#891)
* Update TensorRT-LLM --------- Co-authored-by: Eddie-Wang1120 <81598289+Eddie-Wang1120@users.noreply.github.com> Co-authored-by: Shixiaowei02 <39303645+Shixiaowei02@users.noreply.github.com>
This commit is contained in:
parent
12e82e30b0
commit
c89653021e
116
.github/ISSUE_TEMPLATE/bug_report.yml
vendored
Normal file
116
.github/ISSUE_TEMPLATE/bug_report.yml
vendored
Normal file
@ -0,0 +1,116 @@
|
||||
name: "Bug Report"
|
||||
description: Submit a bug report to help us improve TensorRT-LLM
|
||||
labels: [ "bug" ]
|
||||
body:
|
||||
- type: textarea
|
||||
id: system-info
|
||||
attributes:
|
||||
label: System Info
|
||||
description: Please share your system info with us.
|
||||
placeholder: |
|
||||
- CPU architecture (e.g., x86_64, aarch64)
|
||||
- CPU/Host memory size (if known)
|
||||
- GPU properties
|
||||
- GPU name (e.g., NVIDIA H100, NVIDIA A100, NVIDIA L40S)
|
||||
- GPU memory size (if known)
|
||||
- Clock frequencies used (if applicable)
|
||||
- Libraries
|
||||
- TensorRT-LLM branch or tag (e.g., main, v0.7.1)
|
||||
- TensorRT-LLM commit (if known)
|
||||
- Versions of TensorRT, AMMO, CUDA, cuBLAS, etc. used
|
||||
- Container used (if running TensorRT-LLM in a container)
|
||||
- NVIDIA driver version
|
||||
- OS (Ubuntu 22.04, CentOS 7, Windows 10)
|
||||
- Any other information that may be useful in reproducing the bug
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
id: who-can-help
|
||||
attributes:
|
||||
label: Who can help?
|
||||
description: |
|
||||
To expedite the response to your issue, it would be helpful if you could identify the appropriate person
|
||||
to tag using the **@** symbol. Here is a general guideline on **whom to tag**.
|
||||
|
||||
Rest assured that all issues are reviewed by the core maintainers. If you are unsure about whom to tag,
|
||||
you can leave it blank, and a core maintainer will make sure to involve the appropriate person.
|
||||
|
||||
Please tag fewer than 3 people.
|
||||
|
||||
Quantization: @Tracin
|
||||
|
||||
Documentation: @juney-nvidia
|
||||
|
||||
Feature request: @ncomly-nvidia
|
||||
|
||||
Performance: @kaiyux
|
||||
|
||||
Others: @byshiue
|
||||
|
||||
placeholder: "@Username ..."
|
||||
|
||||
- type: checkboxes
|
||||
id: information-scripts-examples
|
||||
attributes:
|
||||
label: Information
|
||||
description: 'The problem arises when using:'
|
||||
options:
|
||||
- label: "The official example scripts"
|
||||
- label: "My own modified scripts"
|
||||
|
||||
- type: checkboxes
|
||||
id: information-tasks
|
||||
attributes:
|
||||
label: Tasks
|
||||
description: "The tasks I am working on are:"
|
||||
options:
|
||||
- label: "An officially supported task in the `examples` folder (such as GLUE/SQuAD, ...)"
|
||||
- label: "My own task or dataset (give details below)"
|
||||
|
||||
- type: textarea
|
||||
id: reproduction
|
||||
validations:
|
||||
required: true
|
||||
attributes:
|
||||
label: Reproduction
|
||||
description: |
|
||||
Kindly share a code example that demonstrates the issue you encountered. It is recommending to provide a code snippet directly.
|
||||
Additionally, if you have any error messages, or stack traces related to the problem, please include them here.
|
||||
|
||||
Remember to use code tags to properly format your code. You can refer to the
|
||||
link https://help.github.com/en/github/writing-on-github/creating-and-highlighting-code-blocks#syntax-highlighting for guidance on code formatting.
|
||||
|
||||
Please refrain from using screenshots, as they can be difficult to read and prevent others from copying and pasting your code.
|
||||
It would be most helpful if we could reproduce your issue by simply copying and pasting your scripts and codes.
|
||||
|
||||
placeholder: |
|
||||
Steps to reproduce the behavior:
|
||||
|
||||
1.
|
||||
2.
|
||||
3.
|
||||
|
||||
- type: textarea
|
||||
id: expected-behavior
|
||||
validations:
|
||||
required: true
|
||||
attributes:
|
||||
label: Expected behavior
|
||||
description: "Provide a brief summary of the expected behavior of the software. Provide output files or examples if possible."
|
||||
|
||||
- type: textarea
|
||||
id: actual-behavior
|
||||
validations:
|
||||
required: true
|
||||
attributes:
|
||||
label: actual behavior
|
||||
description: "Describe the actual behavior of the software and how it deviates from the expected behavior. Provide output files or examples if possible."
|
||||
|
||||
- type: textarea
|
||||
id: additioanl-notes
|
||||
validations:
|
||||
required: true
|
||||
attributes:
|
||||
label: additional notes
|
||||
description: "Provide any additional context here you think might be useful for the TensorRT-LLM team to help debug this issue (such as experiments done, potential things to investigate)."
|
||||
@ -15,7 +15,8 @@ repos:
|
||||
rev: v4.1.0
|
||||
hooks:
|
||||
- id: check-added-large-files
|
||||
exclude: 'cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/'
|
||||
exclude: |
|
||||
(?x)^(.*cubin.cpp)$
|
||||
- id: check-merge-conflict
|
||||
- id: check-symlinks
|
||||
- id: detect-private-key
|
||||
|
||||
24
README.md
24
README.md
@ -45,8 +45,6 @@ H200 is now 2.4x faster on Llama-70B with recent improvements to TensorRT-LLM GQ
|
||||
|
||||
- [TensorRT-LLM Overview](#tensorrt-llm-overview)
|
||||
- [Installation](#installation)
|
||||
- [Linux](./docs/source/installation.md)
|
||||
- [Windows](windows/README.md)
|
||||
- [Quick Start](#quick-start)
|
||||
- [Support Matrix](#support-matrix)
|
||||
- [Devices](#devices)
|
||||
@ -110,10 +108,26 @@ concepts used in TensorRT-LLM, we recommend you to read the following
|
||||
|
||||
## Installation
|
||||
|
||||
*For Linux installation, see [`Linux`](./docs/source/installation.md).*
|
||||
*For Windows installation, see [`Windows`](windows/README.md).*
|
||||
After installing the [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit),
|
||||
please run the following commands to install TensorRT-LLM.
|
||||
|
||||
Once installed, commands to build and run LLMs must be executed from the TensorRT-LLM container.
|
||||
```bash
|
||||
# Obtain and start the basic docker image environment
|
||||
nvidia-docker run --entrypoint /bin/bash -it nvidia/cuda:12.1.0-devel-ubuntu22.04
|
||||
# Install dependencies, TensorRT-LLM requires Python 3.10
|
||||
apt-get update && apt-get -y install python3.10 python3-pip openmpi-bin libopenmpi-dev
|
||||
# Install the latest preview version (corresponding to the main branch) of TensorRT-LLM.
|
||||
# If you want to install the stable version (corresponding to the release branch), please
|
||||
# remove the `--pre` option.
|
||||
pip3 install tensorrt_llm -U --pre --extra-index-url https://pypi.nvidia.com
|
||||
# Check installation
|
||||
python3 -c "import tensorrt_llm; print(tensorrt_llm.__version__)"
|
||||
```
|
||||
|
||||
For users who require the best performance or debugging capabilities, please refer to the instructions for
|
||||
[building from source code](docs/source/build_from_source.md).
|
||||
|
||||
For Windows installation, see [`Windows`](windows/README.md).
|
||||
|
||||
## Quick Start
|
||||
|
||||
|
||||
@ -7,7 +7,7 @@ multiple GPUs or multiple nodes with multiple GPUs.
|
||||
|
||||
### 1. Build TensorRT-LLM and benchmarking source code
|
||||
|
||||
Please follow the [`installation document`](../../docs/source/installation.md) to build TensorRT-LLM.
|
||||
Please follow the [`installation document`](../../README.md#installation) to build TensorRT-LLM.
|
||||
|
||||
Note that the benchmarking source code for C++ runtime is not built by default, you can use the argument `--benchmarks` in [`build_wheel.py`](source:scripts/build_wheel.py) to build the corresponding executable.
|
||||
|
||||
|
||||
@ -123,6 +123,17 @@ void benchmarkGptSession(std::string const& modelName, std::filesystem::path con
|
||||
bufferManager.emptyTensor(MemoryType::kGPU, nvinfer1::DataType::kINT32),
|
||||
bufferManager.emptyTensor(MemoryType::kGPU, nvinfer1::DataType::kINT32)};
|
||||
|
||||
if (session.getModelConfig().computeContextLogits())
|
||||
{
|
||||
generationOutput.contextLogits
|
||||
= bufferManager.emptyTensor(MemoryType::kGPU, nvinfer1::DataType::kFLOAT);
|
||||
}
|
||||
if (session.getModelConfig().computeGenerationLogits())
|
||||
{
|
||||
generationOutput.generationLogits
|
||||
= bufferManager.emptyTensor(MemoryType::kGPU, nvinfer1::DataType::kFLOAT);
|
||||
}
|
||||
|
||||
TLLM_LOG_INFO(memoryCounter.toString());
|
||||
|
||||
for (auto r = 0; r < warmUp; ++r)
|
||||
@ -175,21 +186,20 @@ void benchmarkGptSession(std::string const& modelName, std::filesystem::path con
|
||||
if (session.getModelConfig().computeContextLogits() && printAllLogits)
|
||||
{
|
||||
std::cout << "generationOutput.contextLogits.shape: "
|
||||
<< generationOutput.contextLogitsHost->getShape()
|
||||
<< generationOutput.contextLogits->getShape()
|
||||
<< std::endl; // (batchsize, prompt_len, vocabsize)
|
||||
std::cout << "generationOutput.contextLogits: " << *generationOutput.contextLogitsHost
|
||||
<< std::endl;
|
||||
std::cout << "generationOutput.contextLogits: " << *generationOutput.contextLogits << std::endl;
|
||||
}
|
||||
|
||||
if (session.getModelConfig().computeGenerationLogits() && printAllLogits)
|
||||
{
|
||||
std::cout << "generationOutput.generationLogits.shape: "
|
||||
<< generationOutput.generationLogitsHost->getShape()
|
||||
<< generationOutput.generationLogits->getShape()
|
||||
<< std::endl; // (batchsize, beamwidth, maxNewTokens, vocabsize)
|
||||
generationOutput.generationLogitsHost->reshape(ITensor::makeShape({batchSize * beamWidth,
|
||||
generationOutput.generationLogits->reshape(ITensor::makeShape({batchSize * beamWidth,
|
||||
maxNewTokens, modelConfig.getVocabSizePadded(worldConfig.getSize())}));
|
||||
|
||||
std::cout << "generationOutput.generationLogits: " << *generationOutput.generationLogitsHost
|
||||
std::cout << "generationOutput.generationLogits: " << *generationOutput.generationLogits
|
||||
<< std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
@ -17,7 +17,6 @@ from argparse import ArgumentParser
|
||||
|
||||
# isort: off
|
||||
import torch
|
||||
import tensorrt as trt
|
||||
# isort: on
|
||||
from cuda import cuda, cudart
|
||||
from mpi4py import MPI
|
||||
@ -25,8 +24,10 @@ from polygraphy.backend.trt import CreateConfig, EngineFromNetwork
|
||||
|
||||
import tensorrt_llm as tllm
|
||||
from tensorrt_llm import Mapping, Tensor
|
||||
from tensorrt_llm._ipc_utils import IpcMemory, peer_access
|
||||
from tensorrt_llm._ipc_utils import peer_access
|
||||
from tensorrt_llm.functional import AllReduceStrategy, allreduce
|
||||
from tensorrt_llm.plugin.plugin import (current_all_reduce_helper,
|
||||
init_all_reduce_helper)
|
||||
|
||||
|
||||
def allreduce_benchmark(dtype: str, test_range: str = "10,10000000,10"):
|
||||
@ -42,25 +43,17 @@ def allreduce_benchmark(dtype: str, test_range: str = "10,10000000,10"):
|
||||
if world_size == 1:
|
||||
raise RuntimeError("Benchmark must run with mpi_world_size > 1")
|
||||
|
||||
ipc_barriers_in = IpcMemory(
|
||||
mapping, IpcMemory.IPC_BARRIERS_SIZE_PER_GPU * mapping.tp_size)
|
||||
ipc_barriers_out = IpcMemory(
|
||||
mapping, IpcMemory.IPC_BARRIERS_SIZE_PER_GPU * mapping.tp_size)
|
||||
torch_dtype = tllm._utils.str_dtype_to_torch(dtype)
|
||||
|
||||
min_size, max_size, ratio = [int(i) for i in test_range.split(",")]
|
||||
inner_loop = 1000
|
||||
|
||||
size = min_size
|
||||
dtype_size = torch.finfo(torch_dtype).bits // 8
|
||||
init_all_reduce_helper()
|
||||
while size < max_size:
|
||||
ipc_buffers = IpcMemory(mapping, size * 4)
|
||||
workspace = torch.tensor(ipc_buffers.serialize() +
|
||||
ipc_barriers_in.serialize() +
|
||||
ipc_barriers_out.serialize(),
|
||||
dtype=torch.int64,
|
||||
device="cpu")
|
||||
|
||||
input = torch.zeros(size, dtype=torch_dtype, device="cuda")
|
||||
_buffers, workspace = current_all_reduce_helper().allocate_workspace(
|
||||
mapping, size * dtype_size)
|
||||
input = torch.ones(size, dtype=torch_dtype, device="cuda")
|
||||
|
||||
for strategy in [
|
||||
AllReduceStrategy.RING, AllReduceStrategy.ONESHOT,
|
||||
@ -77,16 +70,11 @@ def allreduce_benchmark(dtype: str, test_range: str = "10,10000000,10"):
|
||||
shape=input.shape,
|
||||
dtype=tllm.str_dtype_to_trt(dtype))
|
||||
|
||||
w = Tensor(name='workspace',
|
||||
shape=workspace.shape,
|
||||
dtype=trt.int64)
|
||||
current_all_reduce_helper().set_workspace_tensor(mapping)
|
||||
|
||||
current = x
|
||||
for i in range(inner_loop):
|
||||
current = allreduce(
|
||||
current, mapping.tp_group,
|
||||
w if strategy != AllReduceStrategy.RING else None, i,
|
||||
strategy)
|
||||
for _ in range(inner_loop):
|
||||
current = allreduce(current, mapping.tp_group, strategy)
|
||||
output = current.trt_tensor
|
||||
|
||||
output.name = 'output'
|
||||
@ -104,7 +92,7 @@ def allreduce_benchmark(dtype: str, test_range: str = "10,10000000,10"):
|
||||
output = torch.zeros_like(input)
|
||||
|
||||
stream = torch.cuda.current_stream()
|
||||
feed_dict = {'x': input, 'workspace': workspace}
|
||||
feed_dict = {'x': input, 'all_reduce_workspace': workspace}
|
||||
|
||||
session = tllm.runtime.Session.from_engine(build_engine())
|
||||
_, start = cuda.cuEventCreate(0)
|
||||
@ -119,9 +107,11 @@ def allreduce_benchmark(dtype: str, test_range: str = "10,10000000,10"):
|
||||
cuda.cuEventRecord(stop, stream.cuda_stream)
|
||||
torch.cuda.synchronize()
|
||||
_, ms = cuda.cuEventElapsedTime(start, stop)
|
||||
assert torch.allclose(output, (input * world_size)**inner_loop)
|
||||
|
||||
if mapping.rank == 0:
|
||||
print(f"{size=}, {strategy=}, {ms=}")
|
||||
|
||||
size *= ratio
|
||||
if mapping.rank == 0:
|
||||
print("")
|
||||
|
||||
@ -197,8 +197,8 @@ _allowed_configs = {
|
||||
builder_opt=None,
|
||||
quantization="int8_sq_per_token_channel",
|
||||
)),
|
||||
"gpt-next_2b":
|
||||
ModelConfig(name="gpt-next_2b",
|
||||
"gpt_next_2b":
|
||||
ModelConfig(name="gpt_next_2b",
|
||||
family="gpt",
|
||||
benchmark_type="gpt",
|
||||
build_config=BuildConfig(
|
||||
@ -305,25 +305,6 @@ _allowed_configs = {
|
||||
max_output_len=200,
|
||||
builder_opt=None,
|
||||
)),
|
||||
"llama_7b_moe":
|
||||
ModelConfig(name="llama_7b_moe",
|
||||
family="llama",
|
||||
benchmark_type="gpt",
|
||||
build_config=BuildConfig(
|
||||
num_layers=32,
|
||||
num_heads=32,
|
||||
hidden_size=4096,
|
||||
vocab_size=32000,
|
||||
hidden_act='silu',
|
||||
n_positions=2048,
|
||||
inter_size=11008,
|
||||
max_batch_size=128,
|
||||
max_input_len=512,
|
||||
max_output_len=200,
|
||||
builder_opt=None,
|
||||
moe_num_experts=4,
|
||||
moe_top_k=1,
|
||||
)),
|
||||
"llama_13b":
|
||||
ModelConfig(name="llama_13b",
|
||||
family="llama",
|
||||
@ -427,6 +408,25 @@ _allowed_configs = {
|
||||
max_output_len=200,
|
||||
builder_opt=None,
|
||||
quantization="int8_sq_per_tensor")),
|
||||
"mixtral_8x7b":
|
||||
ModelConfig(name="mixtral_8x7b",
|
||||
family="llama",
|
||||
benchmark_type="gpt",
|
||||
build_config=BuildConfig(
|
||||
num_layers=32,
|
||||
num_heads=32,
|
||||
hidden_size=4096,
|
||||
vocab_size=32000,
|
||||
hidden_act='swiglu',
|
||||
n_positions=2048,
|
||||
inter_size=14336,
|
||||
max_batch_size=128,
|
||||
max_input_len=512,
|
||||
max_output_len=200,
|
||||
builder_opt=None,
|
||||
moe_num_experts=8,
|
||||
moe_top_k=2,
|
||||
)),
|
||||
"gptj_6b":
|
||||
ModelConfig(name="gptj_6b",
|
||||
family="gptj",
|
||||
|
||||
@ -251,7 +251,6 @@ def main(args):
|
||||
from bert_benchmark import BERTBenchmark
|
||||
from enc_dec_benchmark import EncDecBenchmark
|
||||
from gpt_benchmark import GPTBenchmark
|
||||
from mem_monitor import MemoryMonitor
|
||||
|
||||
import tensorrt_llm
|
||||
from tensorrt_llm.logger import logger
|
||||
@ -281,6 +280,13 @@ def main(args):
|
||||
rank = tensorrt_llm.mpi_rank()
|
||||
world_size = tensorrt_llm.mpi_world_size()
|
||||
|
||||
# TODO: Re-enable memory monitor for multi-gpu benchmarks.
|
||||
# Current Mem Monitor will cause benchmark script hang
|
||||
# because MPI does not work well with multiprocessing.
|
||||
disable_mem_monitor = world_size > 1
|
||||
if not disable_mem_monitor:
|
||||
from mem_monitor import MemoryMonitor
|
||||
|
||||
benchmark_profiler = None
|
||||
if args.model in get_allowed_models(benchmark_type="gpt"):
|
||||
benchmark_profiler = BenchmarkProfiler()
|
||||
@ -314,8 +320,9 @@ def main(args):
|
||||
torch.cuda.empty_cache()
|
||||
latencies = []
|
||||
|
||||
memory_monitor = MemoryMonitor()
|
||||
memory_monitor.start()
|
||||
if not disable_mem_monitor:
|
||||
memory_monitor = MemoryMonitor()
|
||||
memory_monitor.start()
|
||||
|
||||
iter_idx = 0
|
||||
try:
|
||||
@ -346,12 +353,17 @@ def main(args):
|
||||
|
||||
except Exception as e:
|
||||
print("Found exception during benchmarking", e.with_traceback())
|
||||
memory_monitor.kill()
|
||||
if not disable_mem_monitor:
|
||||
memory_monitor.kill()
|
||||
raise e
|
||||
|
||||
memory_monitor.stop()
|
||||
_, peak_gpu_used = memory_monitor.get_peak_memory_usage("GiB")
|
||||
peak_gpu_used = round(peak_gpu_used, 3)
|
||||
if not disable_mem_monitor:
|
||||
memory_monitor.stop()
|
||||
_, peak_gpu_used = memory_monitor.get_peak_memory_usage("GiB")
|
||||
peak_gpu_used = round(peak_gpu_used, 3)
|
||||
else:
|
||||
peak_gpu_used = 0.0
|
||||
|
||||
if benchmark_profiler is not None:
|
||||
benchmark_profiler.add_aux_info('iter_count', iter_idx)
|
||||
benchmark_profiler.stop()
|
||||
|
||||
@ -12,11 +12,12 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
from multiprocessing import Event, Process, Queue
|
||||
|
||||
from tensorrt_llm.logger import logger
|
||||
from tensorrt_llm.profiler import (MemUnitType, bytes_to_target_unit,
|
||||
device_memory_info)
|
||||
device_memory_info, host_memory_info)
|
||||
|
||||
|
||||
class MemoryMonitor:
|
||||
@ -28,6 +29,7 @@ class MemoryMonitor:
|
||||
self._peak_host_memory = 0
|
||||
self._peak_device_memory = 0
|
||||
|
||||
self.pid = os.getpid()
|
||||
self.device_handles = {}
|
||||
|
||||
self.signal_event = Event() # Sending signal to subprocess
|
||||
@ -49,19 +51,28 @@ class MemoryMonitor:
|
||||
self.signal_event.set()
|
||||
logger.debug("Sent signal to stop memory monitor subprocess.")
|
||||
|
||||
peak_mem_use = self.peak_mem_queue.get()
|
||||
|
||||
self._peak_host_memory = max(self._peak_host_memory, peak_mem_use[0])
|
||||
self._peak_device_memory = max(self._peak_device_memory,
|
||||
self.peak_mem_queue.get())
|
||||
peak_mem_use[1])
|
||||
|
||||
self.mem_monitor_process.join()
|
||||
self.mem_monitor_process = None
|
||||
logger.debug("Memory monitor subprocess joined.")
|
||||
|
||||
def _upd_peak_memory_usage(self, signal_event, peak_mem_queue):
|
||||
peak_used, _, _ = device_memory_info()
|
||||
peak_host_used, peak_device_used = self.get_memory_usage()
|
||||
while not signal_event.is_set():
|
||||
used, _, _ = device_memory_info()
|
||||
peak_used = max(used, peak_used)
|
||||
peak_mem_queue.put(peak_used)
|
||||
host_used, device_used = self.get_memory_usage()
|
||||
peak_host_used = max(host_used, peak_host_used)
|
||||
peak_device_used = max(device_used, peak_device_used)
|
||||
peak_mem_queue.put((peak_host_used, peak_device_used))
|
||||
|
||||
def get_memory_usage(self):
|
||||
host_used, _, _ = host_memory_info(self.pid)
|
||||
device_used, _, _ = device_memory_info()
|
||||
return host_used, device_used
|
||||
|
||||
def get_peak_memory_usage(self, unit: MemUnitType = 'GiB'):
|
||||
return bytes_to_target_unit(self._peak_host_memory, unit), \
|
||||
|
||||
@ -29,7 +29,7 @@ project(tensorrt_llm LANGUAGES CXX)
|
||||
# Build options
|
||||
option(BUILD_PYT "Build in PyTorch TorchScript class mode" ON)
|
||||
option(BUILD_PYBIND "Build Python bindings for C++ runtime and batch manager"
|
||||
OFF)
|
||||
ON)
|
||||
option(BUILD_TESTS "Build Google tests" ON)
|
||||
option(BUILD_BENCHMARKS "Build benchmarks" ON)
|
||||
option(NVTX_DISABLE "Disable all NVTX features" ON)
|
||||
|
||||
@ -33,7 +33,7 @@ public:
|
||||
using SizeType = tensorrt_llm::runtime::SizeType;
|
||||
|
||||
explicit TrtGptModelOptionalParams(KvCacheConfig const& kvCacheConfig = KvCacheConfig{},
|
||||
std::optional<SizeType> maxNumSequences = std::nullopt, bool enableTrtOverlap = true,
|
||||
std::optional<SizeType> maxNumSequences = std::nullopt, bool enableTrtOverlap = false,
|
||||
std::optional<std::vector<SizeType>> const& deviceIds = std::nullopt, bool normalizeLogProbs = true,
|
||||
bool logIterationData = false)
|
||||
: kvCacheConfig{kvCacheConfig}
|
||||
|
||||
@ -39,8 +39,6 @@ public:
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(static_cast<bool>(this->ids), "Invalid ids tensor");
|
||||
TLLM_CHECK_WITH_INFO(static_cast<bool>(this->lengths), "Invalid lengths tensor");
|
||||
|
||||
generationLogitsFragments = std::make_shared<std::vector<TensorPtr>>();
|
||||
}
|
||||
|
||||
// mandatory parameters
|
||||
@ -53,11 +51,6 @@ public:
|
||||
TensorPtr contextLogits; // [batch_size, max_input_length, vocab_size_padded], if packed, the shape will be
|
||||
// [packed_size, vocab_size_padded]
|
||||
TensorPtr generationLogits; // [batch_size, beam_width, max_output_length, vocab_size_padded]
|
||||
// generation logit pointer list
|
||||
std::shared_ptr<std::vector<TensorPtr>> generationLogitsFragments;
|
||||
|
||||
TensorPtr contextLogitsHost;
|
||||
TensorPtr generationLogitsHost;
|
||||
|
||||
// callbacks
|
||||
Callback onTokenGenerated;
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:32eb4ace3a218b307f4a8ade8b2f47540d690118539c75bbff82ee251e098f3c
|
||||
size 1927834
|
||||
oid sha256:85691c252e1025d087c3aeb34fa249ea2424b079b883c28228ec0f838704ee0b
|
||||
size 1920092
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:dcf82d3f89bfd45abffbdd0856bd18d0c80154397fedd4617f65efca8c5619d5
|
||||
size 1938574
|
||||
oid sha256:5ac6ca64007c016dff30aadb3f167b263c31aef995a66ac96d731d2c077a3535
|
||||
size 1930918
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
0ad3a2c87135a42a7c766472e3690945 libtensorrt_llm_batch_manager_static.a
|
||||
e83f9560653b8a7bf0e5465fed610422 libtensorrt_llm_batch_manager_static.pre_cxx11.a
|
||||
87230159783a783af845d5fbc908d47108fbb754 commit
|
||||
435c2a649cca52168e243ff67b7e4811 libtensorrt_llm_batch_manager_static.a
|
||||
a6c17eee9cb72f3442482656a0f7307c libtensorrt_llm_batch_manager_static.pre_cxx11.a
|
||||
d086d9953961e0cf567a4f6185d929df4c5bf9ec commit
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:00627c7c18991cf9d832c63d4df4174744b4bed2d3a393051db5523f36a6a16b
|
||||
size 1867794
|
||||
oid sha256:52183471dbd6db61eff4cc09f78f56e8ce48f624947aa81806f7fb7585dd2617
|
||||
size 1863048
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:3d06f1cc77ea915b9494f6362a8229dbc6d4728711b25c7dab813869a75b599e
|
||||
size 1847520
|
||||
oid sha256:a63a46d5744b119398678983e252b2a142d2005894643650eb698cb98e5362cd
|
||||
size 1838410
|
||||
|
||||
@ -1,2 +1,2 @@
|
||||
1424807fe5927445d2daf3a3d9db19cc libtensorrt_llm_batch_manager_static.a
|
||||
afc1bba58d7138cfb018790b62a92990 libtensorrt_llm_batch_manager_static.pre_cxx11.a
|
||||
591e69f5ed2dd632ef62dc57900894db libtensorrt_llm_batch_manager_static.a
|
||||
1733f8c663e9c0c1969e74f6a741dbb8 libtensorrt_llm_batch_manager_static.pre_cxx11.a
|
||||
|
||||
@ -200,15 +200,6 @@ inline __device__ __nv_bfloat162 bf16exp2(const __nv_bfloat162 x)
|
||||
|
||||
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)
|
||||
#if defined(CUDART_VERSION) && (CUDART_VERSION < 12020)
|
||||
inline __device__ __nv_bfloat162 operator*(const __nv_bfloat162 x, const __nv_bfloat162 y)
|
||||
{
|
||||
return bf16hmul2(x, y);
|
||||
};
|
||||
|
||||
inline __device__ __nv_bfloat162 operator+(const __nv_bfloat162 x, const __nv_bfloat162 y)
|
||||
{
|
||||
return bf16hadd2(x, y);
|
||||
};
|
||||
|
||||
inline __device__ __nv_bfloat162 make_bfloat162(const __nv_bfloat16 x, const __nv_bfloat16 y)
|
||||
{
|
||||
@ -301,3 +292,22 @@ inline __device__ __nv_bfloat162 bf16hfma2(__nv_bfloat162 a, __nv_bfloat162 b, _
|
||||
|
||||
} // namespace common
|
||||
} // namespace tensorrt_llm
|
||||
|
||||
// Operator definitions intentionally in global namespace
|
||||
namespace
|
||||
{
|
||||
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)
|
||||
#if defined(CUDART_VERSION) && (CUDART_VERSION < 12020)
|
||||
|
||||
inline __device__ __nv_bfloat162 operator*(const __nv_bfloat162 x, const __nv_bfloat162 y)
|
||||
{
|
||||
return tensorrt_llm::common::bf16hmul2(x, y);
|
||||
};
|
||||
|
||||
inline __device__ __nv_bfloat162 operator+(const __nv_bfloat162 x, const __nv_bfloat162 y)
|
||||
{
|
||||
return tensorrt_llm::common::bf16hadd2(x, y);
|
||||
};
|
||||
#endif
|
||||
#endif
|
||||
} // namespace
|
||||
|
||||
40
cpp/tensorrt_llm/common/customAllReduceUtils.h
Normal file
40
cpp/tensorrt_llm/common/customAllReduceUtils.h
Normal file
@ -0,0 +1,40 @@
|
||||
/*
|
||||
* Copyright (c) 2019-2024, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cstddef>
|
||||
|
||||
namespace tensorrt_llm::utils::customAllReduceUtils
|
||||
{
|
||||
|
||||
constexpr size_t NUM_POINTERS_PER_RANK = 4;
|
||||
|
||||
namespace
|
||||
{
|
||||
|
||||
// WARNING: MUST BE KEPT IN SYNC with tensorrt_llm/plugin/plugin.py
|
||||
size_t getMaxRequiredWorkspaceSize(int worldSize) noexcept
|
||||
{
|
||||
if (worldSize <= 2)
|
||||
{
|
||||
return 16 * 1000 * 1000;
|
||||
}
|
||||
return 8 * 1000 * 1000;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
} // namespace tensorrt_llm::utils::customAllReduceUtils
|
||||
@ -22,6 +22,21 @@
|
||||
namespace tensorrt_llm::common
|
||||
{
|
||||
|
||||
// XQA kernels (optimized kernels for generation phase).
|
||||
bool forceXQAKernels()
|
||||
{
|
||||
const char* force_xqa_env_var = getenv("TRTLLM_FORCE_XQA");
|
||||
static bool forceXQA = false;
|
||||
if (force_xqa_env_var != nullptr)
|
||||
{
|
||||
if (force_xqa_env_var[0] == '1' && force_xqa_env_var[1] == '\0')
|
||||
{
|
||||
forceXQA = true;
|
||||
}
|
||||
}
|
||||
return forceXQA;
|
||||
}
|
||||
|
||||
// Tune the number of blocks per sequence for accuracy/performance purpose.
|
||||
bool getEnvMmhaMultiblockDebug()
|
||||
{
|
||||
|
||||
@ -20,6 +20,9 @@
|
||||
namespace tensorrt_llm::common
|
||||
{
|
||||
|
||||
// XQA kernels (optimized kernels for generation phase).
|
||||
bool forceXQAKernels();
|
||||
|
||||
// Tune the number of blocks per sequence for accuracy/performance purpose.
|
||||
bool getEnvMmhaMultiblockDebug();
|
||||
|
||||
|
||||
@ -860,5 +860,47 @@ bool invokeCheckRange(const T* buffer, const size_t size, T min, T max, bool* d_
|
||||
template bool invokeCheckRange<int>(
|
||||
const int* buffer, const size_t size, int min, int max, bool* d_within_range, cudaStream_t stream);
|
||||
|
||||
/*
|
||||
* Determine the total workspace size based on a vector containing multiple variable sizes.
|
||||
*/
|
||||
size_t calcAlignedSize(const std::vector<size_t>& sizes, const size_t ALIGN_BYTES)
|
||||
{
|
||||
const size_t ALIGN_MASK = ~(ALIGN_BYTES - 1);
|
||||
// Check ALIGN_BYTES is a power of 2
|
||||
assert((ALIGN_BYTES & (ALIGN_BYTES - 1)) == 0);
|
||||
|
||||
size_t total = 0;
|
||||
for (auto sz : sizes)
|
||||
{
|
||||
total += (sz + ALIGN_BYTES - 1) & ALIGN_MASK;
|
||||
}
|
||||
|
||||
// We add extra "ALIGN_BYTES - 1" bytes in case the start address passed to the function calcAlignedPointers() is
|
||||
// not aligned.
|
||||
return total + ALIGN_BYTES - 1;
|
||||
}
|
||||
|
||||
/*
|
||||
* Given the address of the workspace and the vector containing multiple variable sizes, calculate the start addresses
|
||||
* of each variable.
|
||||
*/
|
||||
void calcAlignedPointers(
|
||||
std::vector<void*>& outPtrs, const void* p, const std::vector<size_t>& sizes, size_t ALIGN_BYTES)
|
||||
{
|
||||
const size_t ALIGN_MASK = ~(ALIGN_BYTES - 1);
|
||||
// Check ALIGN_BYTES is a power of 2
|
||||
assert((ALIGN_BYTES & (ALIGN_BYTES - 1)) == 0);
|
||||
|
||||
// In case the start address is not aligned
|
||||
char* ptr = reinterpret_cast<char*>((reinterpret_cast<size_t>(p) + ALIGN_BYTES - 1) & ALIGN_MASK);
|
||||
|
||||
outPtrs.reserve(sizes.size());
|
||||
for (auto sz : sizes)
|
||||
{
|
||||
outPtrs.push_back(ptr);
|
||||
ptr += (sz + ALIGN_BYTES - 1) & ALIGN_MASK;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace common
|
||||
} // namespace tensorrt_llm
|
||||
|
||||
@ -258,5 +258,8 @@ size_t cuda_datatype_size(TRTLLMCudaDataType dt);
|
||||
template <typename T>
|
||||
bool invokeCheckRange(const T* buffer, const size_t size, T min, T max, bool* d_within_range, cudaStream_t stream);
|
||||
|
||||
size_t calcAlignedSize(const std::vector<size_t>& sizes, size_t ALIGN_BYTES = 256);
|
||||
void calcAlignedPointers(
|
||||
std::vector<void*>& outPtrs, const void* p, const std::vector<size_t>& sizes, size_t ALIGN_BYTES = 256);
|
||||
} // namespace common
|
||||
} // namespace tensorrt_llm
|
||||
|
||||
@ -0,0 +1,542 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*!
|
||||
\file
|
||||
\brief Based on cutlass/include/cutlass/gemm/kernel/gemm_grouped.h
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <limits>
|
||||
#include <numeric>
|
||||
#include <vector>
|
||||
|
||||
#include "cutlass/arch/arch.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/device_kernel.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.h"
|
||||
#include "cutlass/gemm/threadblock/threadblock_swizzle.h"
|
||||
|
||||
#include "cutlass/gemm/device/default_gemm_configuration.h"
|
||||
#include "cutlass/gemm/kernel/default_gemm_universal.h"
|
||||
|
||||
#include "cutlass/trace.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass
|
||||
{
|
||||
namespace gemm
|
||||
{
|
||||
namespace device
|
||||
{
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename T_IN, typename T_OUT>
|
||||
__global__ void splitkReduction(T_OUT** out_tensor, const T_IN* in_tensor, const GemmCoord* problem_sizes, int splitk,
|
||||
int64_t* splitk_buffer_offsets)
|
||||
{
|
||||
// in_tensor: [problem_idx, k_partition, hidden_size]
|
||||
// Note that different requests of in_tensor might have different hidden_size (=m*n)
|
||||
// so, we need to use splitk_buffer_offsets.
|
||||
// out_tensor: problem_idx * [hidden_size]
|
||||
|
||||
const int problem_idx = blockIdx.y;
|
||||
GemmCoord problem = problem_sizes[problem_idx];
|
||||
const int hidden_size = problem.m() * problem.n();
|
||||
const T_IN* in_tensor_ = in_tensor + splitk_buffer_offsets[problem_idx] * splitk;
|
||||
T_OUT* out_tensor_ = out_tensor[problem_idx];
|
||||
|
||||
for (int i = threadIdx.x + blockIdx.x * blockDim.x; i < hidden_size; i += blockDim.x * gridDim.x)
|
||||
{
|
||||
float sum = 0.0f;
|
||||
for (int k_idx = 0; k_idx < splitk; k_idx++)
|
||||
{
|
||||
sum += (float) in_tensor_[k_idx * hidden_size + i];
|
||||
}
|
||||
out_tensor_[i] = (T_OUT) (sum);
|
||||
}
|
||||
}
|
||||
|
||||
/// GEMM Grouped
|
||||
template <typename BaseKernel_>
|
||||
class BaseSplitkGrouped
|
||||
{
|
||||
public:
|
||||
using BaseKernel = BaseKernel_;
|
||||
|
||||
using ElementA = typename BaseKernel::ElementA;
|
||||
using LayoutA = typename BaseKernel::LayoutA;
|
||||
using TensorRefA = TensorRef<ElementA const, LayoutA>;
|
||||
static ComplexTransform const kTransformA = BaseKernel::kTransformA;
|
||||
static int const kAlignmentA = BaseKernel::kAlignmentA;
|
||||
|
||||
using ElementB = typename BaseKernel::ElementB;
|
||||
using LayoutB = typename BaseKernel::LayoutB;
|
||||
using TensorRefB = TensorRef<ElementB const, LayoutB>;
|
||||
static ComplexTransform const kTransformB = BaseKernel::kTransformB;
|
||||
static int const kAlignmentB = BaseKernel::kAlignmentB;
|
||||
|
||||
using ElementC = typename BaseKernel::ElementC;
|
||||
using LayoutC = typename BaseKernel::LayoutC;
|
||||
using TensorRefC = TensorRef<ElementC const, LayoutC>;
|
||||
using TensorRefD = TensorRef<ElementC, LayoutC>;
|
||||
static int const kAlignmentC = BaseKernel::kAlignmentC;
|
||||
|
||||
using ElementAccumulator = typename BaseKernel::Mma::Policy::Operator::ElementC;
|
||||
|
||||
using EpilogueOutputOp = typename BaseKernel::EpilogueOutputOp;
|
||||
using ThreadblockSwizzle = typename threadblock::GemmSplitKHorizontalThreadblockSwizzle;
|
||||
|
||||
using Operator = typename BaseKernel::Operator;
|
||||
using WarpMmaOperator = typename BaseKernel::Mma::Policy::Operator;
|
||||
|
||||
using ArchMmaOperator = typename WarpMmaOperator::ArchMmaOperator;
|
||||
using MathOperator = typename WarpMmaOperator::MathOperator;
|
||||
using OperatorClass = typename WarpMmaOperator::OperatorClass;
|
||||
using ArchTag = typename WarpMmaOperator::ArchTag;
|
||||
using ThreadblockShape = typename BaseKernel::Mma::Shape;
|
||||
using WarpShape = typename BaseKernel::WarpShape;
|
||||
using InstructionShape = typename BaseKernel::InstructionShape;
|
||||
static int const kStages = BaseKernel::Mma::kStages;
|
||||
|
||||
/// Argument structure
|
||||
using Arguments = typename BaseKernel::Arguments;
|
||||
|
||||
using ProblemInfo = typename BaseKernel::ProblemVisitor::ProblemInfo;
|
||||
|
||||
protected:
|
||||
/// Kernel parameters object
|
||||
typename BaseKernel::Params gemm_params_;
|
||||
|
||||
private:
|
||||
/// Get the number of tiles across all problems in a group
|
||||
static int32_t group_tile_count(const cutlass::gemm::GemmCoord* problem_sizes_ptr, int problem_count)
|
||||
{
|
||||
int32_t tiles = 0;
|
||||
for (int32_t i = 0; i < problem_count; ++i)
|
||||
{
|
||||
cutlass::gemm::GemmCoord problem = problem_sizes_ptr[i];
|
||||
BaseKernel::ProblemVisitor::possibly_transpose_problem(problem);
|
||||
tiles += problem_tile_count(problem);
|
||||
}
|
||||
return tiles;
|
||||
}
|
||||
|
||||
/// Copy from `data` to `workspace`
|
||||
Status copy_to_workspace(void* workspace, void* data, size_t bytes)
|
||||
{
|
||||
cudaError_t cuda_error = cudaMemcpy(workspace, data, bytes, cudaMemcpyHostToDevice);
|
||||
if (cuda_error != cudaSuccess)
|
||||
{
|
||||
// Call cudaGetLastError() to clear the error bit
|
||||
cuda_error = cudaGetLastError();
|
||||
CUTLASS_TRACE_HOST(" cudaMemcpy() returned error " << cudaGetErrorString(cuda_error));
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
/// Precomputes scheduling information for the grouped GEMM
|
||||
Status precompute(Arguments const& args, int32_t tile_count, void* workspace)
|
||||
{
|
||||
size_t workspace_bytes = get_workspace_size(args);
|
||||
std::vector<uint8_t> host_workspace(workspace_bytes);
|
||||
BaseKernel::ProblemVisitor::host_precompute(
|
||||
args.host_problem_sizes, args.problem_count, args.threadblock_count, (void*) host_workspace.data());
|
||||
return copy_to_workspace(workspace, host_workspace.data(), workspace_bytes);
|
||||
}
|
||||
|
||||
/// Reorder `data` according to `indices`
|
||||
template <typename T>
|
||||
static void reorder_array(T* data, const std::vector<size_t>& indices)
|
||||
{
|
||||
// For now, simply create a copy of the data and then copy over to the original.
|
||||
std::vector<T> copy(indices.size());
|
||||
for (size_t i = 0; i < indices.size(); ++i)
|
||||
{
|
||||
copy.at(i) = data[indices[i]];
|
||||
}
|
||||
|
||||
memcpy(data, copy.data(), indices.size() * sizeof(T));
|
||||
}
|
||||
|
||||
public:
|
||||
/// Constructs the GEMM.
|
||||
BaseSplitkGrouped() {}
|
||||
|
||||
/// Determines whether the GEMM can execute the given problem.
|
||||
static Status can_implement(Arguments const& args)
|
||||
{
|
||||
|
||||
return BaseKernel::can_implement(args);
|
||||
}
|
||||
|
||||
/// Get the number of tiles in a problem
|
||||
static int32_t problem_tile_count(cutlass::gemm::GemmCoord const& problem)
|
||||
{
|
||||
auto grid = BaseKernel::ProblemVisitor::grid_shape(problem);
|
||||
return BaseKernel::ProblemVisitor::tile_count(grid);
|
||||
}
|
||||
|
||||
/// Get the number of tiles across all problems in a group
|
||||
static int32_t group_tile_count(Arguments const& args)
|
||||
{
|
||||
if (args.host_problem_sizes == nullptr)
|
||||
{
|
||||
CUTLASS_TRACE_HOST("Received nullptr for `args.host_problem_sizes");
|
||||
return -1;
|
||||
}
|
||||
|
||||
return group_tile_count(args.host_problem_sizes, args.problem_count);
|
||||
}
|
||||
|
||||
/// Gets the workspace size
|
||||
static size_t get_workspace_size(Arguments const& args)
|
||||
{
|
||||
size_t total_mn = 0;
|
||||
for (int i = 0; i < args.problem_count; i++)
|
||||
{
|
||||
total_mn += args.host_problem_sizes[i].m() * args.host_problem_sizes[i].n();
|
||||
}
|
||||
size_t workSpaceSize = total_mn * sizeof(ElementAccumulator) * args.split_k_slices;
|
||||
|
||||
if (BaseKernel::ProblemVisitor::kRequiresPrecomputation)
|
||||
{
|
||||
workSpaceSize += BaseKernel::ProblemVisitor::get_workspace_size(
|
||||
args.host_problem_sizes, args.problem_count, args.threadblock_count);
|
||||
}
|
||||
return workSpaceSize;
|
||||
}
|
||||
|
||||
/// Computes the grid shape
|
||||
static dim3 get_grid_shape(Arguments const& args)
|
||||
{
|
||||
|
||||
return dim3(args.threadblock_count, 1, 1);
|
||||
}
|
||||
|
||||
/// Computes the maximum number of active blocks per multiprocessor
|
||||
static int maximum_active_blocks(int smem_capacity = -1)
|
||||
{
|
||||
|
||||
CUTLASS_TRACE_HOST("BaseSplitkGrouped::maximum_active_blocks()");
|
||||
|
||||
int smem_size = int(sizeof(typename BaseKernel::SharedStorage));
|
||||
|
||||
CUTLASS_TRACE_HOST(" smem_size: " << smem_size << " bytes");
|
||||
|
||||
cudaError_t result;
|
||||
if (smem_size > (48 << 10))
|
||||
{
|
||||
result = cudaFuncSetAttribute(Kernel<BaseKernel>, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
|
||||
|
||||
if (result != cudaSuccess)
|
||||
{
|
||||
// Call cudaGetLastError() to clear the error bit
|
||||
result = cudaGetLastError();
|
||||
CUTLASS_TRACE_HOST(" cudaFuncSetAttribute() returned error " << cudaGetErrorString(result));
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
int max_active_blocks = -1;
|
||||
result = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
&max_active_blocks, Kernel<BaseKernel>, BaseKernel::kThreadCount, smem_size);
|
||||
|
||||
if (result != cudaSuccess)
|
||||
{
|
||||
// Call cudaGetLastError() to clear the error bit
|
||||
result = cudaGetLastError();
|
||||
CUTLASS_TRACE_HOST(
|
||||
" cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error " << cudaGetErrorString(result));
|
||||
return -1;
|
||||
}
|
||||
|
||||
CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks);
|
||||
return max_active_blocks;
|
||||
}
|
||||
|
||||
/// Sorts each pointer passed in according to the indices that sort
|
||||
/// `problem_sizes_ptr` in descending order of problem-K dimension.
|
||||
static void sort_problems(int problem_count, cutlass::gemm::GemmCoord* problem_sizes_ptr, int64_t* lda_host_ptr,
|
||||
int64_t* ldb_host_ptr, int64_t* ldc_host_ptr, int64_t* ldd_host_ptr, int64_t* offset_A_ptr,
|
||||
int64_t* offset_B_ptr, int64_t* offset_C_ptr, int64_t* offset_D_ptr)
|
||||
{
|
||||
std::vector<size_t> indices(problem_count);
|
||||
std::iota(indices.begin(), indices.end(), 0);
|
||||
std::stable_sort(indices.begin(), indices.end(),
|
||||
[&problem_sizes_ptr](size_t i, size_t j) { return problem_sizes_ptr[i].k() > problem_sizes_ptr[j].k(); });
|
||||
|
||||
reorder_array(problem_sizes_ptr, indices);
|
||||
reorder_array(lda_host_ptr, indices);
|
||||
reorder_array(ldb_host_ptr, indices);
|
||||
reorder_array(ldc_host_ptr, indices);
|
||||
reorder_array(ldd_host_ptr, indices);
|
||||
reorder_array(offset_A_ptr, indices);
|
||||
reorder_array(offset_B_ptr, indices);
|
||||
reorder_array(offset_C_ptr, indices);
|
||||
reorder_array(offset_D_ptr, indices);
|
||||
}
|
||||
|
||||
/// Computes the number of threadblocks to launch for the grouped kernel
|
||||
static int sufficient(
|
||||
const cutlass::gemm::GemmCoord* problem_sizes_ptr = nullptr, int problem_count = 0, int available_sm_count = -1)
|
||||
{
|
||||
// Determine the number of blocks that would be launched to fill up a single
|
||||
// wave on the GPU with each SM having maximum occupancy.
|
||||
int device_idx;
|
||||
cudaError_t result = cudaGetDevice(&device_idx);
|
||||
if (result != cudaSuccess)
|
||||
{
|
||||
// Call cudaGetLastError() to clear the error bit
|
||||
result = cudaGetLastError();
|
||||
CUTLASS_TRACE_HOST(" cudaGetDevice() returned error " << cudaGetErrorString(result));
|
||||
return 0;
|
||||
}
|
||||
|
||||
int multiprocessor_count;
|
||||
result = cudaDeviceGetAttribute(&multiprocessor_count, cudaDevAttrMultiProcessorCount, device_idx);
|
||||
if (result != cudaSuccess)
|
||||
{
|
||||
CUTLASS_TRACE_HOST(" cudaDeviceGetAttribute() returned error " << cudaGetErrorString(result));
|
||||
return 0;
|
||||
}
|
||||
|
||||
bool override_sm_count = (available_sm_count < 0 || available_sm_count > multiprocessor_count);
|
||||
if (override_sm_count)
|
||||
{
|
||||
available_sm_count = multiprocessor_count;
|
||||
}
|
||||
|
||||
int max_active_blocks = maximum_active_blocks();
|
||||
if (max_active_blocks <= 0)
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
|
||||
int occupancy_based_block_count = available_sm_count * max_active_blocks;
|
||||
|
||||
if (problem_sizes_ptr == nullptr || problem_count == 0)
|
||||
{
|
||||
return occupancy_based_block_count;
|
||||
}
|
||||
|
||||
int total_tiles = group_tile_count(problem_sizes_ptr, problem_count);
|
||||
|
||||
// If the group contains a single problem, launching the exact number of
|
||||
// threadblocks needed to cover the problem minimizes the work performed
|
||||
// per threadblock in finding the next tile to compute. We return total_tiles
|
||||
// unless the user has provided the SM count.
|
||||
if (problem_count == 1 && override_sm_count)
|
||||
{
|
||||
return total_tiles;
|
||||
}
|
||||
|
||||
// Choose between the full wave of threadblocks and the tile count. If there
|
||||
// are fewer tiles in the group than threadblocks in the full wave, only
|
||||
// some threadblocks will be assigned tiles. Those threadblocks
|
||||
// which are not assigned tiles still need to perform the work of iterating through
|
||||
// problem sizes to determine that they have no work to do. This competes for cycles
|
||||
// with those threadblocks that are assigned tiles to compute.
|
||||
return std::min(total_tiles, occupancy_based_block_count);
|
||||
}
|
||||
|
||||
/// Initializes GEMM state from arguments.
|
||||
Status initialize(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr)
|
||||
{
|
||||
|
||||
CUTLASS_TRACE_HOST("BaseSplitkGrouped::initialize() - workspace "
|
||||
<< workspace << ", stream: " << (stream ? "non-null" : "null"));
|
||||
|
||||
// Workspace
|
||||
size_t workspace_bytes = get_workspace_size(args);
|
||||
|
||||
if (workspace_bytes && !workspace)
|
||||
{
|
||||
return Status::kErrorWorkspaceNull;
|
||||
}
|
||||
|
||||
if (BaseKernel::ProblemVisitor::kRequiresPrecomputation)
|
||||
{
|
||||
int32_t tile_count = group_tile_count(args);
|
||||
Status status = precompute(args, tile_count, workspace);
|
||||
if (status != Status::kSuccess)
|
||||
{
|
||||
return status;
|
||||
}
|
||||
|
||||
gemm_params_ = typename BaseKernel::Params(args, workspace, tile_count);
|
||||
}
|
||||
else
|
||||
{
|
||||
gemm_params_ = typename BaseKernel::Params(args, workspace);
|
||||
}
|
||||
|
||||
// Specify shared memory capacity for kernel.
|
||||
int smem_size = int(sizeof(typename BaseKernel::SharedStorage));
|
||||
|
||||
if (smem_size >= (48 << 10))
|
||||
{
|
||||
cudaError_t result
|
||||
= cudaFuncSetAttribute(Kernel<BaseKernel>, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
|
||||
|
||||
if (result != cudaSuccess)
|
||||
{
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
}
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
/// Lightweight update given a subset of arguments
|
||||
Status update(Arguments const& args, void* workspace = nullptr)
|
||||
{
|
||||
|
||||
size_t workspace_bytes = get_workspace_size(args);
|
||||
|
||||
if (workspace_bytes && !workspace)
|
||||
{
|
||||
return Status::kErrorWorkspaceNull;
|
||||
}
|
||||
|
||||
if (BaseKernel::ProblemVisitor::kRequiresPrecomputation)
|
||||
{
|
||||
int32_t tile_count = group_tile_count(args);
|
||||
Status status = precompute(args, tile_count, workspace);
|
||||
if (status != Status::kSuccess)
|
||||
{
|
||||
return status;
|
||||
}
|
||||
|
||||
gemm_params_.update(args, workspace, tile_count);
|
||||
}
|
||||
else
|
||||
{
|
||||
gemm_params_.update(args, workspace);
|
||||
}
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
/// Runs the kernel using initialized state.
|
||||
Status run(cudaStream_t stream = nullptr)
|
||||
{
|
||||
if (!gemm_params_.problem_visitor.problem_count)
|
||||
{
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
//
|
||||
// Launch kernel
|
||||
//
|
||||
|
||||
// Launch splitk grouped gemm
|
||||
{
|
||||
dim3 grid(gemm_params_.threadblock_count, 1, gemm_params_.split_k_slices);
|
||||
dim3 block(BaseKernel::kThreadCount, 1, 1);
|
||||
|
||||
int smem_size = int(sizeof(typename BaseKernel::SharedStorage));
|
||||
cutlass::Kernel<BaseKernel><<<grid, block, smem_size, stream>>>(gemm_params_);
|
||||
|
||||
cudaError_t result = cudaGetLastError();
|
||||
|
||||
if (result != cudaSuccess)
|
||||
{
|
||||
CUTLASS_TRACE_HOST(" grid launch failed with error " << cudaGetErrorString(result));
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
}
|
||||
|
||||
// Launch splitkReduction
|
||||
{
|
||||
dim3 grid(32, gemm_params_.problem_visitor.problem_count);
|
||||
dim3 block(256);
|
||||
splitkReduction<<<grid, block, 0, stream>>>(gemm_params_.ptr_D, gemm_params_.ptr_D_split,
|
||||
gemm_params_.problem_visitor.problem_sizes, gemm_params_.split_k_slices,
|
||||
gemm_params_.splitk_buffer_offsets);
|
||||
|
||||
cudaError_t result = cudaGetLastError();
|
||||
|
||||
if (result != cudaSuccess)
|
||||
{
|
||||
CUTLASS_TRACE_HOST(" grid launch failed with error " << cudaGetErrorString(result));
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
}
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
/// Runs the kernel using initialized state.
|
||||
Status operator()(cudaStream_t stream = nullptr)
|
||||
{
|
||||
return run(stream);
|
||||
}
|
||||
|
||||
/// Initializes and runs the kernel.
|
||||
Status operator()(Arguments const& args, void* workspace, cudaStream_t stream = nullptr)
|
||||
{
|
||||
|
||||
Status status = initialize(args, workspace, stream);
|
||||
|
||||
if (status == Status::kSuccess)
|
||||
{
|
||||
status = run(stream);
|
||||
}
|
||||
|
||||
return status;
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// GEMM Grouped
|
||||
template <typename GemmKernel_>
|
||||
class SplitkGemmGrouped : public BaseSplitkGrouped<GemmKernel_>
|
||||
{
|
||||
public:
|
||||
using GemmKernel = GemmKernel_;
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace device
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -0,0 +1,207 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief
|
||||
Default kernel-level GEMM definitions combine threadblock-scoped matrix multiply-add with
|
||||
the appropriate threadblock-scoped epilogue.
|
||||
|
||||
Note, CUTLASS epilogues universally target row-major outputs. Column-major outputs are
|
||||
accommodated by exchanging A and B operands and assuming transposed layouts. Partial
|
||||
specializations here choose 'device::GemmTransposed' to implement this functionality.
|
||||
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cutlass/complex.h"
|
||||
#include "cutlass/layout/matrix.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
#include "cutlass/gemm/device/default_gemm_configuration.h"
|
||||
#include "cutlass/gemm/kernel/default_gemm.h"
|
||||
#include "cutlass/gemm/kernel/default_gemm_complex.h"
|
||||
#include "cutlass/gemm/kernel/gemm_transpose_operands.h"
|
||||
|
||||
#include "cutlass/layout/permute.h"
|
||||
|
||||
#include "splitk_gemm_grouped.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass
|
||||
{
|
||||
namespace gemm
|
||||
{
|
||||
namespace kernel
|
||||
{
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
/// Element type for A matrix operand
|
||||
typename ElementA_,
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA_,
|
||||
/// Complex elementwise transformation on A operand
|
||||
ComplexTransform TransformA,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int kAlignmentA,
|
||||
/// Element type for B matrix operand
|
||||
typename ElementB_,
|
||||
/// Layout type for B matrix operand
|
||||
typename LayoutB_,
|
||||
/// Complex elementwise transformation on B operand
|
||||
ComplexTransform TransformB,
|
||||
/// Access granularity of B matrix in units of elements
|
||||
int kAlignmentB,
|
||||
/// Element type for C and D matrix operands
|
||||
typename ElementC_,
|
||||
/// Layout type for C and D matrix operands
|
||||
typename LayoutC_,
|
||||
/// Element type for internal accumulation
|
||||
typename ElementAccumulator,
|
||||
/// Operator class tag
|
||||
typename OperatorClass,
|
||||
/// Tag indicating architecture to tune for
|
||||
typename ArchTag,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename InstructionShape,
|
||||
/// Epilogue output operator
|
||||
typename EpilogueOutputOp,
|
||||
/// Threadblock-level swizzling operator
|
||||
typename ThreadblockSwizzle,
|
||||
/// Number of stages used in the pipelined mainloop
|
||||
int Stages,
|
||||
/// Whether the schedule of problems to visit has been precomputed
|
||||
GroupScheduleMode GroupScheduleMode_ = GroupScheduleMode::kDeviceOnly,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator = typename device::DefaultGemmConfiguration<OperatorClass, ArchTag, ElementA_, ElementB_,
|
||||
ElementC_, ElementAccumulator>::Operator,
|
||||
/// Use zfill or predicate for out-of-bound cp.async
|
||||
SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone,
|
||||
/// Permute result D
|
||||
typename PermuteDLayout = layout::NoPermute,
|
||||
///
|
||||
typename Enable = void>
|
||||
struct DefaultSplitkGemmGrouped;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
//
|
||||
// Real-valued GEMM kernels
|
||||
//
|
||||
|
||||
template <
|
||||
/// Element type for A matrix operand
|
||||
typename ElementA,
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int kAlignmentA,
|
||||
/// Element type for B matrix operand
|
||||
typename ElementB,
|
||||
/// Layout type for B matrix operand
|
||||
typename LayoutB,
|
||||
/// Access granularity of B matrix in units of elements
|
||||
int kAlignmentB,
|
||||
/// Element type for C and D matrix operands
|
||||
typename ElementC,
|
||||
/// Layout type for C and D matrix operands
|
||||
typename LayoutC,
|
||||
/// Element type for internal accumulation
|
||||
typename ElementAccumulator,
|
||||
/// Operator class tag
|
||||
typename OperatorClass,
|
||||
/// Tag indicating architecture to tune for
|
||||
typename ArchTag,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename InstructionShape,
|
||||
/// Epilogue output operator
|
||||
typename EpilogueOutputOp,
|
||||
/// Threadblock-level swizzling operator
|
||||
typename ThreadblockSwizzle,
|
||||
/// Number of stages used in the pipelined mainloop
|
||||
int Stages,
|
||||
/// Whether the schedule of problems to visit has been precomputed
|
||||
GroupScheduleMode GroupScheduleMode_,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator,
|
||||
/// Use zfill or predicate for out-of-bound cp.async
|
||||
SharedMemoryClearOption SharedMemoryClear,
|
||||
/// Permute result D
|
||||
typename PermuteDLayout>
|
||||
struct DefaultSplitkGemmGrouped<ElementA, LayoutA,
|
||||
ComplexTransform::kNone, // transform A
|
||||
kAlignmentA, ElementB, LayoutB,
|
||||
ComplexTransform::kNone, // transform B
|
||||
kAlignmentB, ElementC, LayoutC, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape,
|
||||
InstructionShape, EpilogueOutputOp, ThreadblockSwizzle, Stages, GroupScheduleMode_, Operator, SharedMemoryClear,
|
||||
PermuteDLayout, typename platform::enable_if<!cutlass::is_complex<ElementAccumulator>::value>::type>
|
||||
{
|
||||
|
||||
// If true, we must construct a 'transposed-and-exchanged' Mma operator.
|
||||
static bool const kInternalTranspose = platform::is_same<LayoutC, layout::ColumnMajor>::value;
|
||||
|
||||
using MapArguments = kernel::detail::MapArguments<ElementA, LayoutA, ComplexTransform::kNone, kAlignmentA, ElementB,
|
||||
LayoutB, ComplexTransform::kNone, kAlignmentB, LayoutC, kInternalTranspose>;
|
||||
|
||||
// Define the default GEMM kernel
|
||||
using DefaultGemmKernel = typename kernel::DefaultGemm<typename MapArguments::ElementA,
|
||||
typename MapArguments::LayoutA, MapArguments::kAlignmentA, typename MapArguments::ElementB,
|
||||
typename MapArguments::LayoutB, MapArguments::kAlignmentB, ElementC, typename MapArguments::LayoutC,
|
||||
ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp,
|
||||
ThreadblockSwizzle, Stages, true, Operator, SharedMemoryClear, false, /*GatherA*/
|
||||
false, /*GatherB*/
|
||||
false, /*ScatterD*/
|
||||
PermuteDLayout>::GemmKernel;
|
||||
|
||||
/// Define the kernel in terms of the default kernel
|
||||
using GemmKernel = kernel::SplitkGemmGrouped<typename DefaultGemmKernel::Mma, typename DefaultGemmKernel::Epilogue,
|
||||
ThreadblockSwizzle, GroupScheduleMode_, kInternalTranspose>;
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -0,0 +1,494 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief based on cutlass/include/cutlass/gemm/kernel/gemm_grouped.h
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/complex.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/fast_math.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/matrix_coord.h"
|
||||
#include "cutlass/semaphore.h"
|
||||
|
||||
#include "cutlass/gemm/kernel/gemm_grouped_problem_visitor.h"
|
||||
#include "cutlass/gemm/kernel/gemm_transpose_operands.h"
|
||||
#include "cutlass/layout/matrix.h"
|
||||
#include "cutlass/trace.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass
|
||||
{
|
||||
namespace gemm
|
||||
{
|
||||
namespace kernel
|
||||
{
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate
|
||||
typename Epilogue_, ///! Epilogue
|
||||
typename ThreadblockSwizzle_, ///! Threadblock swizzling function
|
||||
GroupScheduleMode GroupScheduleMode_, ///! Type of scheduling to perform
|
||||
bool Transposed = false>
|
||||
struct SplitkGemmGrouped
|
||||
{
|
||||
public:
|
||||
using Mma = Mma_;
|
||||
using Epilogue = Epilogue_;
|
||||
using EpilogueOutputOp = typename Epilogue::OutputOp;
|
||||
using ThreadblockSwizzle = ThreadblockSwizzle_;
|
||||
static GroupScheduleMode const kGroupScheduleMode = GroupScheduleMode_;
|
||||
static bool const kTransposed = Transposed;
|
||||
|
||||
// Optional transpose
|
||||
using MapArguments = kernel::detail::MapArguments<typename Mma::IteratorA::Element, typename Mma::IteratorA::Layout,
|
||||
Mma::kTransformA, Mma::IteratorA::AccessType::kElements, typename Mma::IteratorB::Element,
|
||||
typename Mma::IteratorB::Layout, Mma::kTransformB, Mma::IteratorB::AccessType::kElements, typename Mma::LayoutC,
|
||||
kTransposed>;
|
||||
|
||||
// Public-facing type definitions related to operand element type, layout, and complex conjugate
|
||||
// operation. Must interact with the 'kTransposed' notion.
|
||||
using ElementA = typename MapArguments::ElementA;
|
||||
using LayoutA = typename MapArguments::LayoutA;
|
||||
using ElementB = typename MapArguments::ElementB;
|
||||
using LayoutB = typename MapArguments::LayoutB;
|
||||
using ElementC = typename Epilogue::OutputTileIterator::Element;
|
||||
using LayoutC = typename MapArguments::LayoutC;
|
||||
|
||||
using ElementFinalOutput = typename MapArguments::ElementA;
|
||||
|
||||
static ComplexTransform const kTransformA = MapArguments::kTransformA;
|
||||
static ComplexTransform const kTransformB = MapArguments::kTransformB;
|
||||
|
||||
// Type definitions about the mainloop.
|
||||
using Operator = typename Mma::Operator;
|
||||
using OperatorClass = typename Mma::Operator::OperatorClass;
|
||||
using ThreadblockShape = typename Mma::Shape;
|
||||
using WarpShape = typename Mma::Operator::Shape;
|
||||
using InstructionShape = typename Mma::Policy::Operator::InstructionShape;
|
||||
using ArchTag = typename Mma::ArchTag;
|
||||
|
||||
static int const kStages = Mma::kStages;
|
||||
static int const kAlignmentA = MapArguments::kAlignmentA;
|
||||
static int const kAlignmentB = MapArguments::kAlignmentB;
|
||||
static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess;
|
||||
|
||||
/// Warp count (concept: GemmShape)
|
||||
using WarpCount = typename Mma::WarpCount;
|
||||
static int const kThreadCount = 32 * WarpCount::kCount;
|
||||
|
||||
using ProblemVisitor
|
||||
= GemmGroupedProblemVisitor<ThreadblockShape, kGroupScheduleMode, kThreadCount, kThreadCount, kTransposed>;
|
||||
|
||||
//
|
||||
// Structures
|
||||
//
|
||||
|
||||
/// Argument structure
|
||||
struct Arguments
|
||||
{
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
GemmCoord* problem_sizes;
|
||||
int problem_count;
|
||||
int threadblock_count;
|
||||
|
||||
typename EpilogueOutputOp::Params output_op;
|
||||
|
||||
ElementA** ptr_A;
|
||||
ElementB** ptr_B;
|
||||
ElementFinalOutput** ptr_C;
|
||||
ElementFinalOutput** ptr_D;
|
||||
|
||||
typename LayoutA::Stride::LongIndex* lda;
|
||||
typename LayoutB::Stride::LongIndex* ldb;
|
||||
typename LayoutC::Stride::LongIndex* ldc;
|
||||
typename LayoutC::Stride::LongIndex* ldd;
|
||||
|
||||
// Only used by device-level operator
|
||||
GemmCoord* host_problem_sizes;
|
||||
|
||||
// splitK
|
||||
int split_k_slices;
|
||||
int64_t* splitk_buffer_offsets;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Default ctor
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments()
|
||||
: problem_count(0)
|
||||
, threadblock_count(0)
|
||||
, ptr_A(nullptr)
|
||||
, ptr_B(nullptr)
|
||||
, ptr_C(nullptr)
|
||||
, ptr_D(nullptr)
|
||||
, lda(nullptr)
|
||||
, ldb(nullptr)
|
||||
, ldc(nullptr)
|
||||
, ldd(nullptr)
|
||||
, host_problem_sizes(nullptr)
|
||||
, split_k_slices(1)
|
||||
, splitk_buffer_offsets(nullptr)
|
||||
{
|
||||
}
|
||||
|
||||
/// Ctor
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments(GemmCoord* problem_sizes, int problem_count, int threadblock_count,
|
||||
typename EpilogueOutputOp::Params output_op, ElementA** ptr_A, ElementB** ptr_B, ElementFinalOutput** ptr_C,
|
||||
ElementFinalOutput** ptr_D, typename LayoutA::Stride::LongIndex* lda,
|
||||
typename LayoutB::Stride::LongIndex* ldb, typename LayoutC::Stride::LongIndex* ldc,
|
||||
typename LayoutC::Stride::LongIndex* ldd, GemmCoord* host_problem_sizes, int split_k_slices,
|
||||
int64_t* splitk_buffer_offsets)
|
||||
: problem_sizes(problem_sizes)
|
||||
, problem_count(problem_count)
|
||||
, threadblock_count(threadblock_count)
|
||||
, output_op(output_op)
|
||||
, ptr_A(ptr_A)
|
||||
, ptr_B(ptr_B)
|
||||
, ptr_C(ptr_C)
|
||||
, ptr_D(ptr_D)
|
||||
, lda(lda)
|
||||
, ldb(ldb)
|
||||
, ldc(ldc)
|
||||
, ldd(ldd)
|
||||
, host_problem_sizes(host_problem_sizes)
|
||||
, split_k_slices(split_k_slices)
|
||||
, splitk_buffer_offsets(splitk_buffer_offsets)
|
||||
{
|
||||
}
|
||||
};
|
||||
|
||||
//
|
||||
// Structure for precomputing values in host memory and passing to kernels
|
||||
//
|
||||
|
||||
/// Parameters structure
|
||||
struct Params
|
||||
{
|
||||
|
||||
typename ProblemVisitor::Params problem_visitor;
|
||||
int threadblock_count;
|
||||
|
||||
typename EpilogueOutputOp::Params output_op;
|
||||
|
||||
ElementA** ptr_A;
|
||||
ElementB** ptr_B;
|
||||
ElementFinalOutput** ptr_C;
|
||||
ElementFinalOutput** ptr_D;
|
||||
ElementC* ptr_C_split;
|
||||
ElementC* ptr_D_split;
|
||||
|
||||
typename LayoutA::Stride::LongIndex* lda;
|
||||
typename LayoutB::Stride::LongIndex* ldb;
|
||||
typename LayoutC::Stride::LongIndex* ldc;
|
||||
typename LayoutC::Stride::LongIndex* ldd;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
// splitk
|
||||
GemmCoord grid_tiled_shape;
|
||||
int swizzle_log_tile;
|
||||
int gemm_k_size;
|
||||
GemmCoord* host_problem_sizes;
|
||||
int split_k_slices;
|
||||
int64_t* splitk_buffer_offsets;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params()
|
||||
: ptr_A(nullptr)
|
||||
, ptr_B(nullptr)
|
||||
, ptr_C(nullptr)
|
||||
, ptr_D(nullptr)
|
||||
, ptr_C_split(nullptr)
|
||||
, ptr_D_split(nullptr)
|
||||
, lda(nullptr)
|
||||
, ldb(nullptr)
|
||||
, ldc(nullptr)
|
||||
, ldd(nullptr)
|
||||
, swizzle_log_tile(0)
|
||||
, gemm_k_size(0)
|
||||
, host_problem_sizes(nullptr)
|
||||
, split_k_slices(1)
|
||||
, splitk_buffer_offsets(nullptr)
|
||||
{
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(Arguments const& args, void* workspace = nullptr, int tile_count = 0)
|
||||
: problem_visitor(args.problem_sizes, args.problem_count, workspace, tile_count)
|
||||
, host_problem_sizes(args.host_problem_sizes)
|
||||
, threadblock_count(args.threadblock_count)
|
||||
, output_op(args.output_op)
|
||||
, ptr_A(args.ptr_A)
|
||||
, ptr_B(args.ptr_B)
|
||||
, ptr_C(args.ptr_C)
|
||||
, ptr_D(args.ptr_D)
|
||||
, ptr_C_split((ElementC*) workspace)
|
||||
, ptr_D_split((ElementC*) workspace)
|
||||
, lda(args.lda)
|
||||
, ldb(args.ldb)
|
||||
, ldc(args.ldc)
|
||||
, ldd(args.ldd)
|
||||
, split_k_slices(args.split_k_slices)
|
||||
, splitk_buffer_offsets(args.splitk_buffer_offsets)
|
||||
{
|
||||
// Determine grid shape
|
||||
ThreadblockSwizzle threadblock_swizzle;
|
||||
grid_tiled_shape = threadblock_swizzle.get_tiled_shape(args.host_problem_sizes[0],
|
||||
{ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, args.split_k_slices);
|
||||
swizzle_log_tile = ThreadblockSwizzle().get_log_tile(grid_tiled_shape);
|
||||
|
||||
// only support same k
|
||||
int full_gemm_k_iterations = args.host_problem_sizes[0].k() / Mma::Shape::kK;
|
||||
int gemm_k_iterations = full_gemm_k_iterations / grid_tiled_shape.k();
|
||||
|
||||
gemm_k_size = gemm_k_iterations * Mma::Shape::kK;
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
void update(Arguments const& args, void* workspace = nullptr, int tile_count = 0)
|
||||
{
|
||||
|
||||
problem_visitor =
|
||||
typename ProblemVisitor::Params(args.problem_sizes, args.problem_count, workspace, tile_count);
|
||||
threadblock_count = args.threadblock_count;
|
||||
output_op = args.output_op;
|
||||
ptr_A = args.ptr_A;
|
||||
ptr_B = args.ptr_B;
|
||||
ptr_C = args.ptr_C;
|
||||
ptr_D = args.ptr_D;
|
||||
ptr_C_split = workspace;
|
||||
ptr_D_split = workspace;
|
||||
|
||||
lda = args.lda;
|
||||
ldb = args.ldb;
|
||||
ldc = args.ldc;
|
||||
ldd = args.ldd;
|
||||
}
|
||||
};
|
||||
|
||||
/// Shared memory storage structure
|
||||
struct SharedStorage
|
||||
{
|
||||
union
|
||||
{
|
||||
typename Mma::SharedStorage main_loop;
|
||||
typename Epilogue::SharedStorage epilogue;
|
||||
} kernel;
|
||||
|
||||
// ProblemVisitor shared storage can't be overlapped with others
|
||||
typename ProblemVisitor::SharedStorage problem_visitor;
|
||||
};
|
||||
|
||||
public:
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
CUTLASS_DEVICE
|
||||
SplitkGemmGrouped() {}
|
||||
|
||||
/// Determines whether kernel satisfies alignment
|
||||
static Status can_implement(cutlass::gemm::GemmCoord const& problem_size)
|
||||
{
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
static Status can_implement(Arguments const& args)
|
||||
{
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
/// Executes one GEMM
|
||||
CUTLASS_DEVICE
|
||||
void operator()(Params const& params, SharedStorage& shared_storage)
|
||||
{
|
||||
|
||||
//
|
||||
// These types shadow the type-level definitions and support the ability to implement
|
||||
// a 'transposed' GEMM that computes the transposed problems.
|
||||
//
|
||||
using ElementA = typename Mma::IteratorA::Element;
|
||||
using LayoutA = typename Mma::IteratorA::Layout;
|
||||
using ElementB = typename Mma::IteratorB::Element;
|
||||
using LayoutB = typename Mma::IteratorB::Layout;
|
||||
using ElementC = typename Epilogue::OutputTileIterator::Element;
|
||||
using LayoutC = typename Epilogue::OutputTileIterator::Layout;
|
||||
|
||||
//
|
||||
// Problem visitor.
|
||||
//
|
||||
ProblemVisitor problem_visitor(params.problem_visitor, shared_storage.problem_visitor, blockIdx.x);
|
||||
|
||||
// Outer 'persistent' loop to iterate over tiles
|
||||
while (problem_visitor.next_tile())
|
||||
{
|
||||
|
||||
GemmCoord problem_size = problem_visitor.problem_size();
|
||||
int32_t problem_idx = problem_visitor.problem_index();
|
||||
int32_t threadblock_idx = int32_t(problem_visitor.threadblock_idx());
|
||||
|
||||
GemmCoord grid_shape = problem_visitor.grid_shape(problem_size);
|
||||
|
||||
// Load element pointers. Exchange pointers and strides if working on the transpose
|
||||
ElementA* ptr_A
|
||||
= reinterpret_cast<ElementA*>((kTransposed ? params.ptr_B[problem_idx] : params.ptr_A[problem_idx]));
|
||||
typename LayoutA::LongIndex ldm_A = (kTransposed ? params.ldb[problem_idx] : params.lda[problem_idx]);
|
||||
|
||||
ElementB* ptr_B
|
||||
= reinterpret_cast<ElementB*>((kTransposed ? params.ptr_A[problem_idx] : params.ptr_B[problem_idx]));
|
||||
typename LayoutB::LongIndex ldm_B = (kTransposed ? params.lda[problem_idx] : params.ldb[problem_idx]);
|
||||
|
||||
// Compute threadblock location
|
||||
ThreadblockSwizzle threadblock_swizzle;
|
||||
GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
|
||||
|
||||
cutlass::gemm::GemmCoord threadblock_offset(int(threadblock_idx / grid_shape.n()) * Mma::Shape::kM,
|
||||
int(threadblock_idx % grid_shape.n()) * Mma::Shape::kN, 0);
|
||||
|
||||
// Compute initial location in logical coordinates
|
||||
cutlass::MatrixCoord tb_offset_A{
|
||||
threadblock_offset.m(),
|
||||
threadblock_tile_offset.k() * params.gemm_k_size,
|
||||
};
|
||||
|
||||
cutlass::MatrixCoord tb_offset_B{threadblock_tile_offset.k() * params.gemm_k_size, threadblock_offset.n()};
|
||||
|
||||
// Problem size is a function of threadblock index in the K dimension
|
||||
int problem_size_k;
|
||||
if (threadblock_tile_offset.k() + 1 == params.grid_tiled_shape.k())
|
||||
{
|
||||
problem_size_k = problem_size.k();
|
||||
}
|
||||
else
|
||||
{
|
||||
problem_size_k = (threadblock_tile_offset.k() + 1) * params.gemm_k_size;
|
||||
}
|
||||
|
||||
// Compute threadblock-scoped matrix multiply-add
|
||||
int gemm_k_iterations = (problem_size_k - tb_offset_A.column() + Mma::Shape::kK - 1) / Mma::Shape::kK;
|
||||
|
||||
// Compute position within threadblock
|
||||
int thread_idx = threadIdx.x;
|
||||
|
||||
// Construct iterators to A and B operands
|
||||
typename Mma::IteratorA iterator_A(
|
||||
LayoutA(ldm_A), ptr_A, {problem_size.m(), problem_size_k}, thread_idx, tb_offset_A);
|
||||
|
||||
typename Mma::IteratorB iterator_B(
|
||||
LayoutB(ldm_B), ptr_B, {problem_size_k, problem_size.n()}, thread_idx, tb_offset_B);
|
||||
|
||||
typename Mma::FragmentC accumulators;
|
||||
|
||||
accumulators.clear();
|
||||
|
||||
// Broadcast the warp_id computed by lane 0 to ensure dependent code
|
||||
// is compiled as warp-uniform.
|
||||
int warp_idx = canonical_warp_idx_sync();
|
||||
|
||||
int lane_idx = threadIdx.x % 32;
|
||||
|
||||
//
|
||||
// Matrix multiply phase
|
||||
//
|
||||
|
||||
// Construct thread-scoped matrix multiply
|
||||
Mma mma(shared_storage.kernel.main_loop, thread_idx, warp_idx, lane_idx);
|
||||
|
||||
// Wait for all threads to finish their epilogue phases from the previous tile.
|
||||
__syncthreads();
|
||||
|
||||
// Compute threadblock-scoped matrix multiply-add
|
||||
mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators);
|
||||
|
||||
//
|
||||
// Epilogue
|
||||
//
|
||||
|
||||
EpilogueOutputOp output_op(params.output_op);
|
||||
|
||||
ElementC* ptr_C = params.ptr_C_split;
|
||||
ElementC* ptr_D = params.ptr_D_split;
|
||||
|
||||
LayoutC layout_C(params.ldc[problem_idx]);
|
||||
LayoutC layout_D(params.ldd[problem_idx]);
|
||||
|
||||
typename Epilogue::OutputTileIterator::Params params_C(layout_C);
|
||||
typename Epilogue::OutputTileIterator::Params params_D(layout_D);
|
||||
|
||||
// assume identity swizzle
|
||||
MatrixCoord threadblock_offset_C(threadblock_offset.m(), threadblock_offset.n());
|
||||
|
||||
// Tile iterator loading from source tensor.
|
||||
typename Epilogue::OutputTileIterator iterator_C(
|
||||
params_C, ptr_C, problem_size.mn(), thread_idx, threadblock_offset_C);
|
||||
|
||||
iterator_C.add_pointer_offset(problem_size.m() * problem_size.n() * threadblock_tile_offset.k()
|
||||
+ gridDim.z * params.splitk_buffer_offsets[problem_idx]);
|
||||
|
||||
// Tile iterator writing to destination tensor.
|
||||
typename Epilogue::OutputTileIterator iterator_D(
|
||||
params_D, ptr_D, problem_size.mn(), thread_idx, threadblock_offset_C);
|
||||
iterator_D.add_pointer_offset(problem_size.m() * problem_size.n() * threadblock_tile_offset.k()
|
||||
+ gridDim.z * params.splitk_buffer_offsets[problem_idx]);
|
||||
|
||||
Epilogue epilogue(shared_storage.kernel.epilogue, thread_idx, warp_idx, lane_idx);
|
||||
|
||||
// Execute the epilogue operator to update the destination tensor.
|
||||
epilogue(output_op, iterator_D, accumulators, iterator_C);
|
||||
|
||||
// Next tile
|
||||
problem_visitor.advance(gridDim.x);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -15,36 +15,18 @@
|
||||
*/
|
||||
|
||||
#include "customAllReduceKernels.h"
|
||||
#include "tensorrt_llm/common/cudaBf16Fallbacks.cuh"
|
||||
#include "tensorrt_llm/common/cudaTypeUtils.cuh"
|
||||
#include <tuple>
|
||||
|
||||
namespace tensorrt_llm::kernels
|
||||
{
|
||||
|
||||
using tensorrt_llm::common::hadd2;
|
||||
using tensorrt_llm::common::datatype_enum;
|
||||
using tensorrt_llm::common::divUp;
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
static inline __device__ uint32_t myHadd2(const uint32_t& a, const uint32_t& b)
|
||||
{
|
||||
uint32_t c;
|
||||
asm volatile("add.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b));
|
||||
return c;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
static inline __device__ uint32_t fadd(const uint32_t& a, const uint32_t& b)
|
||||
{
|
||||
uint32_t c;
|
||||
asm volatile("add.f32 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b));
|
||||
return c;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
static inline __device__ void st_flag_release(uint32_t& flag, uint32_t* flag_addr)
|
||||
{
|
||||
#if __CUDA_ARCH__ >= 700
|
||||
@ -69,80 +51,62 @@ static inline __device__ void ld_flag_acquire(uint32_t& flag, uint32_t* flag_add
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Type Converter that packs data format to 128 bits data type
|
||||
template <typename T>
|
||||
struct ARTypeConverter
|
||||
//
|
||||
using PackedFloat = union
|
||||
{
|
||||
using Type = uint4;
|
||||
int4 packed;
|
||||
float unpacked[4];
|
||||
};
|
||||
|
||||
using PackedHalf = union
|
||||
{
|
||||
int4 packed;
|
||||
half2 unpacked[4];
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct PackedOn16Bytes
|
||||
{
|
||||
};
|
||||
|
||||
template <>
|
||||
struct PackedOn16Bytes<float>
|
||||
{
|
||||
using Type = PackedFloat;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct PackedOn16Bytes<half>
|
||||
{
|
||||
using Type = PackedHalf;
|
||||
};
|
||||
|
||||
#ifdef ENABLE_BF16
|
||||
template <>
|
||||
struct ARTypeConverter<__nv_bfloat16>
|
||||
using PackedBFloat16 = union
|
||||
{
|
||||
using Type = bf168;
|
||||
int4 packed;
|
||||
__nv_bfloat162 unpacked[4];
|
||||
};
|
||||
|
||||
template <>
|
||||
struct PackedOn16Bytes<__nv_bfloat16>
|
||||
{
|
||||
using Type = PackedBFloat16;
|
||||
};
|
||||
#endif
|
||||
|
||||
// add two 128b data
|
||||
template <typename T_IN, typename T_COMP>
|
||||
inline __device__ T_IN add128b(T_IN a, T_IN b);
|
||||
|
||||
template <>
|
||||
inline __device__ uint4 add128b<uint4, uint16_t>(uint4 a, uint4 b)
|
||||
{
|
||||
uint4 c;
|
||||
c.x = myHadd2(a.x, b.x);
|
||||
c.y = myHadd2(a.y, b.y);
|
||||
c.z = myHadd2(a.z, b.z);
|
||||
c.w = myHadd2(a.w, b.w);
|
||||
return c;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline __device__ uint4 add128b<uint4, uint32_t>(uint4 a, uint4 b)
|
||||
{
|
||||
uint4 c;
|
||||
c.x = fadd(a.x, b.x);
|
||||
c.y = fadd(a.y, b.y);
|
||||
c.z = fadd(a.z, b.z);
|
||||
c.w = fadd(a.w, b.w);
|
||||
return c;
|
||||
}
|
||||
|
||||
#ifdef ENABLE_BF16
|
||||
template <>
|
||||
inline __device__ bf168 add128b<bf168, __nv_bfloat16>(bf168 a, bf168 b)
|
||||
{
|
||||
bf168 c;
|
||||
c.x = hadd2(a.x, b.x);
|
||||
c.y = hadd2(a.y, b.y);
|
||||
c.z = hadd2(a.z, b.z);
|
||||
c.w = hadd2(a.w, b.w);
|
||||
return c;
|
||||
}
|
||||
#endif
|
||||
|
||||
// init 128bits data with 0
|
||||
template <typename T>
|
||||
inline __device__ T init_packed_type();
|
||||
|
||||
template <>
|
||||
inline __device__ uint4 init_packed_type()
|
||||
inline __device__ int4 add128b(T& a, T& b)
|
||||
{
|
||||
return make_uint4(0u, 0u, 0u, 0u);
|
||||
T c;
|
||||
c.unpacked[0] = a.unpacked[0] + b.unpacked[0];
|
||||
c.unpacked[1] = a.unpacked[1] + b.unpacked[1];
|
||||
c.unpacked[2] = a.unpacked[2] + b.unpacked[2];
|
||||
c.unpacked[3] = a.unpacked[3] + b.unpacked[3];
|
||||
return c.packed;
|
||||
}
|
||||
|
||||
#ifdef ENABLE_BF16
|
||||
template <>
|
||||
inline __device__ bf168 init_packed_type()
|
||||
{
|
||||
bf168 val;
|
||||
uint4& val_u = reinterpret_cast<uint4&>(val);
|
||||
val_u = make_uint4(0u, 0u, 0u, 0u);
|
||||
return val;
|
||||
}
|
||||
#endif
|
||||
|
||||
__inline__ __device__ void multi_gpu_barrier(
|
||||
uint32_t** signals, const uint32_t flag, const size_t rank, const size_t world_size, const int tidx, const int bidx)
|
||||
{
|
||||
@ -179,10 +143,10 @@ static __global__ void oneShotAllReduceKernel(AllReduceParams params)
|
||||
const int tidx = threadIdx.x;
|
||||
|
||||
// The number of elements packed into one for comms
|
||||
static constexpr int NUM_ELTS = std::is_same<T, uint32_t>::value ? 4 : 8;
|
||||
static constexpr int NUM_ELTS = 16 / sizeof(T);
|
||||
|
||||
// Packed data type for comms
|
||||
using PackedType = typename ARTypeConverter<T>::Type;
|
||||
using PackedStruct = typename PackedOn16Bytes<T>::Type;
|
||||
|
||||
multi_gpu_barrier(params.peer_barrier_ptrs_in, params.barrier_flag, params.local_rank, RANKS_PER_NODE, tidx, bidx);
|
||||
|
||||
@ -204,23 +168,24 @@ static __global__ void oneShotAllReduceKernel(AllReduceParams params)
|
||||
for (size_t iter_offset = offset; iter_offset < max_offset; iter_offset += blockDim.x * NUM_ELTS)
|
||||
{
|
||||
// Iterate over the different ranks/devices on the node to load the values.
|
||||
PackedType vals[RANKS_PER_NODE];
|
||||
PackedStruct vals[RANKS_PER_NODE];
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < RANKS_PER_NODE; ++ii)
|
||||
{
|
||||
vals[ii] = reinterpret_cast<const PackedType*>(&src_d[ii][iter_offset])[0];
|
||||
vals[ii].packed = *reinterpret_cast<const int4*>(&src_d[ii][iter_offset]);
|
||||
}
|
||||
|
||||
// Sum the values from the different ranks.
|
||||
PackedType sums = init_packed_type<PackedType>();
|
||||
PackedStruct sums;
|
||||
sums.packed = {0, 0, 0, 0};
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < RANKS_PER_NODE; ++ii)
|
||||
{
|
||||
sums = add128b<PackedType, T>(sums, vals[ii]);
|
||||
sums.packed = add128b(sums, vals[ii]);
|
||||
}
|
||||
|
||||
// Store to the destination buffer.
|
||||
reinterpret_cast<PackedType*>(&reinterpret_cast<T*>(params.local_output_buffer_ptr)[iter_offset])[0] = sums;
|
||||
*reinterpret_cast<int4*>(&reinterpret_cast<T*>(params.local_output_buffer_ptr)[iter_offset]) = sums.packed;
|
||||
}
|
||||
}
|
||||
|
||||
@ -234,10 +199,10 @@ static __global__ void twoShotAllReduceKernel(AllReduceParams params)
|
||||
const int tidx = threadIdx.x;
|
||||
|
||||
// The number of elements packed into one for comms
|
||||
static constexpr int NUM_ELTS = std::is_same<T, uint32_t>::value ? 4 : 8;
|
||||
static constexpr int NUM_ELTS = 16 / sizeof(T);
|
||||
|
||||
// Packed data type for comms
|
||||
using PackedType = typename ARTypeConverter<T>::Type;
|
||||
using PackedType = typename PackedOn16Bytes<T>::Type;
|
||||
|
||||
// The location in the destination array (load 8 fp16 or load 4 fp32 using LDG.128).
|
||||
const size_t block_offset = bidx * params.elts_per_block + tidx * NUM_ELTS;
|
||||
@ -268,19 +233,20 @@ static __global__ void twoShotAllReduceKernel(AllReduceParams params)
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < RANKS_PER_NODE; ++ii)
|
||||
{
|
||||
vals[ii] = reinterpret_cast<const PackedType*>(&src_d[ii][local_offset])[0];
|
||||
vals[ii].packed = *reinterpret_cast<const int4*>(&src_d[ii][local_offset]);
|
||||
}
|
||||
|
||||
// Sum the values from the different ranks.
|
||||
PackedType sums = init_packed_type<PackedType>();
|
||||
PackedType sums;
|
||||
sums.packed = {0, 0, 0, 0};
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < RANKS_PER_NODE; ++ii)
|
||||
{
|
||||
sums = add128b<PackedType, T>(sums, vals[ii]);
|
||||
sums.packed = add128b(sums, vals[ii]);
|
||||
}
|
||||
|
||||
// Store to the local buffer.
|
||||
reinterpret_cast<PackedType*>(&src_d[0][local_offset])[0] = sums;
|
||||
*reinterpret_cast<int4*>(&src_d[0][local_offset]) = sums.packed;
|
||||
}
|
||||
|
||||
// sync threads to make sure all block threads have the sums
|
||||
@ -318,8 +284,8 @@ static __global__ void twoShotAllReduceKernel(AllReduceParams params)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
reinterpret_cast<PackedType*>(&reinterpret_cast<T*>(params.local_output_buffer_ptr)[offset_rank])[0]
|
||||
= reinterpret_cast<PackedType*>(&src_d[ii][offset_rank])[0];
|
||||
*reinterpret_cast<int4*>(&reinterpret_cast<T*>(params.local_output_buffer_ptr)[offset_rank])
|
||||
= *reinterpret_cast<int4*>(&src_d[ii][offset_rank]);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -434,18 +400,22 @@ AllReduceParams AllReduceParams::deserialize(const int32_t* buffer, size_t tpSiz
|
||||
{
|
||||
void* const* buffer_ptrs = reinterpret_cast<void* const*>(buffer);
|
||||
AllReduceParams params;
|
||||
// Even plugins use ping buffers, odd plugins use pong.
|
||||
// That way, we don't need to wait for other GPUs to be done
|
||||
// before copying input tensor to workspace.
|
||||
const auto buffer_offset = (flag_value % 2 == 0) ? 0 : tpSize;
|
||||
|
||||
for (int i = 0; i < tpSize; ++i)
|
||||
{
|
||||
params.peer_comm_buffer_ptrs[i] = buffer_ptrs[i];
|
||||
params.peer_comm_buffer_ptrs[i] = buffer_ptrs[buffer_offset + i];
|
||||
}
|
||||
for (int i = 0; i < tpSize; ++i)
|
||||
{
|
||||
params.peer_barrier_ptrs_in[i] = reinterpret_cast<uint32_t*>(buffer_ptrs[tpSize + i]);
|
||||
params.peer_barrier_ptrs_in[i] = reinterpret_cast<uint32_t*>(buffer_ptrs[2 * tpSize + i]);
|
||||
}
|
||||
for (int i = 0; i < tpSize; ++i)
|
||||
{
|
||||
params.peer_barrier_ptrs_out[i] = reinterpret_cast<uint32_t*>(buffer_ptrs[2 * tpSize + i]);
|
||||
params.peer_barrier_ptrs_out[i] = reinterpret_cast<uint32_t*>(buffer_ptrs[3 * tpSize + i]);
|
||||
}
|
||||
params.barrier_flag = flag_value;
|
||||
params.ranks_per_node = tpSize;
|
||||
@ -463,19 +433,18 @@ void customAllReduce(kernels::AllReduceParams& params, void* data, size_t elts,
|
||||
|
||||
if (dataType == datatype_enum::TYPE_FP32)
|
||||
{
|
||||
using T = CustomARCommTypeConverter<float>::Type;
|
||||
kernels::invokeOneOrTwoShotAllReduceKernel<T>(params, strat, stream);
|
||||
kernels::invokeOneOrTwoShotAllReduceKernel<float>(params, strat, stream);
|
||||
}
|
||||
else if (dataType == datatype_enum::TYPE_FP16)
|
||||
{
|
||||
using T = CustomARCommTypeConverter<half>::Type;
|
||||
kernels::invokeOneOrTwoShotAllReduceKernel<T>(params, strat, stream);
|
||||
kernels::invokeOneOrTwoShotAllReduceKernel<half>(params, strat, stream);
|
||||
}
|
||||
#ifdef ENABLE_BF16
|
||||
else if (dataType == datatype_enum::TYPE_BF16)
|
||||
{
|
||||
using T = CustomARCommTypeConverter<__nv_bfloat16>::Type;
|
||||
kernels::invokeOneOrTwoShotAllReduceKernel<T>(params, strat, stream);
|
||||
kernels::invokeOneOrTwoShotAllReduceKernel<__nv_bfloat16>(params, strat, stream);
|
||||
}
|
||||
#endif
|
||||
else
|
||||
{
|
||||
TLLM_THROW("Unsupported dataType for customAllReduce");
|
||||
|
||||
@ -30,7 +30,6 @@ namespace tensorrt_llm::kernels
|
||||
{
|
||||
|
||||
constexpr size_t WARP_SIZE = 32;
|
||||
constexpr size_t CUSTOM_AR_SIZE_THRESHOLD = 50331648;
|
||||
constexpr size_t MAX_ALL_REDUCE_BLOCKS = 24;
|
||||
constexpr size_t MAX_RANKS_PER_NODE = 8;
|
||||
constexpr size_t DEFAULT_BLOCK_SIZE = 1024;
|
||||
@ -45,16 +44,6 @@ enum class AllReduceStrategyType : int8_t
|
||||
AUTO = 3,
|
||||
};
|
||||
|
||||
#ifdef ENABLE_BF16
|
||||
typedef struct bf168
|
||||
{
|
||||
__nv_bfloat162 x;
|
||||
__nv_bfloat162 y;
|
||||
__nv_bfloat162 z;
|
||||
__nv_bfloat162 w;
|
||||
} bf168;
|
||||
#endif
|
||||
|
||||
struct AllReduceParams
|
||||
{
|
||||
size_t elts_total;
|
||||
@ -76,26 +65,6 @@ void invokeOneOrTwoShotAllReduceKernel(AllReduceParams& param, AllReduceStrategy
|
||||
|
||||
void invokeMultiGpuBarrier(AllReduceParams& param, cudaStream_t stream);
|
||||
|
||||
template <typename T>
|
||||
struct CustomARCommTypeConverter
|
||||
{
|
||||
using Type = uint32_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct CustomARCommTypeConverter<half>
|
||||
{
|
||||
using Type = uint16_t;
|
||||
};
|
||||
|
||||
#ifdef ENABLE_BF16
|
||||
template <>
|
||||
struct CustomARCommTypeConverter<__nv_bfloat16>
|
||||
{
|
||||
using Type = __nv_bfloat16;
|
||||
};
|
||||
#endif
|
||||
|
||||
void customAllReduce(kernels::AllReduceParams& params, void* data, size_t elts, size_t size_per_elem,
|
||||
common::datatype_enum dataType, AllReduceStrategyType strat, cudaStream_t stream);
|
||||
|
||||
|
||||
@ -19,27 +19,253 @@ namespace tensorrt_llm
|
||||
namespace kernels
|
||||
{
|
||||
// clang-format off
|
||||
extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_0_nqpkv_8_sm_80_cubin[];
|
||||
extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_1_nqpkv_8_sm_80_cubin[];
|
||||
extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_0_nqpkv_8_sm_86_cubin[];
|
||||
extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_1_nqpkv_8_sm_86_cubin[];
|
||||
extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_0_nqpkv_8_sm_89_cubin[];
|
||||
extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_1_nqpkv_8_sm_89_cubin[];
|
||||
extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_2_nqpkv_8_sm_89_cubin[];
|
||||
extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_0_nqpkv_8_sm_90_cubin[];
|
||||
extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_1_nqpkv_8_sm_90_cubin[];
|
||||
extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_2_nqpkv_8_sm_90_cubin[];
|
||||
// SingleQueryToken kernels.
|
||||
extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_8_m_8_sm_80_cubin[];
|
||||
extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_64_nqpkv_8_m_8_sm_80_cubin[];
|
||||
extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_128_nqpkv_8_m_8_sm_80_cubin[];
|
||||
extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_nqpkv_8_m_8_sm_80_cubin[];
|
||||
extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_pagedKV_64_nqpkv_8_m_8_sm_80_cubin[];
|
||||
extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_pagedKV_128_nqpkv_8_m_8_sm_80_cubin[];
|
||||
extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_e4m3_nqpkv_8_m_8_sm_80_cubin[];
|
||||
extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_e4m3_pagedKV_64_nqpkv_8_m_8_sm_80_cubin[];
|
||||
extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_e4m3_pagedKV_128_nqpkv_8_m_8_sm_80_cubin[];
|
||||
extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_nqpkv_8_m_8_sm_80_cubin[];
|
||||
extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_64_nqpkv_8_m_8_sm_80_cubin[];
|
||||
extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_128_nqpkv_8_m_8_sm_80_cubin[];
|
||||
extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_nqpkv_8_m_8_sm_80_cubin[];
|
||||
extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_pagedKV_64_nqpkv_8_m_8_sm_80_cubin[];
|
||||
extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_pagedKV_128_nqpkv_8_m_8_sm_80_cubin[];
|
||||
extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_nqpkv_8_m_8_sm_80_cubin[];
|
||||
extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_pagedKV_64_nqpkv_8_m_8_sm_80_cubin[];
|
||||
extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_pagedKV_128_nqpkv_8_m_8_sm_80_cubin[];
|
||||
extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_8_m_8_sm_86_cubin[];
|
||||
extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_64_nqpkv_8_m_8_sm_86_cubin[];
|
||||
extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_128_nqpkv_8_m_8_sm_86_cubin[];
|
||||
extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_nqpkv_8_m_8_sm_86_cubin[];
|
||||
extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_pagedKV_64_nqpkv_8_m_8_sm_86_cubin[];
|
||||
extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_pagedKV_128_nqpkv_8_m_8_sm_86_cubin[];
|
||||
extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_e4m3_nqpkv_8_m_8_sm_86_cubin[];
|
||||
extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_e4m3_pagedKV_64_nqpkv_8_m_8_sm_86_cubin[];
|
||||
extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_e4m3_pagedKV_128_nqpkv_8_m_8_sm_86_cubin[];
|
||||
extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_nqpkv_8_m_8_sm_86_cubin[];
|
||||
extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_64_nqpkv_8_m_8_sm_86_cubin[];
|
||||
extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_128_nqpkv_8_m_8_sm_86_cubin[];
|
||||
extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_nqpkv_8_m_8_sm_86_cubin[];
|
||||
extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_pagedKV_64_nqpkv_8_m_8_sm_86_cubin[];
|
||||
extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_pagedKV_128_nqpkv_8_m_8_sm_86_cubin[];
|
||||
extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_nqpkv_8_m_8_sm_86_cubin[];
|
||||
extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_pagedKV_64_nqpkv_8_m_8_sm_86_cubin[];
|
||||
extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_pagedKV_128_nqpkv_8_m_8_sm_86_cubin[];
|
||||
extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_8_m_8_sm_89_cubin[];
|
||||
extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_64_nqpkv_8_m_8_sm_89_cubin[];
|
||||
extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_128_nqpkv_8_m_8_sm_89_cubin[];
|
||||
extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_nqpkv_8_m_8_sm_89_cubin[];
|
||||
extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_pagedKV_64_nqpkv_8_m_8_sm_89_cubin[];
|
||||
extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_pagedKV_128_nqpkv_8_m_8_sm_89_cubin[];
|
||||
extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_e4m3_nqpkv_8_m_8_sm_89_cubin[];
|
||||
extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_e4m3_pagedKV_64_nqpkv_8_m_8_sm_89_cubin[];
|
||||
extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_e4m3_pagedKV_128_nqpkv_8_m_8_sm_89_cubin[];
|
||||
extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_nqpkv_8_m_8_sm_89_cubin[];
|
||||
extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_64_nqpkv_8_m_8_sm_89_cubin[];
|
||||
extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_128_nqpkv_8_m_8_sm_89_cubin[];
|
||||
extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_nqpkv_8_m_8_sm_89_cubin[];
|
||||
extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_pagedKV_64_nqpkv_8_m_8_sm_89_cubin[];
|
||||
extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_pagedKV_128_nqpkv_8_m_8_sm_89_cubin[];
|
||||
extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_nqpkv_8_m_8_sm_89_cubin[];
|
||||
extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_pagedKV_64_nqpkv_8_m_8_sm_89_cubin[];
|
||||
extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_pagedKV_128_nqpkv_8_m_8_sm_89_cubin[];
|
||||
extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_8_m_8_sm_90_cubin[];
|
||||
extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_64_nqpkv_8_m_8_sm_90_cubin[];
|
||||
extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_128_nqpkv_8_m_8_sm_90_cubin[];
|
||||
extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_nqpkv_8_m_8_sm_90_cubin[];
|
||||
extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_pagedKV_64_nqpkv_8_m_8_sm_90_cubin[];
|
||||
extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_pagedKV_128_nqpkv_8_m_8_sm_90_cubin[];
|
||||
extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_e4m3_nqpkv_8_m_8_sm_90_cubin[];
|
||||
extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_e4m3_pagedKV_64_nqpkv_8_m_8_sm_90_cubin[];
|
||||
extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_e4m3_pagedKV_128_nqpkv_8_m_8_sm_90_cubin[];
|
||||
extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_nqpkv_8_m_8_sm_90_cubin[];
|
||||
extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_64_nqpkv_8_m_8_sm_90_cubin[];
|
||||
extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_128_nqpkv_8_m_8_sm_90_cubin[];
|
||||
extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_nqpkv_8_m_8_sm_90_cubin[];
|
||||
extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_pagedKV_64_nqpkv_8_m_8_sm_90_cubin[];
|
||||
extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_pagedKV_128_nqpkv_8_m_8_sm_90_cubin[];
|
||||
extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_nqpkv_8_m_8_sm_90_cubin[];
|
||||
extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_pagedKV_64_nqpkv_8_m_8_sm_90_cubin[];
|
||||
extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_pagedKV_128_nqpkv_8_m_8_sm_90_cubin[];
|
||||
|
||||
extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_0_nqpkv_8_sm_80_cubin_len;
|
||||
extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_1_nqpkv_8_sm_80_cubin_len;
|
||||
extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_0_nqpkv_8_sm_86_cubin_len;
|
||||
extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_1_nqpkv_8_sm_86_cubin_len;
|
||||
extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_0_nqpkv_8_sm_89_cubin_len;
|
||||
extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_1_nqpkv_8_sm_89_cubin_len;
|
||||
extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_2_nqpkv_8_sm_89_cubin_len;
|
||||
extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_0_nqpkv_8_sm_90_cubin_len;
|
||||
extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_1_nqpkv_8_sm_90_cubin_len;
|
||||
extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_2_nqpkv_8_sm_90_cubin_len;
|
||||
// MultiQueryToken kernels.
|
||||
extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_0_m_16_sm_80_cubin[];
|
||||
extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_0_m_32_sm_80_cubin[];
|
||||
extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_64_nqpkv_0_m_16_sm_80_cubin[];
|
||||
extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_64_nqpkv_0_m_32_sm_80_cubin[];
|
||||
extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_128_nqpkv_0_m_16_sm_80_cubin[];
|
||||
extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_128_nqpkv_0_m_32_sm_80_cubin[];
|
||||
extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_nqpkv_0_m_16_sm_80_cubin[];
|
||||
extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_nqpkv_0_m_32_sm_80_cubin[];
|
||||
extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_64_nqpkv_0_m_16_sm_80_cubin[];
|
||||
extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_64_nqpkv_0_m_32_sm_80_cubin[];
|
||||
extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_128_nqpkv_0_m_16_sm_80_cubin[];
|
||||
extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_128_nqpkv_0_m_32_sm_80_cubin[];
|
||||
extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_0_m_16_sm_86_cubin[];
|
||||
extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_0_m_32_sm_86_cubin[];
|
||||
extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_64_nqpkv_0_m_16_sm_86_cubin[];
|
||||
extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_64_nqpkv_0_m_32_sm_86_cubin[];
|
||||
extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_128_nqpkv_0_m_16_sm_86_cubin[];
|
||||
extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_128_nqpkv_0_m_32_sm_86_cubin[];
|
||||
extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_nqpkv_0_m_16_sm_86_cubin[];
|
||||
extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_nqpkv_0_m_32_sm_86_cubin[];
|
||||
extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_64_nqpkv_0_m_16_sm_86_cubin[];
|
||||
extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_64_nqpkv_0_m_32_sm_86_cubin[];
|
||||
extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_128_nqpkv_0_m_16_sm_86_cubin[];
|
||||
extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_128_nqpkv_0_m_32_sm_86_cubin[];
|
||||
extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_0_m_16_sm_89_cubin[];
|
||||
extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_0_m_32_sm_89_cubin[];
|
||||
extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_64_nqpkv_0_m_16_sm_89_cubin[];
|
||||
extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_64_nqpkv_0_m_32_sm_89_cubin[];
|
||||
extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_128_nqpkv_0_m_16_sm_89_cubin[];
|
||||
extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_128_nqpkv_0_m_32_sm_89_cubin[];
|
||||
extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_nqpkv_0_m_16_sm_89_cubin[];
|
||||
extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_nqpkv_0_m_32_sm_89_cubin[];
|
||||
extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_64_nqpkv_0_m_16_sm_89_cubin[];
|
||||
extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_64_nqpkv_0_m_32_sm_89_cubin[];
|
||||
extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_128_nqpkv_0_m_16_sm_89_cubin[];
|
||||
extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_128_nqpkv_0_m_32_sm_89_cubin[];
|
||||
extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_0_m_16_sm_90_cubin[];
|
||||
extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_0_m_32_sm_90_cubin[];
|
||||
extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_64_nqpkv_0_m_16_sm_90_cubin[];
|
||||
extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_64_nqpkv_0_m_32_sm_90_cubin[];
|
||||
extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_128_nqpkv_0_m_16_sm_90_cubin[];
|
||||
extern unsigned long long xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_128_nqpkv_0_m_32_sm_90_cubin[];
|
||||
extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_nqpkv_0_m_16_sm_90_cubin[];
|
||||
extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_nqpkv_0_m_32_sm_90_cubin[];
|
||||
extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_64_nqpkv_0_m_16_sm_90_cubin[];
|
||||
extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_64_nqpkv_0_m_32_sm_90_cubin[];
|
||||
extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_128_nqpkv_0_m_16_sm_90_cubin[];
|
||||
extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_128_nqpkv_0_m_32_sm_90_cubin[];
|
||||
|
||||
// SingleQueryToken kernels.
|
||||
extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_8_m_8_sm_80_cubin_len;
|
||||
extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_64_nqpkv_8_m_8_sm_80_cubin_len;
|
||||
extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_128_nqpkv_8_m_8_sm_80_cubin_len;
|
||||
extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_nqpkv_8_m_8_sm_80_cubin_len;
|
||||
extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_pagedKV_64_nqpkv_8_m_8_sm_80_cubin_len;
|
||||
extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_pagedKV_128_nqpkv_8_m_8_sm_80_cubin_len;
|
||||
extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_e4m3_nqpkv_8_m_8_sm_80_cubin_len;
|
||||
extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_e4m3_pagedKV_64_nqpkv_8_m_8_sm_80_cubin_len;
|
||||
extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_e4m3_pagedKV_128_nqpkv_8_m_8_sm_80_cubin_len;
|
||||
extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_nqpkv_8_m_8_sm_80_cubin_len;
|
||||
extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_64_nqpkv_8_m_8_sm_80_cubin_len;
|
||||
extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_128_nqpkv_8_m_8_sm_80_cubin_len;
|
||||
extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_nqpkv_8_m_8_sm_80_cubin_len;
|
||||
extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_pagedKV_64_nqpkv_8_m_8_sm_80_cubin_len;
|
||||
extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_pagedKV_128_nqpkv_8_m_8_sm_80_cubin_len;
|
||||
extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_nqpkv_8_m_8_sm_80_cubin_len;
|
||||
extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_pagedKV_64_nqpkv_8_m_8_sm_80_cubin_len;
|
||||
extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_pagedKV_128_nqpkv_8_m_8_sm_80_cubin_len;
|
||||
extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_8_m_8_sm_86_cubin_len;
|
||||
extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_64_nqpkv_8_m_8_sm_86_cubin_len;
|
||||
extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_128_nqpkv_8_m_8_sm_86_cubin_len;
|
||||
extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_nqpkv_8_m_8_sm_86_cubin_len;
|
||||
extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_pagedKV_64_nqpkv_8_m_8_sm_86_cubin_len;
|
||||
extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_pagedKV_128_nqpkv_8_m_8_sm_86_cubin_len;
|
||||
extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_e4m3_nqpkv_8_m_8_sm_86_cubin_len;
|
||||
extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_e4m3_pagedKV_64_nqpkv_8_m_8_sm_86_cubin_len;
|
||||
extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_e4m3_pagedKV_128_nqpkv_8_m_8_sm_86_cubin_len;
|
||||
extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_nqpkv_8_m_8_sm_86_cubin_len;
|
||||
extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_64_nqpkv_8_m_8_sm_86_cubin_len;
|
||||
extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_128_nqpkv_8_m_8_sm_86_cubin_len;
|
||||
extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_nqpkv_8_m_8_sm_86_cubin_len;
|
||||
extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_pagedKV_64_nqpkv_8_m_8_sm_86_cubin_len;
|
||||
extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_pagedKV_128_nqpkv_8_m_8_sm_86_cubin_len;
|
||||
extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_nqpkv_8_m_8_sm_86_cubin_len;
|
||||
extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_pagedKV_64_nqpkv_8_m_8_sm_86_cubin_len;
|
||||
extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_pagedKV_128_nqpkv_8_m_8_sm_86_cubin_len;
|
||||
extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_8_m_8_sm_89_cubin_len;
|
||||
extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_64_nqpkv_8_m_8_sm_89_cubin_len;
|
||||
extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_128_nqpkv_8_m_8_sm_89_cubin_len;
|
||||
extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_nqpkv_8_m_8_sm_89_cubin_len;
|
||||
extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_pagedKV_64_nqpkv_8_m_8_sm_89_cubin_len;
|
||||
extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_pagedKV_128_nqpkv_8_m_8_sm_89_cubin_len;
|
||||
extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_e4m3_nqpkv_8_m_8_sm_89_cubin_len;
|
||||
extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_e4m3_pagedKV_64_nqpkv_8_m_8_sm_89_cubin_len;
|
||||
extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_e4m3_pagedKV_128_nqpkv_8_m_8_sm_89_cubin_len;
|
||||
extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_nqpkv_8_m_8_sm_89_cubin_len;
|
||||
extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_64_nqpkv_8_m_8_sm_89_cubin_len;
|
||||
extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_128_nqpkv_8_m_8_sm_89_cubin_len;
|
||||
extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_nqpkv_8_m_8_sm_89_cubin_len;
|
||||
extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_pagedKV_64_nqpkv_8_m_8_sm_89_cubin_len;
|
||||
extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_pagedKV_128_nqpkv_8_m_8_sm_89_cubin_len;
|
||||
extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_nqpkv_8_m_8_sm_89_cubin_len;
|
||||
extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_pagedKV_64_nqpkv_8_m_8_sm_89_cubin_len;
|
||||
extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_pagedKV_128_nqpkv_8_m_8_sm_89_cubin_len;
|
||||
extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_8_m_8_sm_90_cubin_len;
|
||||
extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_64_nqpkv_8_m_8_sm_90_cubin_len;
|
||||
extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_128_nqpkv_8_m_8_sm_90_cubin_len;
|
||||
extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_nqpkv_8_m_8_sm_90_cubin_len;
|
||||
extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_pagedKV_64_nqpkv_8_m_8_sm_90_cubin_len;
|
||||
extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_pagedKV_128_nqpkv_8_m_8_sm_90_cubin_len;
|
||||
extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_e4m3_nqpkv_8_m_8_sm_90_cubin_len;
|
||||
extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_e4m3_pagedKV_64_nqpkv_8_m_8_sm_90_cubin_len;
|
||||
extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_e4m3_pagedKV_128_nqpkv_8_m_8_sm_90_cubin_len;
|
||||
extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_nqpkv_8_m_8_sm_90_cubin_len;
|
||||
extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_64_nqpkv_8_m_8_sm_90_cubin_len;
|
||||
extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_128_nqpkv_8_m_8_sm_90_cubin_len;
|
||||
extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_nqpkv_8_m_8_sm_90_cubin_len;
|
||||
extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_pagedKV_64_nqpkv_8_m_8_sm_90_cubin_len;
|
||||
extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_pagedKV_128_nqpkv_8_m_8_sm_90_cubin_len;
|
||||
extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_nqpkv_8_m_8_sm_90_cubin_len;
|
||||
extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_pagedKV_64_nqpkv_8_m_8_sm_90_cubin_len;
|
||||
extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_pagedKV_128_nqpkv_8_m_8_sm_90_cubin_len;
|
||||
|
||||
// MultiQueryToken kernels.
|
||||
extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_0_m_16_sm_80_cubin_len;
|
||||
extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_0_m_32_sm_80_cubin_len;
|
||||
extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_64_nqpkv_0_m_16_sm_80_cubin_len;
|
||||
extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_64_nqpkv_0_m_32_sm_80_cubin_len;
|
||||
extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_128_nqpkv_0_m_16_sm_80_cubin_len;
|
||||
extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_128_nqpkv_0_m_32_sm_80_cubin_len;
|
||||
extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_nqpkv_0_m_16_sm_80_cubin_len;
|
||||
extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_nqpkv_0_m_32_sm_80_cubin_len;
|
||||
extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_64_nqpkv_0_m_16_sm_80_cubin_len;
|
||||
extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_64_nqpkv_0_m_32_sm_80_cubin_len;
|
||||
extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_128_nqpkv_0_m_16_sm_80_cubin_len;
|
||||
extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_128_nqpkv_0_m_32_sm_80_cubin_len;
|
||||
extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_0_m_16_sm_86_cubin_len;
|
||||
extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_0_m_32_sm_86_cubin_len;
|
||||
extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_64_nqpkv_0_m_16_sm_86_cubin_len;
|
||||
extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_64_nqpkv_0_m_32_sm_86_cubin_len;
|
||||
extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_128_nqpkv_0_m_16_sm_86_cubin_len;
|
||||
extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_128_nqpkv_0_m_32_sm_86_cubin_len;
|
||||
extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_nqpkv_0_m_16_sm_86_cubin_len;
|
||||
extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_nqpkv_0_m_32_sm_86_cubin_len;
|
||||
extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_64_nqpkv_0_m_16_sm_86_cubin_len;
|
||||
extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_64_nqpkv_0_m_32_sm_86_cubin_len;
|
||||
extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_128_nqpkv_0_m_16_sm_86_cubin_len;
|
||||
extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_128_nqpkv_0_m_32_sm_86_cubin_len;
|
||||
extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_0_m_16_sm_89_cubin_len;
|
||||
extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_0_m_32_sm_89_cubin_len;
|
||||
extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_64_nqpkv_0_m_16_sm_89_cubin_len;
|
||||
extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_64_nqpkv_0_m_32_sm_89_cubin_len;
|
||||
extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_128_nqpkv_0_m_16_sm_89_cubin_len;
|
||||
extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_128_nqpkv_0_m_32_sm_89_cubin_len;
|
||||
extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_nqpkv_0_m_16_sm_89_cubin_len;
|
||||
extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_nqpkv_0_m_32_sm_89_cubin_len;
|
||||
extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_64_nqpkv_0_m_16_sm_89_cubin_len;
|
||||
extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_64_nqpkv_0_m_32_sm_89_cubin_len;
|
||||
extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_128_nqpkv_0_m_16_sm_89_cubin_len;
|
||||
extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_128_nqpkv_0_m_32_sm_89_cubin_len;
|
||||
extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_0_m_16_sm_90_cubin_len;
|
||||
extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_0_m_32_sm_90_cubin_len;
|
||||
extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_64_nqpkv_0_m_16_sm_90_cubin_len;
|
||||
extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_64_nqpkv_0_m_32_sm_90_cubin_len;
|
||||
extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_128_nqpkv_0_m_16_sm_90_cubin_len;
|
||||
extern uint32_t xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_128_nqpkv_0_m_32_sm_90_cubin_len;
|
||||
extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_nqpkv_0_m_16_sm_90_cubin_len;
|
||||
extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_nqpkv_0_m_32_sm_90_cubin_len;
|
||||
extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_64_nqpkv_0_m_16_sm_90_cubin_len;
|
||||
extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_64_nqpkv_0_m_32_sm_90_cubin_len;
|
||||
extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_128_nqpkv_0_m_16_sm_90_cubin_len;
|
||||
extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_128_nqpkv_0_m_32_sm_90_cubin_len;
|
||||
|
||||
|
||||
static const struct XQAKernelMetaInfo
|
||||
@ -49,21 +275,137 @@ static const struct XQAKernelMetaInfo
|
||||
unsigned int mHeadDim;
|
||||
unsigned int mBeamWidth;
|
||||
unsigned int mNumQHeadsOverKV;
|
||||
unsigned int mMTileSize;
|
||||
unsigned int mTokensPerPage;
|
||||
bool mPagedKVCache;
|
||||
bool mMultiQueryTokens;
|
||||
unsigned int mSM;
|
||||
const unsigned long long* mCubin;
|
||||
unsigned int mCubinSize;
|
||||
const char* mFuncName;
|
||||
} sXqaKernelMetaInfo[] = {
|
||||
{ DATA_TYPE_FP16, DATA_TYPE_FP16, 128, 1, 8, kSM_80, xqa_kernel_dt_fp16_d_128_beam_1_kvt_0_nqpkv_8_sm_80_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_0_nqpkv_8_sm_80_cubin_len, "kernel_mha"},
|
||||
{ DATA_TYPE_FP16, DATA_TYPE_INT8, 128, 1, 8, kSM_80, xqa_kernel_dt_fp16_d_128_beam_1_kvt_1_nqpkv_8_sm_80_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_1_nqpkv_8_sm_80_cubin_len, "kernel_mha"},
|
||||
{ DATA_TYPE_FP16, DATA_TYPE_FP16, 128, 1, 8, kSM_86, xqa_kernel_dt_fp16_d_128_beam_1_kvt_0_nqpkv_8_sm_86_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_0_nqpkv_8_sm_86_cubin_len, "kernel_mha"},
|
||||
{ DATA_TYPE_FP16, DATA_TYPE_INT8, 128, 1, 8, kSM_86, xqa_kernel_dt_fp16_d_128_beam_1_kvt_1_nqpkv_8_sm_86_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_1_nqpkv_8_sm_86_cubin_len, "kernel_mha"},
|
||||
{ DATA_TYPE_FP16, DATA_TYPE_FP16, 128, 1, 8, kSM_89, xqa_kernel_dt_fp16_d_128_beam_1_kvt_0_nqpkv_8_sm_89_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_0_nqpkv_8_sm_89_cubin_len, "kernel_mha"},
|
||||
{ DATA_TYPE_FP16, DATA_TYPE_INT8, 128, 1, 8, kSM_89, xqa_kernel_dt_fp16_d_128_beam_1_kvt_1_nqpkv_8_sm_89_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_1_nqpkv_8_sm_89_cubin_len, "kernel_mha"},
|
||||
{ DATA_TYPE_FP16, DATA_TYPE_E4M3, 128, 1, 8, kSM_89, xqa_kernel_dt_fp16_d_128_beam_1_kvt_2_nqpkv_8_sm_89_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_2_nqpkv_8_sm_89_cubin_len, "kernel_mha"},
|
||||
{ DATA_TYPE_FP16, DATA_TYPE_FP16, 128, 1, 8, kSM_90, xqa_kernel_dt_fp16_d_128_beam_1_kvt_0_nqpkv_8_sm_90_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_0_nqpkv_8_sm_90_cubin_len, "kernel_mha"},
|
||||
{ DATA_TYPE_FP16, DATA_TYPE_INT8, 128, 1, 8, kSM_90, xqa_kernel_dt_fp16_d_128_beam_1_kvt_1_nqpkv_8_sm_90_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_1_nqpkv_8_sm_90_cubin_len, "kernel_mha"},
|
||||
{ DATA_TYPE_FP16, DATA_TYPE_E4M3, 128, 1, 8, kSM_90, xqa_kernel_dt_fp16_d_128_beam_1_kvt_2_nqpkv_8_sm_90_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_2_nqpkv_8_sm_90_cubin_len, "kernel_mha"}
|
||||
// SingleQueryToken kernels.
|
||||
{ DATA_TYPE_FP16, DATA_TYPE_FP16, 128, 1, 8, 8, 0, false, false, kSM_80, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_8_m_8_sm_80_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_8_m_8_sm_80_cubin_len, "kernel_mha"},
|
||||
{ DATA_TYPE_FP16, DATA_TYPE_FP16, 128, 1, 8, 8, 64, true, false, kSM_80, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_64_nqpkv_8_m_8_sm_80_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_64_nqpkv_8_m_8_sm_80_cubin_len, "kernel_mha"},
|
||||
{ DATA_TYPE_FP16, DATA_TYPE_FP16, 128, 1, 8, 8, 128, true, false, kSM_80, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_128_nqpkv_8_m_8_sm_80_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_128_nqpkv_8_m_8_sm_80_cubin_len, "kernel_mha"},
|
||||
{ DATA_TYPE_FP16, DATA_TYPE_INT8, 128, 1, 8, 8, 0, false, false, kSM_80, xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_nqpkv_8_m_8_sm_80_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_nqpkv_8_m_8_sm_80_cubin_len, "kernel_mha"},
|
||||
{ DATA_TYPE_FP16, DATA_TYPE_INT8, 128, 1, 8, 8, 64, true, false, kSM_80, xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_pagedKV_64_nqpkv_8_m_8_sm_80_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_pagedKV_64_nqpkv_8_m_8_sm_80_cubin_len, "kernel_mha"},
|
||||
{ DATA_TYPE_FP16, DATA_TYPE_INT8, 128, 1, 8, 8, 128, true, false, kSM_80, xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_pagedKV_128_nqpkv_8_m_8_sm_80_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_pagedKV_128_nqpkv_8_m_8_sm_80_cubin_len, "kernel_mha"},
|
||||
{ DATA_TYPE_FP16, DATA_TYPE_E4M3, 128, 1, 8, 8, 0, false, false, kSM_80, xqa_kernel_dt_fp16_d_128_beam_1_kvt_e4m3_nqpkv_8_m_8_sm_80_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_e4m3_nqpkv_8_m_8_sm_80_cubin_len, "kernel_mha"},
|
||||
{ DATA_TYPE_FP16, DATA_TYPE_E4M3, 128, 1, 8, 8, 64, true, false, kSM_80, xqa_kernel_dt_fp16_d_128_beam_1_kvt_e4m3_pagedKV_64_nqpkv_8_m_8_sm_80_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_e4m3_pagedKV_64_nqpkv_8_m_8_sm_80_cubin_len, "kernel_mha"},
|
||||
{ DATA_TYPE_FP16, DATA_TYPE_E4M3, 128, 1, 8, 8, 128, true, false, kSM_80, xqa_kernel_dt_fp16_d_128_beam_1_kvt_e4m3_pagedKV_128_nqpkv_8_m_8_sm_80_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_e4m3_pagedKV_128_nqpkv_8_m_8_sm_80_cubin_len, "kernel_mha"},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_FP16, 128, 1, 8, 8, 0, false, false, kSM_80, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_nqpkv_8_m_8_sm_80_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_nqpkv_8_m_8_sm_80_cubin_len, "kernel_mha"},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_FP16, 128, 1, 8, 8, 64, true, false, kSM_80, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_64_nqpkv_8_m_8_sm_80_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_64_nqpkv_8_m_8_sm_80_cubin_len, "kernel_mha"},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_FP16, 128, 1, 8, 8, 128, true, false, kSM_80, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_128_nqpkv_8_m_8_sm_80_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_128_nqpkv_8_m_8_sm_80_cubin_len, "kernel_mha"},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_INT8, 128, 1, 8, 8, 0, false, false, kSM_80, xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_nqpkv_8_m_8_sm_80_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_nqpkv_8_m_8_sm_80_cubin_len, "kernel_mha"},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_INT8, 128, 1, 8, 8, 64, true, false, kSM_80, xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_pagedKV_64_nqpkv_8_m_8_sm_80_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_pagedKV_64_nqpkv_8_m_8_sm_80_cubin_len, "kernel_mha"},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_INT8, 128, 1, 8, 8, 128, true, false, kSM_80, xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_pagedKV_128_nqpkv_8_m_8_sm_80_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_pagedKV_128_nqpkv_8_m_8_sm_80_cubin_len, "kernel_mha"},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_E4M3, 128, 1, 8, 8, 0, false, false, kSM_80, xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_nqpkv_8_m_8_sm_80_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_nqpkv_8_m_8_sm_80_cubin_len, "kernel_mha"},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_E4M3, 128, 1, 8, 8, 64, true, false, kSM_80, xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_pagedKV_64_nqpkv_8_m_8_sm_80_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_pagedKV_64_nqpkv_8_m_8_sm_80_cubin_len, "kernel_mha"},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_E4M3, 128, 1, 8, 8, 128, true, false, kSM_80, xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_pagedKV_128_nqpkv_8_m_8_sm_80_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_pagedKV_128_nqpkv_8_m_8_sm_80_cubin_len, "kernel_mha"},
|
||||
{ DATA_TYPE_FP16, DATA_TYPE_FP16, 128, 1, 8, 8, 0, false, false, kSM_86, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_8_m_8_sm_86_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_8_m_8_sm_86_cubin_len, "kernel_mha"},
|
||||
{ DATA_TYPE_FP16, DATA_TYPE_FP16, 128, 1, 8, 8, 64, true, false, kSM_86, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_64_nqpkv_8_m_8_sm_86_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_64_nqpkv_8_m_8_sm_86_cubin_len, "kernel_mha"},
|
||||
{ DATA_TYPE_FP16, DATA_TYPE_FP16, 128, 1, 8, 8, 128, true, false, kSM_86, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_128_nqpkv_8_m_8_sm_86_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_128_nqpkv_8_m_8_sm_86_cubin_len, "kernel_mha"},
|
||||
{ DATA_TYPE_FP16, DATA_TYPE_INT8, 128, 1, 8, 8, 0, false, false, kSM_86, xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_nqpkv_8_m_8_sm_86_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_nqpkv_8_m_8_sm_86_cubin_len, "kernel_mha"},
|
||||
{ DATA_TYPE_FP16, DATA_TYPE_INT8, 128, 1, 8, 8, 64, true, false, kSM_86, xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_pagedKV_64_nqpkv_8_m_8_sm_86_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_pagedKV_64_nqpkv_8_m_8_sm_86_cubin_len, "kernel_mha"},
|
||||
{ DATA_TYPE_FP16, DATA_TYPE_INT8, 128, 1, 8, 8, 128, true, false, kSM_86, xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_pagedKV_128_nqpkv_8_m_8_sm_86_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_pagedKV_128_nqpkv_8_m_8_sm_86_cubin_len, "kernel_mha"},
|
||||
{ DATA_TYPE_FP16, DATA_TYPE_E4M3, 128, 1, 8, 8, 0, false, false, kSM_86, xqa_kernel_dt_fp16_d_128_beam_1_kvt_e4m3_nqpkv_8_m_8_sm_86_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_e4m3_nqpkv_8_m_8_sm_86_cubin_len, "kernel_mha"},
|
||||
{ DATA_TYPE_FP16, DATA_TYPE_E4M3, 128, 1, 8, 8, 64, true, false, kSM_86, xqa_kernel_dt_fp16_d_128_beam_1_kvt_e4m3_pagedKV_64_nqpkv_8_m_8_sm_86_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_e4m3_pagedKV_64_nqpkv_8_m_8_sm_86_cubin_len, "kernel_mha"},
|
||||
{ DATA_TYPE_FP16, DATA_TYPE_E4M3, 128, 1, 8, 8, 128, true, false, kSM_86, xqa_kernel_dt_fp16_d_128_beam_1_kvt_e4m3_pagedKV_128_nqpkv_8_m_8_sm_86_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_e4m3_pagedKV_128_nqpkv_8_m_8_sm_86_cubin_len, "kernel_mha"},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_FP16, 128, 1, 8, 8, 0, false, false, kSM_86, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_nqpkv_8_m_8_sm_86_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_nqpkv_8_m_8_sm_86_cubin_len, "kernel_mha"},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_FP16, 128, 1, 8, 8, 64, true, false, kSM_86, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_64_nqpkv_8_m_8_sm_86_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_64_nqpkv_8_m_8_sm_86_cubin_len, "kernel_mha"},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_FP16, 128, 1, 8, 8, 128, true, false, kSM_86, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_128_nqpkv_8_m_8_sm_86_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_128_nqpkv_8_m_8_sm_86_cubin_len, "kernel_mha"},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_INT8, 128, 1, 8, 8, 0, false, false, kSM_86, xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_nqpkv_8_m_8_sm_86_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_nqpkv_8_m_8_sm_86_cubin_len, "kernel_mha"},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_INT8, 128, 1, 8, 8, 64, true, false, kSM_86, xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_pagedKV_64_nqpkv_8_m_8_sm_86_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_pagedKV_64_nqpkv_8_m_8_sm_86_cubin_len, "kernel_mha"},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_INT8, 128, 1, 8, 8, 128, true, false, kSM_86, xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_pagedKV_128_nqpkv_8_m_8_sm_86_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_pagedKV_128_nqpkv_8_m_8_sm_86_cubin_len, "kernel_mha"},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_E4M3, 128, 1, 8, 8, 0, false, false, kSM_86, xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_nqpkv_8_m_8_sm_86_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_nqpkv_8_m_8_sm_86_cubin_len, "kernel_mha"},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_E4M3, 128, 1, 8, 8, 64, true, false, kSM_86, xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_pagedKV_64_nqpkv_8_m_8_sm_86_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_pagedKV_64_nqpkv_8_m_8_sm_86_cubin_len, "kernel_mha"},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_E4M3, 128, 1, 8, 8, 128, true, false, kSM_86, xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_pagedKV_128_nqpkv_8_m_8_sm_86_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_pagedKV_128_nqpkv_8_m_8_sm_86_cubin_len, "kernel_mha"},
|
||||
{ DATA_TYPE_FP16, DATA_TYPE_FP16, 128, 1, 8, 8, 0, false, false, kSM_89, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_8_m_8_sm_89_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_8_m_8_sm_89_cubin_len, "kernel_mha"},
|
||||
{ DATA_TYPE_FP16, DATA_TYPE_FP16, 128, 1, 8, 8, 64, true, false, kSM_89, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_64_nqpkv_8_m_8_sm_89_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_64_nqpkv_8_m_8_sm_89_cubin_len, "kernel_mha"},
|
||||
{ DATA_TYPE_FP16, DATA_TYPE_FP16, 128, 1, 8, 8, 128, true, false, kSM_89, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_128_nqpkv_8_m_8_sm_89_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_128_nqpkv_8_m_8_sm_89_cubin_len, "kernel_mha"},
|
||||
{ DATA_TYPE_FP16, DATA_TYPE_INT8, 128, 1, 8, 8, 0, false, false, kSM_89, xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_nqpkv_8_m_8_sm_89_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_nqpkv_8_m_8_sm_89_cubin_len, "kernel_mha"},
|
||||
{ DATA_TYPE_FP16, DATA_TYPE_INT8, 128, 1, 8, 8, 64, true, false, kSM_89, xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_pagedKV_64_nqpkv_8_m_8_sm_89_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_pagedKV_64_nqpkv_8_m_8_sm_89_cubin_len, "kernel_mha"},
|
||||
{ DATA_TYPE_FP16, DATA_TYPE_INT8, 128, 1, 8, 8, 128, true, false, kSM_89, xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_pagedKV_128_nqpkv_8_m_8_sm_89_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_pagedKV_128_nqpkv_8_m_8_sm_89_cubin_len, "kernel_mha"},
|
||||
{ DATA_TYPE_FP16, DATA_TYPE_E4M3, 128, 1, 8, 8, 0, false, false, kSM_89, xqa_kernel_dt_fp16_d_128_beam_1_kvt_e4m3_nqpkv_8_m_8_sm_89_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_e4m3_nqpkv_8_m_8_sm_89_cubin_len, "kernel_mha"},
|
||||
{ DATA_TYPE_FP16, DATA_TYPE_E4M3, 128, 1, 8, 8, 64, true, false, kSM_89, xqa_kernel_dt_fp16_d_128_beam_1_kvt_e4m3_pagedKV_64_nqpkv_8_m_8_sm_89_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_e4m3_pagedKV_64_nqpkv_8_m_8_sm_89_cubin_len, "kernel_mha"},
|
||||
{ DATA_TYPE_FP16, DATA_TYPE_E4M3, 128, 1, 8, 8, 128, true, false, kSM_89, xqa_kernel_dt_fp16_d_128_beam_1_kvt_e4m3_pagedKV_128_nqpkv_8_m_8_sm_89_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_e4m3_pagedKV_128_nqpkv_8_m_8_sm_89_cubin_len, "kernel_mha"},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_FP16, 128, 1, 8, 8, 0, false, false, kSM_89, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_nqpkv_8_m_8_sm_89_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_nqpkv_8_m_8_sm_89_cubin_len, "kernel_mha"},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_FP16, 128, 1, 8, 8, 64, true, false, kSM_89, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_64_nqpkv_8_m_8_sm_89_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_64_nqpkv_8_m_8_sm_89_cubin_len, "kernel_mha"},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_FP16, 128, 1, 8, 8, 128, true, false, kSM_89, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_128_nqpkv_8_m_8_sm_89_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_128_nqpkv_8_m_8_sm_89_cubin_len, "kernel_mha"},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_INT8, 128, 1, 8, 8, 0, false, false, kSM_89, xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_nqpkv_8_m_8_sm_89_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_nqpkv_8_m_8_sm_89_cubin_len, "kernel_mha"},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_INT8, 128, 1, 8, 8, 64, true, false, kSM_89, xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_pagedKV_64_nqpkv_8_m_8_sm_89_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_pagedKV_64_nqpkv_8_m_8_sm_89_cubin_len, "kernel_mha"},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_INT8, 128, 1, 8, 8, 128, true, false, kSM_89, xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_pagedKV_128_nqpkv_8_m_8_sm_89_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_pagedKV_128_nqpkv_8_m_8_sm_89_cubin_len, "kernel_mha"},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_E4M3, 128, 1, 8, 8, 0, false, false, kSM_89, xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_nqpkv_8_m_8_sm_89_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_nqpkv_8_m_8_sm_89_cubin_len, "kernel_mha"},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_E4M3, 128, 1, 8, 8, 64, true, false, kSM_89, xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_pagedKV_64_nqpkv_8_m_8_sm_89_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_pagedKV_64_nqpkv_8_m_8_sm_89_cubin_len, "kernel_mha"},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_E4M3, 128, 1, 8, 8, 128, true, false, kSM_89, xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_pagedKV_128_nqpkv_8_m_8_sm_89_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_pagedKV_128_nqpkv_8_m_8_sm_89_cubin_len, "kernel_mha"},
|
||||
{ DATA_TYPE_FP16, DATA_TYPE_FP16, 128, 1, 8, 8, 0, false, false, kSM_90, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_8_m_8_sm_90_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_8_m_8_sm_90_cubin_len, "kernel_mha"},
|
||||
{ DATA_TYPE_FP16, DATA_TYPE_FP16, 128, 1, 8, 8, 64, true, false, kSM_90, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_64_nqpkv_8_m_8_sm_90_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_64_nqpkv_8_m_8_sm_90_cubin_len, "kernel_mha"},
|
||||
{ DATA_TYPE_FP16, DATA_TYPE_FP16, 128, 1, 8, 8, 128, true, false, kSM_90, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_128_nqpkv_8_m_8_sm_90_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_128_nqpkv_8_m_8_sm_90_cubin_len, "kernel_mha"},
|
||||
{ DATA_TYPE_FP16, DATA_TYPE_INT8, 128, 1, 8, 8, 0, false, false, kSM_90, xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_nqpkv_8_m_8_sm_90_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_nqpkv_8_m_8_sm_90_cubin_len, "kernel_mha"},
|
||||
{ DATA_TYPE_FP16, DATA_TYPE_INT8, 128, 1, 8, 8, 64, true, false, kSM_90, xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_pagedKV_64_nqpkv_8_m_8_sm_90_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_pagedKV_64_nqpkv_8_m_8_sm_90_cubin_len, "kernel_mha"},
|
||||
{ DATA_TYPE_FP16, DATA_TYPE_INT8, 128, 1, 8, 8, 128, true, false, kSM_90, xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_pagedKV_128_nqpkv_8_m_8_sm_90_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_int8_pagedKV_128_nqpkv_8_m_8_sm_90_cubin_len, "kernel_mha"},
|
||||
{ DATA_TYPE_FP16, DATA_TYPE_E4M3, 128, 1, 8, 8, 0, false, false, kSM_90, xqa_kernel_dt_fp16_d_128_beam_1_kvt_e4m3_nqpkv_8_m_8_sm_90_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_e4m3_nqpkv_8_m_8_sm_90_cubin_len, "kernel_mha"},
|
||||
{ DATA_TYPE_FP16, DATA_TYPE_E4M3, 128, 1, 8, 8, 64, true, false, kSM_90, xqa_kernel_dt_fp16_d_128_beam_1_kvt_e4m3_pagedKV_64_nqpkv_8_m_8_sm_90_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_e4m3_pagedKV_64_nqpkv_8_m_8_sm_90_cubin_len, "kernel_mha"},
|
||||
{ DATA_TYPE_FP16, DATA_TYPE_E4M3, 128, 1, 8, 8, 128, true, false, kSM_90, xqa_kernel_dt_fp16_d_128_beam_1_kvt_e4m3_pagedKV_128_nqpkv_8_m_8_sm_90_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_e4m3_pagedKV_128_nqpkv_8_m_8_sm_90_cubin_len, "kernel_mha"},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_FP16, 128, 1, 8, 8, 0, false, false, kSM_90, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_nqpkv_8_m_8_sm_90_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_nqpkv_8_m_8_sm_90_cubin_len, "kernel_mha"},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_FP16, 128, 1, 8, 8, 64, true, false, kSM_90, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_64_nqpkv_8_m_8_sm_90_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_64_nqpkv_8_m_8_sm_90_cubin_len, "kernel_mha"},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_FP16, 128, 1, 8, 8, 128, true, false, kSM_90, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_128_nqpkv_8_m_8_sm_90_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_128_nqpkv_8_m_8_sm_90_cubin_len, "kernel_mha"},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_INT8, 128, 1, 8, 8, 0, false, false, kSM_90, xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_nqpkv_8_m_8_sm_90_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_nqpkv_8_m_8_sm_90_cubin_len, "kernel_mha"},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_INT8, 128, 1, 8, 8, 64, true, false, kSM_90, xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_pagedKV_64_nqpkv_8_m_8_sm_90_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_pagedKV_64_nqpkv_8_m_8_sm_90_cubin_len, "kernel_mha"},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_INT8, 128, 1, 8, 8, 128, true, false, kSM_90, xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_pagedKV_128_nqpkv_8_m_8_sm_90_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_int8_pagedKV_128_nqpkv_8_m_8_sm_90_cubin_len, "kernel_mha"},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_E4M3, 128, 1, 8, 8, 0, false, false, kSM_90, xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_nqpkv_8_m_8_sm_90_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_nqpkv_8_m_8_sm_90_cubin_len, "kernel_mha"},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_E4M3, 128, 1, 8, 8, 64, true, false, kSM_90, xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_pagedKV_64_nqpkv_8_m_8_sm_90_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_pagedKV_64_nqpkv_8_m_8_sm_90_cubin_len, "kernel_mha"},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_E4M3, 128, 1, 8, 8, 128, true, false, kSM_90, xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_pagedKV_128_nqpkv_8_m_8_sm_90_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_pagedKV_128_nqpkv_8_m_8_sm_90_cubin_len, "kernel_mha"},
|
||||
// MultiQueryToken kernels.
|
||||
{ DATA_TYPE_FP16, DATA_TYPE_FP16, 128, 1, 0, 16, 0, false, true, kSM_80, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_0_m_16_sm_80_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_0_m_16_sm_80_cubin_len, "kernel_mha"},
|
||||
{ DATA_TYPE_FP16, DATA_TYPE_FP16, 128, 1, 0, 32, 0, false, true, kSM_80, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_0_m_32_sm_80_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_0_m_32_sm_80_cubin_len, "kernel_mha"},
|
||||
{ DATA_TYPE_FP16, DATA_TYPE_FP16, 128, 1, 0, 16, 64, true, true, kSM_80, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_64_nqpkv_0_m_16_sm_80_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_64_nqpkv_0_m_16_sm_80_cubin_len, "kernel_mha"},
|
||||
{ DATA_TYPE_FP16, DATA_TYPE_FP16, 128, 1, 0, 32, 64, true, true, kSM_80, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_64_nqpkv_0_m_32_sm_80_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_64_nqpkv_0_m_32_sm_80_cubin_len, "kernel_mha"},
|
||||
{ DATA_TYPE_FP16, DATA_TYPE_FP16, 128, 1, 0, 16, 128, true, true, kSM_80, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_128_nqpkv_0_m_16_sm_80_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_128_nqpkv_0_m_16_sm_80_cubin_len, "kernel_mha"},
|
||||
{ DATA_TYPE_FP16, DATA_TYPE_FP16, 128, 1, 0, 32, 128, true, true, kSM_80, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_128_nqpkv_0_m_32_sm_80_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_128_nqpkv_0_m_32_sm_80_cubin_len, "kernel_mha"},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 128, 1, 0, 16, 0, false, true, kSM_80, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_nqpkv_0_m_16_sm_80_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_nqpkv_0_m_16_sm_80_cubin_len, "kernel_mha"},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 128, 1, 0, 32, 0, false, true, kSM_80, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_nqpkv_0_m_32_sm_80_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_nqpkv_0_m_32_sm_80_cubin_len, "kernel_mha"},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 128, 1, 0, 16, 64, true, true, kSM_80, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_64_nqpkv_0_m_16_sm_80_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_64_nqpkv_0_m_16_sm_80_cubin_len, "kernel_mha"},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 128, 1, 0, 32, 64, true, true, kSM_80, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_64_nqpkv_0_m_32_sm_80_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_64_nqpkv_0_m_32_sm_80_cubin_len, "kernel_mha"},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 128, 1, 0, 16, 128, true, true, kSM_80, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_128_nqpkv_0_m_16_sm_80_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_128_nqpkv_0_m_16_sm_80_cubin_len, "kernel_mha"},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 128, 1, 0, 32, 128, true, true, kSM_80, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_128_nqpkv_0_m_32_sm_80_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_128_nqpkv_0_m_32_sm_80_cubin_len, "kernel_mha"},
|
||||
{ DATA_TYPE_FP16, DATA_TYPE_FP16, 128, 1, 0, 16, 0, false, true, kSM_86, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_0_m_16_sm_86_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_0_m_16_sm_86_cubin_len, "kernel_mha"},
|
||||
{ DATA_TYPE_FP16, DATA_TYPE_FP16, 128, 1, 0, 32, 0, false, true, kSM_86, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_0_m_32_sm_86_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_0_m_32_sm_86_cubin_len, "kernel_mha"},
|
||||
{ DATA_TYPE_FP16, DATA_TYPE_FP16, 128, 1, 0, 16, 64, true, true, kSM_86, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_64_nqpkv_0_m_16_sm_86_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_64_nqpkv_0_m_16_sm_86_cubin_len, "kernel_mha"},
|
||||
{ DATA_TYPE_FP16, DATA_TYPE_FP16, 128, 1, 0, 32, 64, true, true, kSM_86, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_64_nqpkv_0_m_32_sm_86_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_64_nqpkv_0_m_32_sm_86_cubin_len, "kernel_mha"},
|
||||
{ DATA_TYPE_FP16, DATA_TYPE_FP16, 128, 1, 0, 16, 128, true, true, kSM_86, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_128_nqpkv_0_m_16_sm_86_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_128_nqpkv_0_m_16_sm_86_cubin_len, "kernel_mha"},
|
||||
{ DATA_TYPE_FP16, DATA_TYPE_FP16, 128, 1, 0, 32, 128, true, true, kSM_86, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_128_nqpkv_0_m_32_sm_86_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_128_nqpkv_0_m_32_sm_86_cubin_len, "kernel_mha"},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 128, 1, 0, 16, 0, false, true, kSM_86, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_nqpkv_0_m_16_sm_86_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_nqpkv_0_m_16_sm_86_cubin_len, "kernel_mha"},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 128, 1, 0, 32, 0, false, true, kSM_86, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_nqpkv_0_m_32_sm_86_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_nqpkv_0_m_32_sm_86_cubin_len, "kernel_mha"},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 128, 1, 0, 16, 64, true, true, kSM_86, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_64_nqpkv_0_m_16_sm_86_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_64_nqpkv_0_m_16_sm_86_cubin_len, "kernel_mha"},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 128, 1, 0, 32, 64, true, true, kSM_86, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_64_nqpkv_0_m_32_sm_86_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_64_nqpkv_0_m_32_sm_86_cubin_len, "kernel_mha"},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 128, 1, 0, 16, 128, true, true, kSM_86, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_128_nqpkv_0_m_16_sm_86_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_128_nqpkv_0_m_16_sm_86_cubin_len, "kernel_mha"},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 128, 1, 0, 32, 128, true, true, kSM_86, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_128_nqpkv_0_m_32_sm_86_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_128_nqpkv_0_m_32_sm_86_cubin_len, "kernel_mha"},
|
||||
{ DATA_TYPE_FP16, DATA_TYPE_FP16, 128, 1, 0, 16, 0, false, true, kSM_89, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_0_m_16_sm_89_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_0_m_16_sm_89_cubin_len, "kernel_mha"},
|
||||
{ DATA_TYPE_FP16, DATA_TYPE_FP16, 128, 1, 0, 32, 0, false, true, kSM_89, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_0_m_32_sm_89_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_0_m_32_sm_89_cubin_len, "kernel_mha"},
|
||||
{ DATA_TYPE_FP16, DATA_TYPE_FP16, 128, 1, 0, 16, 64, true, true, kSM_89, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_64_nqpkv_0_m_16_sm_89_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_64_nqpkv_0_m_16_sm_89_cubin_len, "kernel_mha"},
|
||||
{ DATA_TYPE_FP16, DATA_TYPE_FP16, 128, 1, 0, 32, 64, true, true, kSM_89, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_64_nqpkv_0_m_32_sm_89_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_64_nqpkv_0_m_32_sm_89_cubin_len, "kernel_mha"},
|
||||
{ DATA_TYPE_FP16, DATA_TYPE_FP16, 128, 1, 0, 16, 128, true, true, kSM_89, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_128_nqpkv_0_m_16_sm_89_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_128_nqpkv_0_m_16_sm_89_cubin_len, "kernel_mha"},
|
||||
{ DATA_TYPE_FP16, DATA_TYPE_FP16, 128, 1, 0, 32, 128, true, true, kSM_89, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_128_nqpkv_0_m_32_sm_89_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_128_nqpkv_0_m_32_sm_89_cubin_len, "kernel_mha"},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 128, 1, 0, 16, 0, false, true, kSM_89, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_nqpkv_0_m_16_sm_89_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_nqpkv_0_m_16_sm_89_cubin_len, "kernel_mha"},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 128, 1, 0, 32, 0, false, true, kSM_89, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_nqpkv_0_m_32_sm_89_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_nqpkv_0_m_32_sm_89_cubin_len, "kernel_mha"},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 128, 1, 0, 16, 64, true, true, kSM_89, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_64_nqpkv_0_m_16_sm_89_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_64_nqpkv_0_m_16_sm_89_cubin_len, "kernel_mha"},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 128, 1, 0, 32, 64, true, true, kSM_89, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_64_nqpkv_0_m_32_sm_89_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_64_nqpkv_0_m_32_sm_89_cubin_len, "kernel_mha"},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 128, 1, 0, 16, 128, true, true, kSM_89, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_128_nqpkv_0_m_16_sm_89_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_128_nqpkv_0_m_16_sm_89_cubin_len, "kernel_mha"},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 128, 1, 0, 32, 128, true, true, kSM_89, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_128_nqpkv_0_m_32_sm_89_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_128_nqpkv_0_m_32_sm_89_cubin_len, "kernel_mha"},
|
||||
{ DATA_TYPE_FP16, DATA_TYPE_FP16, 128, 1, 0, 16, 0, false, true, kSM_90, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_0_m_16_sm_90_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_0_m_16_sm_90_cubin_len, "kernel_mha"},
|
||||
{ DATA_TYPE_FP16, DATA_TYPE_FP16, 128, 1, 0, 32, 0, false, true, kSM_90, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_0_m_32_sm_90_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_nqpkv_0_m_32_sm_90_cubin_len, "kernel_mha"},
|
||||
{ DATA_TYPE_FP16, DATA_TYPE_FP16, 128, 1, 0, 16, 64, true, true, kSM_90, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_64_nqpkv_0_m_16_sm_90_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_64_nqpkv_0_m_16_sm_90_cubin_len, "kernel_mha"},
|
||||
{ DATA_TYPE_FP16, DATA_TYPE_FP16, 128, 1, 0, 32, 64, true, true, kSM_90, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_64_nqpkv_0_m_32_sm_90_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_64_nqpkv_0_m_32_sm_90_cubin_len, "kernel_mha"},
|
||||
{ DATA_TYPE_FP16, DATA_TYPE_FP16, 128, 1, 0, 16, 128, true, true, kSM_90, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_128_nqpkv_0_m_16_sm_90_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_128_nqpkv_0_m_16_sm_90_cubin_len, "kernel_mha"},
|
||||
{ DATA_TYPE_FP16, DATA_TYPE_FP16, 128, 1, 0, 32, 128, true, true, kSM_90, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_128_nqpkv_0_m_32_sm_90_cubin, xqa_kernel_dt_fp16_d_128_beam_1_kvt_fp16_pagedKV_128_nqpkv_0_m_32_sm_90_cubin_len, "kernel_mha"},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 128, 1, 0, 16, 0, false, true, kSM_90, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_nqpkv_0_m_16_sm_90_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_nqpkv_0_m_16_sm_90_cubin_len, "kernel_mha"},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 128, 1, 0, 32, 0, false, true, kSM_90, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_nqpkv_0_m_32_sm_90_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_nqpkv_0_m_32_sm_90_cubin_len, "kernel_mha"},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 128, 1, 0, 16, 64, true, true, kSM_90, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_64_nqpkv_0_m_16_sm_90_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_64_nqpkv_0_m_16_sm_90_cubin_len, "kernel_mha"},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 128, 1, 0, 32, 64, true, true, kSM_90, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_64_nqpkv_0_m_32_sm_90_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_64_nqpkv_0_m_32_sm_90_cubin_len, "kernel_mha"},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 128, 1, 0, 16, 128, true, true, kSM_90, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_128_nqpkv_0_m_16_sm_90_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_128_nqpkv_0_m_16_sm_90_cubin_len, "kernel_mha"},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 128, 1, 0, 32, 128, true, true, kSM_90, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_128_nqpkv_0_m_32_sm_90_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_bf16_pagedKV_128_nqpkv_0_m_32_sm_90_cubin_len, "kernel_mha"}
|
||||
};
|
||||
|
||||
// clang-format on
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user