TensorRT-LLMs/cpp/kernels/fmha_v2/train_ops
tburt-nv 6147452158
[https://nvbugs/4141427][chore] Add more details to LICENSE file (#9881)
Signed-off-by: Tyler Burt <195370667+tburt-nv@users.noreply.github.com>
2025-12-13 08:35:31 +08:00
..
hopper
kernels
apex_mha_api.cpp
apex_mha_kernels.cu
bert_mha_train_api.cpp
CMakeLists.txt
fmha_bmark.py
fmha_noloop_reduce.cu
fmha_unit_test.py
fp8_mha_api.cpp
fused_multihead_attention_dgrad_kernel_1xN_flash.h
fused_multihead_attention_dgrad_kernel_1xN_noloop.h
fused_multihead_attention_dgrad_kernel_1xN_reload_noloop.h
fused_multihead_attention_dgrad_kernel_1xN_reload.h
fused_multihead_attention_dgrad_kernel_1xN.h
fused_multihead_attention_flash_attention_fprop_kernel.h
fused_multihead_attention_fprop_kernel_1xN_reload.h
fused_multihead_attention_fprop_kernel_1xN.h
fused_multihead_attention_fprop_kernel.h
fused_multihead_attention_fprop.h
gmem_tile_d.h
Makefile [https://nvbugs/4141427][chore] Add more details to LICENSE file (#9881) 2025-12-13 08:35:31 +08:00
my_utils.py
myenv.sh
philox.h
README.md
smem_tile_d.h
smem_tile_dq.h
static_switch.h
te_mha.py
test_bmm.cpp
test.cpp
test.h
train_setup.py [https://nvbugs/4141427][chore] Add more details to LICENSE file (#9881) 2025-12-13 08:35:31 +08:00

Running the unit test

Under the train_ops folder, clone APEX

cd train_ops
git clone https://github.com/NVIDIA/apex.git
cd apex
git submodule update --init --recursive
cd ..

from the project root, launch the build container:

cd docker
make launch_docker

Then inside the container, /repo is the repository mount point:

export TORCH_CUDA_ARCH_LIST="8.0;9.0"
cd /repo/train_ops
python train_setup.py
mkdir -p build && cd build && cmake .. && make -j

Note that we use flash attention by default.

Then in train_ops, run the test script

python fmha_unit_test.py