mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-06 11:11:36 +08:00
25 lines
1.1 KiB
Python
25 lines
1.1 KiB
Python
# This file is vendored from the Triton project. DO NOT EDIT THIS FILE DIRECTLY.
|
|
# Source: https://github.com/triton-lang/triton/tree/v3.5.1/python/triton_kernels/triton_kernels/compaction_details/_masked_compaction.py
|
|
# Triton is licensed under the MIT License.
|
|
|
|
import triton
|
|
import triton.language as tl
|
|
|
|
|
|
@triton.jit
|
|
def _masked_compaction(Yv, Yi, BitMask, stride_bm, stride_bn, RetYv, RetYi, sentinel, K: tl.constexpr):
|
|
pid_m = tl.program_id(0)
|
|
yv = tl.load(Yv + pid_m * K + tl.arange(0, K))
|
|
yi = tl.load(Yi + pid_m * K + tl.arange(0, K))
|
|
div = yi // 32
|
|
rem = yi % 32
|
|
active_bits = (tl.load(BitMask + pid_m * stride_bm + div * stride_bn) >> rem) & 1
|
|
exc_cumsum = tl.cumsum(active_bits, 0) - active_bits
|
|
active_flags = active_bits.to(tl.int1)
|
|
rev_arange = tl.where(active_flags, 0, K - 1 - tl.arange(0, K))
|
|
write_indx = exc_cumsum + rev_arange
|
|
yv = tl.where(active_flags, yv, sentinel)
|
|
yi = tl.where(active_flags, yi, sentinel)
|
|
tl.store(RetYv + pid_m * K + write_indx, yv)
|
|
tl.store(RetYi + pid_m * K + write_indx, yi)
|