[TRTLLM-6263][feat] Enable fp8 SwiGLU to minimize host overhead (#6540)

Signed-off-by: Junyi Xu <junyix@nvidia.com>
This commit is contained in:
JunyiXu-nv 2025-08-06 10:42:19 +08:00 committed by GitHub
parent 9a01934dbf
commit 13e0214fe0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 94 additions and 12 deletions

View File

@ -7,22 +7,12 @@ from torch import nn
from tensorrt_llm.mapping import Mapping
from ..custom_ops import IS_FLASHINFER_AVAILABLE
from ..distributed import AllReduceParams
from ..model_config import ModelConfig
from ..peft.lora.layer import LoraLayer, LoraModuleType
from ..utils import Fp4QuantizedTensor
from .linear import Linear, TensorParallelMode, WeightMode, WeightsLoadingConfig
def swiglu(x):
if IS_FLASHINFER_AVAILABLE:
# WAR for flashinfer activation since it does not support custom op properly
from ..custom_ops import flashinfer_silu_and_mul
return flashinfer_silu_and_mul(x)
else:
gate, x = x.chunk(2, dim=-1)
return F.silu(gate) * x
from .swiglu import swiglu
class GatedMLP(nn.Module):
@ -107,7 +97,12 @@ class GatedMLP(nn.Module):
def _apply_activation(self, x):
if self.activation == F.silu:
return swiglu(x)
if self.down_proj.has_fp8_qdq:
return swiglu(x,
quant_scale=self.down_proj.input_scale,
quant_type=torch.float8_e4m3fn)
else:
return swiglu(x)
elif callable(self.activation):
return self.activation(x)
elif self.activation is None:

View File

@ -0,0 +1,87 @@
from collections.abc import Mapping
from typing import Optional
import torch
import triton # type: ignore[import]
import triton.language as tl # type: ignore[import]
@triton.jit
def scale_and_clamp(x, scale, dtype):
if dtype == tl.float8e4nv:
clamp_min = -448.0
clamp_max = 448.0
elif dtype == tl.float8e5:
clamp_min = -57344.0
clamp_max = 57344.0
elif dtype == tl.float16:
clamp_min = -65504.0
clamp_max = 65504.0
elif dtype == tl.bfloat16:
clamp_min = -3.3895313892515355e38
clamp_max = 3.3895313892515355e38
else:
tl.static_assert(False, f"Unsupported dtype: {dtype}")
return tl.clamp(x.to(tl.float32) / scale, clamp_min, clamp_max).to(dtype)
@triton.jit
def silu_and_mul_kernel(o_ptr, o_stride, o_scale_ptr, x_ptr, x_stride, d,
BLOCK_SIZE: tl.constexpr,
HAS_O_SCALE: tl.constexpr) -> None:
i = tl.program_id(axis=0).to(tl.int64)
j = tl.program_id(axis=1)
o_row_ptr = o_ptr + o_stride * i
x_row_ptr = x_ptr + x_stride * i
offsets = j * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = offsets < d
a = tl.load(x_row_ptr + offsets, mask=mask).to(tl.float32)
b = tl.load(x_row_ptr + offsets + d, mask=mask).to(tl.float32)
result = tl.sigmoid(a) * a * b
if HAS_O_SCALE:
o_scale = tl.load(o_scale_ptr)
result = scale_and_clamp(result, o_scale, o_ptr.dtype.element_ty)
tl.store(o_row_ptr + offsets, result, mask=mask)
def silu_and_mul(x: torch.Tensor,
scale: Optional[torch.Tensor] = None,
dtype: Optional[torch.dtype] = None) -> torch.Tensor:
b, n = x.shape
assert n % 2 == 0
d = n // 2
o_dtype = dtype or x.dtype
o = torch.empty((b, d), dtype=o_dtype, device=x.device)
def grid(meta: Mapping[str, int]) -> tuple[int, int]:
return (b, triton.cdiv(d, meta["BLOCK_SIZE"]))
silu_and_mul_kernel[grid](
o_ptr=o,
o_stride=o.stride(0),
o_scale_ptr=scale,
x_ptr=x,
x_stride=x.stride(0),
d=d,
BLOCK_SIZE=1024,
HAS_O_SCALE=scale is not None,
)
return o
def swiglu(x, quant_scale: torch.Tensor = None, quant_type=None):
if quant_scale is not None:
assert quant_type is not None
return silu_and_mul(x, scale=quant_scale, dtype=quant_type)
return silu_and_mul(x)