TensorRT-LLMs/triton_kernels/routing.py
Anish Shanbhag 24ac86c485
[https://nvbugs/5761391][fix] Include triton-kernels as a packaged dependency (#10471)
Signed-off-by: Anish Shanbhag <ashanbhag@nvidia.com>
2026-01-28 19:56:32 -08:00

397 lines
16 KiB
Python

# This file is vendored from the Triton project. DO NOT EDIT THIS FILE DIRECTLY.
# Source: https://github.com/triton-lang/triton/tree/v3.5.1/python/triton_kernels/triton_kernels/routing.py
# Triton is licensed under the MIT License.
import torch
import triton
from dataclasses import dataclass, field
from .routing_details._routing_compute import _combined_routing_compute
from .routing_details._routing_compute import _combined_routing_memset
from .routing_details._routing_compute import _routing_clear_bitmatrix
from .routing_details._expt_data import _expt_data_memset
from .routing_details._expt_data import _expt_data_compute
from .target_info import is_hip
@dataclass
class GatherIndx:
"""
Indices for an operation that performs:
Y = X[src_idx, :]
"""
# array such that `dst_idx[src_idx] = arange(0, N)`
src_indx: torch.Tensor
dst_indx: torch.Tensor
@dataclass
class ScatterIndx:
"""
Indices for an operation that performs:
Y[dst_idx, :] = X
"""
# array such that `dst_idx[src_idx] = arange(0, N)`
src_indx: torch.Tensor
dst_indx: torch.Tensor
@dataclass
class ExptData:
# hist[i] is the number of tokens routed to expert i
hist: torch.Tensor
# token_offs_raw[i] is the offset of the first token routed
# to expert i in an expert-sorted array
token_offs_raw: torch.Tensor
# token_offs_pad[block][i] is the offset of the first token routed
# to expert i in an expert-sorted array, assuming histogram
# rounded to the next multiple of `block`
token_offs_pad: dict[int, torch.Tensor]
# block_id_map[block] contain one value for each `pid`` launched by
# the matrix multiplication kernel launched with BLOCK_M=block:
# - the value is -1 if the `pid` has no work to do
# - otherwise, the value is two int16 (packed as an int32) that
# correspond respectively to (1) the expert assigned to
# the tokens processed by this pid; (2) the block assigned to the
# tokens processed by this pid (think `pid_m` in a regular matmul)
# see `test_routing.py` for a reference implementation and more details
block_pid_map: dict[int, torch.Tensor]
def __post_init__(self):
if self.hist is not None:
assert self.hist.dtype == torch.int32
if self.token_offs_raw is not None:
assert self.token_offs_raw.dtype == torch.int32
if self.token_offs_pad is not None:
for v in self.token_offs_pad.values():
assert v.dtype == torch.int32
if self.block_pid_map is not None:
for v in self.block_pid_map.values():
assert v.dtype == torch.int32
@dataclass
class RoutingData:
gate_scal: torch.Tensor = field()
expt_hist: torch.Tensor = field()
n_expts_tot: int = field()
n_expts_act: int = field()
expt_data: ExptData = None
# Used to make perf annotation cleaner: when we use expert sharding, we can
# use this to tell the "expected" number of local tokens per expert, because
# the actual number can vary per each input.
expected_tokens_per_expt: int = field(default=None)
def n_blocks(self, n_rows, block_m):
if n_rows <= self.n_expts_tot:
return n_rows
else:
return triton.cdiv(max(n_rows - self.n_expts_tot + 1, 0), block_m) + self.n_expts_tot - 1
# --------------------------
# sort tokens by expert
# --------------------------
class SortTokens(torch.autograd.Function):
@staticmethod
def forward(ctx, expt_scal, expt_indx, n_expts_tot, bitmatrix):
HIST_BLOCK_M = 32
INDX_OFFS_BLOCK_M = 512
MEMSET_BLOCK = 1024
cdiv = triton.cdiv
device = expt_scal.device
dtype = expt_scal.dtype
n_tokens_raw, _ = bitmatrix.shape
n_tokens_pad, n_expts_act = expt_scal.shape
n_gates_pad = n_tokens_pad * n_expts_act
hist, partial_hist = bitmatrix.sum(partials_block_size=HIST_BLOCK_M)
hist = hist[:n_expts_tot]
assert hist.dtype == torch.int32
# scratchpad
expt_offs = torch.empty(n_expts_tot, dtype=torch.int32, device=device)
combined_indx = torch.empty(n_gates_pad * 2, dtype=torch.int32, device=device)
# output
topk_indx = combined_indx[:n_gates_pad]
gate_indx = combined_indx[n_gates_pad:]
gate_scal = torch.empty(n_gates_pad, dtype=dtype, device=device)
token_offs_combined, token_offs_raw, token_offs_pad, block_pid_map, blocks1a, blocks2a, MEMSET_BLOCK_A, HIST2_BLOCK_M, block_m_log2_start, block_m_num = _compute_expt_data_internal(
hist, n_expts_tot, n_gates_pad)
blocks1b = cdiv(n_gates_pad * 2, MEMSET_BLOCK) + n_expts_tot + 1
blocks2b = cdiv(n_tokens_pad, HIST_BLOCK_M)
_combined_routing_memset[(blocks1a + blocks1b, )](
combined_indx, n_gates_pad * 2, -1, MEMSET_BLOCK, hist, #
expt_offs, hist.shape[0], n_expts_tot, partial_hist, # inputs
partial_hist.shape[0], partial_hist.stride(0), partial_hist.stride(1), # outputs
token_offs_combined, token_offs_combined.stride(0), #
blocks1a, block_pid_map, #
block_m_log2_start, SIZES=block_m_num, BLOCK_A=MEMSET_BLOCK_A, # optimization parameters
BLOCK_N=512, BLOCK_M=INDX_OFFS_BLOCK_M, # tunable parameters
)
indx_offs = partial_hist
_combined_routing_compute[(blocks2a + blocks2b, )](
topk_indx, gate_indx, gate_scal, # outputs
expt_scal, expt_indx, indx_offs, indx_offs.stride(0), indx_offs.stride(1), # inputs
expt_offs, n_tokens_raw, # input shape
HIST_BLOCK_M, n_expts_act, # constants
hist, token_offs_pad, token_offs_pad.stride(0), block_pid_map, block_pid_map.stride(0), # outputs
block_m_log2_start, block_m_num, HIST2_BLOCK_M, blocks2a, # etc.
)
ctx.n_tokens_raw = n_tokens_raw
ctx.n_tokens_pad = n_tokens_pad
ctx.n_expts_act = n_expts_act
ctx.save_for_backward(gate_indx)
return hist, topk_indx, gate_indx, gate_scal, token_offs_raw, token_offs_pad, block_pid_map
@staticmethod
def backward(ctx, _0, _1, _2, dgate_scal, _3, _4, _5):
(gate_indx, ) = ctx.saved_tensors
dgate_scal = dgate_scal[gate_indx]
dgate_scal = dgate_scal.reshape(ctx.n_tokens_pad, ctx.n_expts_act)
return dgate_scal, None, None, None
def sort_tokens(expt_scal, expt_indx, n_expts_tot, bitmatrix):
return SortTokens.apply(expt_scal, expt_indx, n_expts_tot, bitmatrix)
# --------------------------
# prune routing
# --------------------------
class PruneRouting(torch.autograd.Function):
@staticmethod
def forward(ctx, expt_scal, expt_indx, bitmatrix, n_expts_tot, simulated_ep):
from .compaction import compaction
n_tokens_pad = expt_scal.shape[0]
assert n_expts_tot % simulated_ep == 0
_routing_clear_bitmatrix[(n_tokens_pad, )](
bitmatrix.storage.data,
bitmatrix.storage.data.stride(0),
bitmatrix.storage.data.stride(1),
bitmatrix.storage.data.shape[1],
n_expts_tot // simulated_ep,
BLOCK_N=512,
)
# perform compaction to update expt_scal / expt_indx
expt_scal, expt_indx = compaction(expt_scal, expt_indx, bitmatrix)
n_expts_tot = n_expts_tot // simulated_ep
bitmatrix.shape[-1] = n_expts_tot
return expt_scal, expt_indx, bitmatrix
def prune_routing(expt_scal, expt_indx, bitmatrix, n_expts_tot, simulated_ep):
return PruneRouting.apply(expt_scal, expt_indx, bitmatrix, n_expts_tot, simulated_ep)
# --------------------------
# expt_data
# --------------------------
def log2_power_of_two(x):
assert x > 0 and (x & (x - 1)) == 0, "x must be a power of two"
return x.bit_length() - 1
block_m_log2_start = 4
def _compute_expt_data_internal(expt_hist, n_expts_tot, n_gates):
MEMSET_BLOCK = 512
HIST2_BLOCK_M = 512
device = expt_hist.device
n_expts_tot = n_expts_tot
cdiv = triton.cdiv
# block_ms are all powers-of-two between 16 and 128 (inclusive)
block_m_log2_end = 9 if is_hip() else 8
block_m_num = block_m_log2_end - block_m_log2_start
if n_gates <= n_expts_tot:
max_n_tiles = n_gates
else:
max_n_tiles = n_expts_tot - 1 - ((n_expts_tot - n_gates - 1) // 2**block_m_log2_start)
# allocate memory
pad = lambda x: cdiv(x, MEMSET_BLOCK) * MEMSET_BLOCK
dtype = torch.int32
token_offs_combined = torch.empty((block_m_num + 1, pad(n_expts_tot + 1)), dtype=dtype, device=device)
token_offs_raw = token_offs_combined[0][:n_expts_tot + 1]
token_offs_pad = token_offs_combined[1:]
block_pid_map = torch.empty((block_m_num, pad(max_n_tiles)), dtype=dtype, device=device)
memset_grid = torch.numel(block_pid_map) // MEMSET_BLOCK # exact division
# compute outputs
token_offs_pad = token_offs_pad[:, :n_expts_tot + 1]
block_pid_map = block_pid_map[:, :max_n_tiles]
blocks1 = memset_grid + block_m_num + 1
blocks2 = n_expts_tot * block_m_num
return token_offs_combined, token_offs_raw, token_offs_pad, block_pid_map, blocks1, blocks2, MEMSET_BLOCK, HIST2_BLOCK_M, block_m_log2_start, block_m_num
def _unpack_into_dict(x):
block_m_log2_end = block_m_log2_start + x.shape[0]
x = {2**j: x[i, :] for i, j in enumerate(range(block_m_log2_start, block_m_log2_end))}
return x
def compute_expt_data(expt_hist, n_expts_tot, n_gates):
if expt_hist is None:
return ExptData(None, None, None, None)
# this just computes the kernel arguments:
token_offs_combined, token_offs_raw, token_offs_pad, block_pid_map, blocks1, blocks2, MEMSET_BLOCK, HIST2_BLOCK_M, block_m_log2_start, block_m_num = _compute_expt_data_internal(
expt_hist, n_expts_tot, n_gates)
_expt_data_memset[(blocks1, )](
expt_hist, n_expts_tot, #
token_offs_combined, token_offs_combined.stride(0), #
block_pid_map, #
block_m_log2_start, SIZES=block_m_num, BLOCK=MEMSET_BLOCK, # optimization parameters
num_warps=4)
_expt_data_compute[(blocks2, )](
expt_hist, token_offs_pad, token_offs_pad.stride(0), block_pid_map, block_pid_map.stride(0), # outputs
block_m_log2_start, SIZES=block_m_num, BLOCK=HIST2_BLOCK_M, # optimization parameters
num_warps=4)
token_offs_pad = _unpack_into_dict(token_offs_pad)
block_pid_map = _unpack_into_dict(block_pid_map)
return ExptData(expt_hist, token_offs_raw, token_offs_pad, block_pid_map)
# --------------------------
# routing
# --------------------------
def routing_from_bitmatrix(bitmatrix, expt_scal, expt_indx, n_expts_tot, n_expts_act):
hist, topk_indx, gate_indx, gate_scal, token_offs_raw, token_offs_pad, block_pid_map = sort_tokens(
expt_scal, expt_indx, n_expts_tot, bitmatrix)
token_offs_pad = _unpack_into_dict(token_offs_pad)
block_pid_map = _unpack_into_dict(block_pid_map)
expt_data = ExptData(hist, token_offs_raw, token_offs_pad, block_pid_map)
# pack the matmul data structure
gather_indx = GatherIndx(src_indx=topk_indx, dst_indx=gate_indx)
scatter_indx = ScatterIndx(src_indx=gate_indx, dst_indx=topk_indx)
return RoutingData(gate_scal, hist, n_expts_tot, n_expts_act, expt_data), gather_indx, scatter_indx
def routing(logits, n_expts_act, sm_first=False, expt_indx=None, simulated_ep=1, n_rows=None):
from .topk import topk
if sm_first:
logits = torch.softmax(logits, dim=-1)
expt_scal, expt_indx, bitmatrix = topk(logits, n_expts_act, #
apply_softmax=not sm_first, y_indx=expt_indx, n_rows=n_rows)
n_expts_tot = logits.shape[-1] // simulated_ep
# mutate bitmatrix
if simulated_ep > 1:
expt_scal, expt_indx, bitmatrix = prune_routing(expt_scal, expt_indx, bitmatrix, logits.shape[-1], simulated_ep)
return routing_from_bitmatrix(bitmatrix, expt_scal, expt_indx, n_expts_tot, n_expts_act)
# --------------------------
# torch reference
# --------------------------
def compute_expt_data_torch(hist, n_expts_tot, n_gates):
# offset for each experts
device = hist.device
token_offs_raw = torch.cumsum(hist, dim=0)
token_offs_raw = torch.cat((torch.zeros(1, device=device), token_offs_raw))
token_offs_raw = token_offs_raw.int()
# maximum number of tiles for all values of `block_m` considered
block_ms = [16, 32, 64, 128]
if is_hip():
block_ms.append(256)
if n_gates <= n_expts_tot:
max_n_tiles = n_gates
else:
# ceil_div(n_gates - n_experts + 1, d_tile) + n_experts - 1
# ceil_div(x, y): -(-x // y)
max_n_tiles = n_expts_tot - 1 - ((n_expts_tot - n_gates - 1) // min(block_ms))
# fill up tile offset/infos for each block
token_offs_pad = dict()
block_pid_map = dict()
for block_m in block_ms:
n_tiles = (hist + block_m - 1) // block_m # matmul blocks needed
token_offs_pad[block_m] = torch.cumsum(n_tiles, dim=0)
token_offs_pad[block_m] = torch.cat((torch.zeros(1, device=device), token_offs_pad[block_m]))
token_offs_pad[block_m] = token_offs_pad[block_m].int()
# compute data required to drive ragged batch matmul
block_pid_map[block_m] = -torch.ones(max_n_tiles, dtype=torch.int32, device=device)
# for e in range(n_expts_tot):
# offset = token_offs_pad[block_m][e]
# for b in range(n_tiles[e]):
# block_pid_map[block_m][offset + b] = (b << 16) + e
col = torch.arange(max_n_tiles, device=device)
map_vals = torch.arange(n_expts_tot, device=device)[:, None] + (col << 16)[None, :]
map_idxs = token_offs_pad[block_m][:-1, None] + col[None, :]
mask = col[None, :] < n_tiles[:, None]
block_pid_map[block_m].index_put_((map_idxs[mask], ), map_vals.int()[mask])
return ExptData(hist, token_offs_raw, token_offs_pad, block_pid_map)
def topk_torch(vals, k, expt_indx, has_user_provided_indx=False):
# topk of experts
if has_user_provided_indx:
tk_indx = expt_indx
else:
tk_indx = torch.argsort(-vals, dim=1, stable=True)[:, :k]
tk_indx = tk_indx.long()
tk_val = torch.take_along_dim(vals, tk_indx, dim=1)
tk_indx = tk_indx.int()
return tk_val, tk_indx
def routing_torch(logits, n_expts_act, sm_first=False, expt_indx=None, n_rows=None):
has_user_provided_indx = expt_indx is not None
n_gates_pad = logits.shape[0] * n_expts_act
if n_rows is not None:
logits = logits[:n_rows, :]
_, n_expts_tot = logits.shape
if sm_first:
logits = torch.softmax(logits, dim=-1)
expt_scal, expt_indx = topk_torch(logits, n_expts_act, expt_indx, has_user_provided_indx=has_user_provided_indx)
if not sm_first:
expt_scal = torch.softmax(expt_scal, dim=-1)
# sort each token's selections by expert
if not has_user_provided_indx:
expt_indx, sort_indices = torch.sort(expt_indx, dim=1)
expt_scal = torch.gather(expt_scal, 1, sort_indices)
# flatten topk data
expt_scal = expt_scal.reshape(-1)
expt_indx = expt_indx.reshape(-1).to(torch.int32)
# sort by expert_id so experts are contiguous for the matmul
topk_indx = torch.argsort(expt_indx, stable=True)
gate_indx = torch.argsort(topk_indx, stable=True)
gate_scal = expt_scal[topk_indx]
hist = torch.histc(expt_indx, bins=n_expts_tot, max=n_expts_tot - 1).int() # histogram of tokens over experts
# pack the matmul data structure
gather_indx = GatherIndx(src_indx=topk_indx.int(), dst_indx=gate_indx.int())
scatter_indx = ScatterIndx(src_indx=gate_indx.int(), dst_indx=topk_indx.int())
# compute expt_data
expt_data = compute_expt_data_torch(hist, n_expts_tot, n_gates_pad)
return RoutingData(gate_scal, hist, n_expts_tot, n_expts_act, expt_data), gather_indx, scatter_indx