mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-05 18:51:38 +08:00
105 lines
3.3 KiB
Python
105 lines
3.3 KiB
Python
# This file is vendored from the Triton project. DO NOT EDIT THIS FILE DIRECTLY.
|
|
# Source: https://github.com/triton-lang/triton/tree/v3.5.1/python/triton_kernels/triton_kernels/swiglu.py
|
|
# Triton is licensed under the MIT License.
|
|
|
|
from dataclasses import dataclass
|
|
from triton_kernels.numerics import InFlexData, OutFlexData
|
|
import torch
|
|
import triton
|
|
from .swiglu_details._swiglu import _swiglu, _swiglu_fn
|
|
from triton_kernels import target_info
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class FlexCtx:
|
|
out_data: OutFlexData = OutFlexData()
|
|
inp_data: InFlexData = InFlexData()
|
|
saturate_inf: bool = False
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class PrecisionConfig:
|
|
limit: float
|
|
flex_ctx: FlexCtx = FlexCtx()
|
|
|
|
|
|
swiglu_fn = _swiglu_fn
|
|
|
|
|
|
class SwiGLU(torch.autograd.Function):
|
|
|
|
@staticmethod
|
|
def forward(ctx, a, alpha, precision_config, routing_data):
|
|
N = a.shape[-1]
|
|
M = a.numel() // N
|
|
assert a.stride()[-1] == 1
|
|
assert a.shape[-1] % 2 == 0
|
|
out = torch.empty(size=(M, N // 2), dtype=a.dtype, device=a.device)
|
|
flex_ctx = precision_config.flex_ctx
|
|
# optimization hyperparameters
|
|
BLOCK_M, BLOCK_N = 32 // a.itemsize, 128
|
|
num_warps = 4
|
|
kwargs = {'maxnreg': 64} if not target_info.is_hip() else {}
|
|
# launch semi-persistent kernel
|
|
N_BLOCKS = triton.cdiv(N // 2, BLOCK_N)
|
|
num_sms = target_info.num_sms()
|
|
if routing_data is not None:
|
|
waves_per_sm = 32 if target_info.is_hip() else 128
|
|
num_pid = num_sms * (waves_per_sm // num_warps)
|
|
M_BLOCKS = max(1, triton.cdiv(num_pid, N_BLOCKS))
|
|
grid = (min(M_BLOCKS * N_BLOCKS, 4 * num_sms), )
|
|
else:
|
|
M_BLOCKS = triton.cdiv(M, BLOCK_M)
|
|
if M_BLOCKS * N_BLOCKS >= 8 * num_sms:
|
|
grid = (8 * num_sms, )
|
|
else:
|
|
grid = (min(M_BLOCKS * N_BLOCKS, 4 * num_sms), )
|
|
n_tokens = None
|
|
if routing_data is not None:
|
|
n_tokens = routing_data.expt_data.token_offs_raw[routing_data.n_expts_tot]
|
|
_swiglu[grid](
|
|
flex_ctx.out_data.reinterpret(out),
|
|
flex_ctx.out_data.expected_scale,
|
|
flex_ctx.out_data.actual_scale,
|
|
flex_ctx.out_data.checksum_scale,
|
|
flex_ctx.inp_data.reinterpret(a),
|
|
flex_ctx.inp_data.scale,
|
|
alpha,
|
|
M,
|
|
N // 2,
|
|
a.shape[-1],
|
|
1,
|
|
out.shape[-1],
|
|
1,
|
|
precision_config.limit,
|
|
n_tokens,
|
|
BLOCK_M=BLOCK_M,
|
|
BLOCK_N=BLOCK_N,
|
|
EVEN_N=(N // 2) % BLOCK_N == 0,
|
|
M_BLOCKS=M_BLOCKS,
|
|
N_BLOCKS=N_BLOCKS,
|
|
flexpoint_saturate_inf=flex_ctx.saturate_inf,
|
|
num_warps=num_warps,
|
|
**kwargs,
|
|
)
|
|
out = out.view(a.shape[:-1] + out.shape[-1:])
|
|
return out
|
|
|
|
|
|
def swiglu(a, alpha, precision_config, routing_data=None):
|
|
return SwiGLU.apply(a, alpha, precision_config, routing_data)
|
|
|
|
|
|
def swiglu_torch(a, alpha, precision_config):
|
|
limit = precision_config.limit
|
|
a_gelu = a[..., ::2]
|
|
if limit is not None:
|
|
a_gelu = a_gelu.clamp(max=limit)
|
|
a_linear = a[..., 1::2]
|
|
if limit is not None:
|
|
a_linear = a_linear.clamp(min=-limit, max=limit)
|
|
|
|
out_gelu = a_gelu * torch.sigmoid(alpha * a_gelu)
|
|
out = out_gelu * (a_linear + 1)
|
|
return out
|