TensorRT-LLMs/cpp/kernels/fmha_v2/train_ops
qsang-nv 0fd59d64ab
infra: open source fmha v2 kernels (#4185)
* add fmha repo

Signed-off-by: Qidi Sang <200703406+qsang-nv@users.noreply.github.com>

* fix format

Signed-off-by: Qidi Sang <200703406+qsang-nv@users.noreply.github.com>

* fix code style

Signed-off-by: Qidi Sang <200703406+qsang-nv@users.noreply.github.com>

* fix header

Signed-off-by: Qidi Sang <200703406+qsang-nv@users.noreply.github.com>

* fix header kernel_traits.h

Signed-off-by: Qidi Sang <200703406+qsang-nv@users.noreply.github.com>

* add .gitignore file

Signed-off-by: Qidi Sang <200703406+qsang-nv@users.noreply.github.com>

* add SLIDING_WINDOW_ATTENTION

Signed-off-by: Qidi Sang <200703406+qsang-nv@users.noreply.github.com>

* fix style

Signed-off-by: Qidi Sang <200703406+qsang-nv@users.noreply.github.com>

* fix format

Signed-off-by: Qidi Sang <200703406+qsang-nv@users.noreply.github.com>

* update setup.py

Signed-off-by: Qidi Sang <200703406+qsang-nv@users.noreply.github.com>

* update build_wheel.py

Signed-off-by: Qidi Sang <200703406+qsang-nv@users.noreply.github.com>

---------

Signed-off-by: Qidi Sang <200703406+qsang-nv@users.noreply.github.com>
Signed-off-by: qsang-nv <200703406+qsang-nv@users.noreply.github.com>
2025-05-15 10:56:34 +08:00
..
hopper infra: open source fmha v2 kernels (#4185) 2025-05-15 10:56:34 +08:00
kernels infra: open source fmha v2 kernels (#4185) 2025-05-15 10:56:34 +08:00
apex_mha_api.cpp infra: open source fmha v2 kernels (#4185) 2025-05-15 10:56:34 +08:00
apex_mha_kernels.cu infra: open source fmha v2 kernels (#4185) 2025-05-15 10:56:34 +08:00
bert_mha_train_api.cpp infra: open source fmha v2 kernels (#4185) 2025-05-15 10:56:34 +08:00
CMakeLists.txt infra: open source fmha v2 kernels (#4185) 2025-05-15 10:56:34 +08:00
fmha_bmark.py infra: open source fmha v2 kernels (#4185) 2025-05-15 10:56:34 +08:00
fmha_noloop_reduce.cu infra: open source fmha v2 kernels (#4185) 2025-05-15 10:56:34 +08:00
fmha_unit_test.py infra: open source fmha v2 kernels (#4185) 2025-05-15 10:56:34 +08:00
fp8_mha_api.cpp infra: open source fmha v2 kernels (#4185) 2025-05-15 10:56:34 +08:00
fused_multihead_attention_dgrad_kernel_1xN_flash.h infra: open source fmha v2 kernels (#4185) 2025-05-15 10:56:34 +08:00
fused_multihead_attention_dgrad_kernel_1xN_noloop.h infra: open source fmha v2 kernels (#4185) 2025-05-15 10:56:34 +08:00
fused_multihead_attention_dgrad_kernel_1xN_reload_noloop.h infra: open source fmha v2 kernels (#4185) 2025-05-15 10:56:34 +08:00
fused_multihead_attention_dgrad_kernel_1xN_reload.h infra: open source fmha v2 kernels (#4185) 2025-05-15 10:56:34 +08:00
fused_multihead_attention_dgrad_kernel_1xN.h infra: open source fmha v2 kernels (#4185) 2025-05-15 10:56:34 +08:00
fused_multihead_attention_flash_attention_fprop_kernel.h infra: open source fmha v2 kernels (#4185) 2025-05-15 10:56:34 +08:00
fused_multihead_attention_fprop_kernel_1xN_reload.h infra: open source fmha v2 kernels (#4185) 2025-05-15 10:56:34 +08:00
fused_multihead_attention_fprop_kernel_1xN.h infra: open source fmha v2 kernels (#4185) 2025-05-15 10:56:34 +08:00
fused_multihead_attention_fprop_kernel.h infra: open source fmha v2 kernels (#4185) 2025-05-15 10:56:34 +08:00
fused_multihead_attention_fprop.h infra: open source fmha v2 kernels (#4185) 2025-05-15 10:56:34 +08:00
gmem_tile_d.h infra: open source fmha v2 kernels (#4185) 2025-05-15 10:56:34 +08:00
Makefile infra: open source fmha v2 kernels (#4185) 2025-05-15 10:56:34 +08:00
my_utils.py infra: open source fmha v2 kernels (#4185) 2025-05-15 10:56:34 +08:00
myenv.sh infra: open source fmha v2 kernels (#4185) 2025-05-15 10:56:34 +08:00
philox.h infra: open source fmha v2 kernels (#4185) 2025-05-15 10:56:34 +08:00
README.md infra: open source fmha v2 kernels (#4185) 2025-05-15 10:56:34 +08:00
smem_tile_d.h infra: open source fmha v2 kernels (#4185) 2025-05-15 10:56:34 +08:00
smem_tile_dq.h infra: open source fmha v2 kernels (#4185) 2025-05-15 10:56:34 +08:00
static_switch.h infra: open source fmha v2 kernels (#4185) 2025-05-15 10:56:34 +08:00
te_mha.py infra: open source fmha v2 kernels (#4185) 2025-05-15 10:56:34 +08:00
test_bmm.cpp infra: open source fmha v2 kernels (#4185) 2025-05-15 10:56:34 +08:00
test.cpp infra: open source fmha v2 kernels (#4185) 2025-05-15 10:56:34 +08:00
test.h infra: open source fmha v2 kernels (#4185) 2025-05-15 10:56:34 +08:00
train_setup.py infra: open source fmha v2 kernels (#4185) 2025-05-15 10:56:34 +08:00

Running the unit test

Under the train_ops folder, clone APEX

cd train_ops
git clone https://github.com/NVIDIA/apex.git
cd apex
git submodule update --init --recursive
cd ..

from the project root, launch the build container:

cd docker
make launch_docker

Then inside the container, /repo is the repository mount point:

export TORCH_CUDA_ARCH_LIST="8.0;9.0"
cd /repo/train_ops
python train_setup.py
mkdir -p build && cd build && cmake .. && make -j

Note that we use flash attention by default.

Then in train_ops, run the test script

python fmha_unit_test.py