TensorRT-LLMs/triton_kernels/compaction_details/_masked_compaction.py
Anish Shanbhag 24ac86c485
[https://nvbugs/5761391][fix] Include triton-kernels as a packaged dependency (#10471)
Signed-off-by: Anish Shanbhag <ashanbhag@nvidia.com>
2026-01-28 19:56:32 -08:00

25 lines
1.1 KiB
Python

# This file is vendored from the Triton project. DO NOT EDIT THIS FILE DIRECTLY.
# Source: https://github.com/triton-lang/triton/tree/v3.5.1/python/triton_kernels/triton_kernels/compaction_details/_masked_compaction.py
# Triton is licensed under the MIT License.
import triton
import triton.language as tl
@triton.jit
def _masked_compaction(Yv, Yi, BitMask, stride_bm, stride_bn, RetYv, RetYi, sentinel, K: tl.constexpr):
pid_m = tl.program_id(0)
yv = tl.load(Yv + pid_m * K + tl.arange(0, K))
yi = tl.load(Yi + pid_m * K + tl.arange(0, K))
div = yi // 32
rem = yi % 32
active_bits = (tl.load(BitMask + pid_m * stride_bm + div * stride_bn) >> rem) & 1
exc_cumsum = tl.cumsum(active_bits, 0) - active_bits
active_flags = active_bits.to(tl.int1)
rev_arange = tl.where(active_flags, 0, K - 1 - tl.arange(0, K))
write_indx = exc_cumsum + rev_arange
yv = tl.where(active_flags, yv, sentinel)
yi = tl.where(active_flags, yi, sentinel)
tl.store(RetYv + pid_m * K + write_indx, yv)
tl.store(RetYi + pid_m * K + write_indx, yi)