mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-04 18:21:52 +08:00
74 lines
2.7 KiB
Python
74 lines
2.7 KiB
Python
# This file is vendored from the Triton project. DO NOT EDIT THIS FILE DIRECTLY.
|
||
# Source: https://github.com/triton-lang/triton/tree/v3.5.1/python/triton_kernels/triton_kernels/compaction.py
|
||
# Triton is licensed under the MIT License.
|
||
|
||
import torch
|
||
from .compaction_details._masked_compaction import _masked_compaction
|
||
from .tensor import Bitmatrix
|
||
|
||
|
||
def compaction(yv, yi, bitmask, sentinel=-1):
|
||
"""
|
||
Return compacted copies of *yv* and *yi* based on a per-row bitmask.
|
||
|
||
Only the elements whose index appears among the active bits of *bitmask*
|
||
are kept; the rest are replaced by *sentinel*. Kept elements preserve
|
||
their original left-to-right order.
|
||
|
||
Parameters
|
||
----------
|
||
yv : torch.Tensor, shape (B, K)
|
||
Values tensor.
|
||
yi : torch.Tensor, shape (B, K), dtype torch.long
|
||
Integer indices (0 ≤ index < 32) associated with *yv*.
|
||
bitmask : torch.Tensor, shape (B,) **or** (B, 32)
|
||
Per-row mask of active indices. See the in-place version for details.
|
||
sentinel : int, default -1
|
||
Value written into dropped positions of the returned tensors.
|
||
|
||
Returns
|
||
-------
|
||
(yv_out, yi_out) : Tuple[torch.Tensor, torch.Tensor], each shape (B, K)
|
||
New tensors with the same dtype/device as the inputs.
|
||
|
||
"""
|
||
|
||
n_rows, n_cols = yi.shape
|
||
ret_yv = torch.empty_like(yv)
|
||
ret_yi = torch.empty_like(yi)
|
||
if isinstance(bitmask, Bitmatrix):
|
||
bitmask = bitmask.storage.data
|
||
|
||
_masked_compaction[(n_rows, )](
|
||
yv, yi, bitmask, bitmask.stride(0), bitmask.stride(1), # inputs
|
||
ret_yv, ret_yi, # outputs
|
||
sentinel, # sentinel
|
||
K=n_cols # constants
|
||
)
|
||
return ret_yv, ret_yi
|
||
|
||
|
||
def compaction_torch(yv: torch.Tensor, yi: torch.Tensor, bitmask: torch.Tensor, sentinel=-1):
|
||
"""
|
||
reference implementation of `masked_compact`
|
||
"""
|
||
B, K = yi.shape
|
||
device = yi.device
|
||
# Expand bitmask to a boolean matrix of active bits (B, 32)
|
||
w = (1 << torch.arange(32, device=device, dtype=bitmask.dtype))
|
||
bits = (bitmask.unsqueeze(-1) & w) != 0
|
||
mask = bits.flatten(start_dim=-2) # or bits.reshape(B, -1)
|
||
# For every yi element decide whether it should be kept
|
||
keep = mask.gather(1, yi.long())
|
||
# Build a stable permutation that brings all "keep" items forward
|
||
# False→0, True→1 ==> invert so kept==0, dropped==1, then argsort
|
||
order = (~keep).to(torch.int).argsort(dim=1, stable=True)
|
||
# Re‑order tensors according to above permutation
|
||
yi_sorted = yi.gather(1, order)
|
||
yv_sorted = yv.gather(1, order)
|
||
# fill relevant positions with sentinel
|
||
keep_sorted = keep.gather(1, order)
|
||
yi_sorted[~keep_sorted] = sentinel
|
||
yv_sorted[~keep_sorted] = sentinel
|
||
return yv_sorted, yi_sorted
|