diff --git a/cpp/kernels/xqa/mha_sm90.cu b/cpp/kernels/xqa/mha_sm90.cu index dc21872ac6..7f7164c9a2 100644 --- a/cpp/kernels/xqa/mha_sm90.cu +++ b/cpp/kernels/xqa/mha_sm90.cu @@ -1138,6 +1138,10 @@ CUBIN_EXPORT __global__ auto& xBar = smem.xBar[idxXBuf]; auto& vBar = smem.vBar[idxVBuf]; auto const& vBuf = smem.vBuf(idxVBuf); +#if !SWAP_AB + CtaBarrierPair& vtBar = smem.vtBar[idxVBuf]; + auto& vtBuf = smem.vtBuf(idxVBuf); +#endif xBar.produced.arrive_and_wait(); #if SKIP_SOFTMAX_ATTN bool shouldSkipSoftmaxAttn = smem.skipSoftmaxVotesGemm0ToGemm1[idxXBuf]; // guarded by xBar @@ -1153,8 +1157,6 @@ CUBIN_EXPORT __global__ { arrive_tx_and_wait(vBar.produced, exactDiv(sizeof(SharedMem::VBuffer), gemm1NbThrds)); #if !SWAP_AB - CtaBarrierPair& vtBar = smem.vtBar[idxVBuf]; - auto& vtBuf = smem.vtBuf(idxVBuf); vtBar.consumed.arrive_and_wait(); transposeVTile(warpRank, laneId(), vtBuf, vBuf); vBar.consumed.arrive();