TensorRT-LLMs/triton_kernels/swiglu.py
Anish Shanbhag 24ac86c485
[https://nvbugs/5761391][fix] Include triton-kernels as a packaged dependency (#10471)
Signed-off-by: Anish Shanbhag <ashanbhag@nvidia.com>
2026-01-28 19:56:32 -08:00

105 lines
3.3 KiB
Python

# This file is vendored from the Triton project. DO NOT EDIT THIS FILE DIRECTLY.
# Source: https://github.com/triton-lang/triton/tree/v3.5.1/python/triton_kernels/triton_kernels/swiglu.py
# Triton is licensed under the MIT License.
from dataclasses import dataclass
from triton_kernels.numerics import InFlexData, OutFlexData
import torch
import triton
from .swiglu_details._swiglu import _swiglu, _swiglu_fn
from triton_kernels import target_info
@dataclass(frozen=True)
class FlexCtx:
out_data: OutFlexData = OutFlexData()
inp_data: InFlexData = InFlexData()
saturate_inf: bool = False
@dataclass(frozen=True)
class PrecisionConfig:
limit: float
flex_ctx: FlexCtx = FlexCtx()
swiglu_fn = _swiglu_fn
class SwiGLU(torch.autograd.Function):
@staticmethod
def forward(ctx, a, alpha, precision_config, routing_data):
N = a.shape[-1]
M = a.numel() // N
assert a.stride()[-1] == 1
assert a.shape[-1] % 2 == 0
out = torch.empty(size=(M, N // 2), dtype=a.dtype, device=a.device)
flex_ctx = precision_config.flex_ctx
# optimization hyperparameters
BLOCK_M, BLOCK_N = 32 // a.itemsize, 128
num_warps = 4
kwargs = {'maxnreg': 64} if not target_info.is_hip() else {}
# launch semi-persistent kernel
N_BLOCKS = triton.cdiv(N // 2, BLOCK_N)
num_sms = target_info.num_sms()
if routing_data is not None:
waves_per_sm = 32 if target_info.is_hip() else 128
num_pid = num_sms * (waves_per_sm // num_warps)
M_BLOCKS = max(1, triton.cdiv(num_pid, N_BLOCKS))
grid = (min(M_BLOCKS * N_BLOCKS, 4 * num_sms), )
else:
M_BLOCKS = triton.cdiv(M, BLOCK_M)
if M_BLOCKS * N_BLOCKS >= 8 * num_sms:
grid = (8 * num_sms, )
else:
grid = (min(M_BLOCKS * N_BLOCKS, 4 * num_sms), )
n_tokens = None
if routing_data is not None:
n_tokens = routing_data.expt_data.token_offs_raw[routing_data.n_expts_tot]
_swiglu[grid](
flex_ctx.out_data.reinterpret(out),
flex_ctx.out_data.expected_scale,
flex_ctx.out_data.actual_scale,
flex_ctx.out_data.checksum_scale,
flex_ctx.inp_data.reinterpret(a),
flex_ctx.inp_data.scale,
alpha,
M,
N // 2,
a.shape[-1],
1,
out.shape[-1],
1,
precision_config.limit,
n_tokens,
BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N,
EVEN_N=(N // 2) % BLOCK_N == 0,
M_BLOCKS=M_BLOCKS,
N_BLOCKS=N_BLOCKS,
flexpoint_saturate_inf=flex_ctx.saturate_inf,
num_warps=num_warps,
**kwargs,
)
out = out.view(a.shape[:-1] + out.shape[-1:])
return out
def swiglu(a, alpha, precision_config, routing_data=None):
return SwiGLU.apply(a, alpha, precision_config, routing_data)
def swiglu_torch(a, alpha, precision_config):
limit = precision_config.limit
a_gelu = a[..., ::2]
if limit is not None:
a_gelu = a_gelu.clamp(max=limit)
a_linear = a[..., 1::2]
if limit is not None:
a_linear = a_linear.clamp(min=-limit, max=limit)
out_gelu = a_gelu * torch.sigmoid(alpha * a_gelu)
out = out_gelu * (a_linear + 1)
return out