TensorRT-LLMs/cpp/kernels/fmha_v2/train_ops
tburt-nv 6147452158
[https://nvbugs/4141427][chore] Add more details to LICENSE file (#9881)
Signed-off-by: Tyler Burt <195370667+tburt-nv@users.noreply.github.com>
2025-12-13 08:35:31 +08:00
..
hopper [#8476][chore] Update license (#8807) 2025-11-19 15:05:25 -08:00
kernels [#8476][chore] Update license (#8807) 2025-11-19 15:05:25 -08:00
apex_mha_api.cpp [#8476][chore] Update license (#8807) 2025-11-19 15:05:25 -08:00
apex_mha_kernels.cu [#8476][chore] Update license (#8807) 2025-11-19 15:05:25 -08:00
bert_mha_train_api.cpp [#8476][chore] Update license (#8807) 2025-11-19 15:05:25 -08:00
CMakeLists.txt chroe:clean useless flag (#4567) 2025-05-23 07:05:15 +08:00
fmha_bmark.py [#8476][chore] Update license (#8807) 2025-11-19 15:05:25 -08:00
fmha_noloop_reduce.cu [#8476][chore] Update license (#8807) 2025-11-19 15:05:25 -08:00
fmha_unit_test.py [#8476][chore] Update license (#8807) 2025-11-19 15:05:25 -08:00
fp8_mha_api.cpp [#8476][chore] Update license (#8807) 2025-11-19 15:05:25 -08:00
fused_multihead_attention_dgrad_kernel_1xN_flash.h [#8476][chore] Update license (#8807) 2025-11-19 15:05:25 -08:00
fused_multihead_attention_dgrad_kernel_1xN_noloop.h [#8476][chore] Update license (#8807) 2025-11-19 15:05:25 -08:00
fused_multihead_attention_dgrad_kernel_1xN_reload_noloop.h [#8476][chore] Update license (#8807) 2025-11-19 15:05:25 -08:00
fused_multihead_attention_dgrad_kernel_1xN_reload.h [#8476][chore] Update license (#8807) 2025-11-19 15:05:25 -08:00
fused_multihead_attention_dgrad_kernel_1xN.h [#8476][chore] Update license (#8807) 2025-11-19 15:05:25 -08:00
fused_multihead_attention_flash_attention_fprop_kernel.h [#8476][chore] Update license (#8807) 2025-11-19 15:05:25 -08:00
fused_multihead_attention_fprop_kernel_1xN_reload.h [#8476][chore] Update license (#8807) 2025-11-19 15:05:25 -08:00
fused_multihead_attention_fprop_kernel_1xN.h [#8476][chore] Update license (#8807) 2025-11-19 15:05:25 -08:00
fused_multihead_attention_fprop_kernel.h [#8476][chore] Update license (#8807) 2025-11-19 15:05:25 -08:00
fused_multihead_attention_fprop.h [#8476][chore] Update license (#8807) 2025-11-19 15:05:25 -08:00
gmem_tile_d.h [#8476][chore] Update license (#8807) 2025-11-19 15:05:25 -08:00
Makefile [https://nvbugs/4141427][chore] Add more details to LICENSE file (#9881) 2025-12-13 08:35:31 +08:00
my_utils.py [#8476][chore] Update license (#8807) 2025-11-19 15:05:25 -08:00
myenv.sh infra: open source fmha v2 kernels (#4185) 2025-05-15 10:56:34 +08:00
philox.h [#8476][chore] Update license (#8807) 2025-11-19 15:05:25 -08:00
README.md infra: open source fmha v2 kernels (#4185) 2025-05-15 10:56:34 +08:00
smem_tile_d.h [#8476][chore] Update license (#8807) 2025-11-19 15:05:25 -08:00
smem_tile_dq.h [#8476][chore] Update license (#8807) 2025-11-19 15:05:25 -08:00
static_switch.h [#8476][chore] Update license (#8807) 2025-11-19 15:05:25 -08:00
te_mha.py [#8476][chore] Update license (#8807) 2025-11-19 15:05:25 -08:00
test_bmm.cpp [#8476][chore] Update license (#8807) 2025-11-19 15:05:25 -08:00
test.cpp [#8476][chore] Update license (#8807) 2025-11-19 15:05:25 -08:00
test.h [#8476][chore] Update license (#8807) 2025-11-19 15:05:25 -08:00
train_setup.py [https://nvbugs/4141427][chore] Add more details to LICENSE file (#9881) 2025-12-13 08:35:31 +08:00

Running the unit test

Under the train_ops folder, clone APEX

cd train_ops
git clone https://github.com/NVIDIA/apex.git
cd apex
git submodule update --init --recursive
cd ..

from the project root, launch the build container:

cd docker
make launch_docker

Then inside the container, /repo is the repository mount point:

export TORCH_CUDA_ARCH_LIST="8.0;9.0"
cd /repo/train_ops
python train_setup.py
mkdir -p build && cd build && cmake .. && make -j

Note that we use flash attention by default.

Then in train_ops, run the test script

python fmha_unit_test.py