[#9236][feature] Make sharing of activation_type across SW layers more robust (#9238)

C++, Python and Python MoE layer all share the definition of ActivationType.
Currently this is done thru redefinition which is fragile and can break when adding new activation function types.

tensorrt_llm/_torch/utils.py
cpp/tensorrt_llm/kernels/cutlass_kernels/include/common.h
=>
tensorrt_llm/layers/moe.py
cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu

Signed-off-by: Kaiyu Xie <26294424+kaiyux@users.noreply.github.com>
Signed-off-by: Neta Zmora <96238833+nzmora-nvidia@users.noreply.github.com>
Co-authored-by: Kaiyu Xie <26294424+kaiyux@users.noreply.github.com>
This commit is contained in:
Neta Zmora 2025-11-20 10:06:58 +02:00 committed by GitHub
parent b018b2698d
commit 1d6fbbf45d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 100 additions and 67 deletions

View File

@ -536,8 +536,8 @@ void help()
"- \"num_tokens\" - The total number of tokens to benchmark\n"
"- \"bias\" - If bias should be used, 0 = no bias, 1 = bias\n"
"- \"do_final_scale\" - If final scales should be applied, 0 = no scale, 1 = scale\n"
"- \"act_fn\" - The activation function to use, 0 = identity, 1 = relu, 2 = gelu, 3 = silu, 4 = geglu, 5 = "
"swiglu\n"
"- \"act_fn\" - The activation function to use, 1 = identity, 2 = gelu, 3 = relu, 4 = silu, 5 = swiglu, 6 = "
"geglu, 7 = swiglu_bias, 8 = relu2\n"
"- \"tactic_id1, tactic_id2\"\n"
"The config for the CUTLASS GEMM. tactic_idX sets the tactic for the corresponding GEMM"
"Valid tactics are:\n"

View File

@ -31,6 +31,7 @@ namespace
{
using ElemCopyType = uint4;
using SFCopyType = uint32_t;
using ActivationType = tensorrt_llm::kernels::cutlass_kernels::ActivationType;
template <typename T>
auto constexpr bitsPerElem()
@ -385,23 +386,43 @@ void moeActivation(InputType const* input, OutputType* output, float const* glob
int32_t const blocks = std::min(smCount, max_num_permuted_tokens);
int32_t const threads = kThreadsPerBlock;
auto kernel_array
= std::array{&moeActivationKernel<InputType, OutputType, SFType, kSFVecSize,
cutlass_kernels::IdentityAdaptor<cutlass::epilogue::thread::GELU>, kThreadsPerBlock>,
&moeActivationKernel<InputType, OutputType, SFType, kSFVecSize,
cutlass_kernels::IdentityAdaptor<cutlass::epilogue::thread::ReLu>, kThreadsPerBlock>,
&moeActivationKernel<InputType, OutputType, SFType, kSFVecSize,
cutlass_kernels::IdentityAdaptor<cutlass::epilogue::thread::SiLu>, kThreadsPerBlock>,
&moeActivationKernel<InputType, OutputType, SFType, kSFVecSize,
cutlass_kernels::GLUAdaptor<cutlass::epilogue::thread::SiLu>, kThreadsPerBlock>,
&moeActivationKernel<InputType, OutputType, SFType, kSFVecSize,
cutlass_kernels::GLUAdaptor<cutlass::epilogue::thread::GELU>, kThreadsPerBlock>,
&moeActivationKernel<InputType, OutputType, SFType, kSFVecSize, cutlass_kernels::SwigluBiasAdaptor,
kThreadsPerBlock>,
&moeActivationKernel<InputType, OutputType, SFType, kSFVecSize,
cutlass_kernels::IdentityAdaptor<cutlass::epilogue::thread::Identity>, kThreadsPerBlock>};
auto kernel = kernel_array[static_cast<int32_t>(activation_params.activation_type)];
auto get_act_kernel = [](ActivationType activation_type) -> void (*)(InputType const* input, OutputType* output,
float const* global_sf, SFType* output_sf,
int32_t const* tile_idx_to_mn_limit,
int32_t const* num_non_exiting_tiles,
int32_t const interm_size, int32_t const tile_size)
{
switch (activation_type)
{
case ActivationType::Identity:
return &moeActivationKernel<InputType, OutputType, SFType, kSFVecSize,
cutlass_kernels::IdentityAdaptor<cutlass::epilogue::thread::Identity>, kThreadsPerBlock>;
case ActivationType::Gelu:
return &moeActivationKernel<InputType, OutputType, SFType, kSFVecSize,
cutlass_kernels::IdentityAdaptor<cutlass::epilogue::thread::GELU>, kThreadsPerBlock>;
case ActivationType::Geglu:
return &moeActivationKernel<InputType, OutputType, SFType, kSFVecSize,
cutlass_kernels::GLUAdaptor<cutlass::epilogue::thread::GELU>, kThreadsPerBlock>;
case ActivationType::Relu:
return &moeActivationKernel<InputType, OutputType, SFType, kSFVecSize,
cutlass_kernels::IdentityAdaptor<cutlass::epilogue::thread::ReLu>, kThreadsPerBlock>;
case ActivationType::Silu:
return &moeActivationKernel<InputType, OutputType, SFType, kSFVecSize,
cutlass_kernels::IdentityAdaptor<cutlass::epilogue::thread::SiLu>, kThreadsPerBlock>;
case ActivationType::Swiglu:
return &moeActivationKernel<InputType, OutputType, SFType, kSFVecSize,
cutlass_kernels::GLUAdaptor<cutlass::epilogue::thread::SiLu>, kThreadsPerBlock>;
case ActivationType::SwigluBias:
return &moeActivationKernel<InputType, OutputType, SFType, kSFVecSize, cutlass_kernels::SwigluBiasAdaptor,
kThreadsPerBlock>;
case ActivationType::Relu2:
// Unsupported activation type
break;
}
TLLM_CHECK_WITH_INFO(false, "Unsupported activation type: %d", int(activation_type));
return nullptr;
};
auto kernel = get_act_kernel(activation_params.activation_type);
cudaLaunchConfig_t config;
config.gridDim = blocks;

View File

@ -23,15 +23,15 @@ namespace tensorrt_llm::kernels::cutlass_kernels
// cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu::doActivationKernel().
enum class ActivationType
{
Gelu = 0,
Relu,
Silu,
Swiglu,
Geglu,
SwigluBias,
Identity,
Relu2,
InvalidType
InvalidType = 0,
Identity = 1,
Gelu = 2,
Relu = 3,
Silu = 4,
Swiglu = 5,
Geglu = 6,
SwigluBias = 7,
Relu2 = 8,
};
} // namespace tensorrt_llm::kernels::cutlass_kernels

View File

@ -2244,29 +2244,39 @@ void doActivation(T* output, GemmOutputType const* gemm_result, float const* fp8
{
// IMPORTANT: Keep the order of the activation functions in the same order as the ActivationType enum in
// common.h
auto fn = [&](auto block_scaling_type)
auto fn
= [&](auto block_scaling_type) -> void (*)(T*, GemmOutputType const*, float const*, ScaleBiasType const*,
bool, int64_t const*, int, int64_t, float const*, bool,
TmaWarpSpecializedGroupedGemmInput::ElementSF*, ActivationParams)
{
auto fn_list = std::array{
&doActivationKernel<T, GemmOutputType, ScaleBiasType, IdentityAdaptor<cutlass::epilogue::thread::GELU>,
decltype(block_scaling_type)::value>, // Gelu
&doActivationKernel<T, GemmOutputType, ScaleBiasType, IdentityAdaptor<cutlass::epilogue::thread::ReLu>,
decltype(block_scaling_type)::value>, // Relu
&doActivationKernel<T, GemmOutputType, ScaleBiasType, IdentityAdaptor<cutlass::epilogue::thread::SiLu>,
decltype(block_scaling_type)::value>, // Silu
&doActivationKernel<T, GemmOutputType, ScaleBiasType, GLUAdaptor<cutlass::epilogue::thread::SiLu>,
decltype(block_scaling_type)::value>, // Swiglu
&doActivationKernel<T, GemmOutputType, ScaleBiasType, GLUAdaptor<cutlass::epilogue::thread::GELU>,
decltype(block_scaling_type)::value>, // Geglu
&doActivationKernel<T, GemmOutputType, ScaleBiasType, SwigluBiasAdaptor,
decltype(block_scaling_type)::value>, // SwigluBias
&doActivationKernel<T, GemmOutputType, ScaleBiasType,
IdentityAdaptor<cutlass::epilogue::thread::Identity>,
decltype(block_scaling_type)::value>, // Identity
&doActivationKernel<T, GemmOutputType, ScaleBiasType, IdentityAdaptor<cutlass::epilogue::thread::Relu2>,
decltype(block_scaling_type)::value> // Relu2
};
return fn_list[static_cast<int>(activation_type.activation_type)];
switch (activation_type.activation_type)
{
case ActivationType::Identity:
return &doActivationKernel<T, GemmOutputType, ScaleBiasType,
IdentityAdaptor<cutlass::epilogue::thread::Identity>, decltype(block_scaling_type)::value>;
case ActivationType::Gelu:
return &doActivationKernel<T, GemmOutputType, ScaleBiasType,
IdentityAdaptor<cutlass::epilogue::thread::GELU>, decltype(block_scaling_type)::value>;
case ActivationType::Relu:
return &doActivationKernel<T, GemmOutputType, ScaleBiasType,
IdentityAdaptor<cutlass::epilogue::thread::ReLu>, decltype(block_scaling_type)::value>;
case ActivationType::Silu:
return &doActivationKernel<T, GemmOutputType, ScaleBiasType,
IdentityAdaptor<cutlass::epilogue::thread::SiLu>, decltype(block_scaling_type)::value>;
case ActivationType::Swiglu:
return &doActivationKernel<T, GemmOutputType, ScaleBiasType,
GLUAdaptor<cutlass::epilogue::thread::SiLu>, decltype(block_scaling_type)::value>;
case ActivationType::Geglu:
return &doActivationKernel<T, GemmOutputType, ScaleBiasType,
GLUAdaptor<cutlass::epilogue::thread::GELU>, decltype(block_scaling_type)::value>;
case ActivationType::SwigluBias:
return &doActivationKernel<T, GemmOutputType, ScaleBiasType, SwigluBiasAdaptor,
decltype(block_scaling_type)::value>;
case ActivationType::Relu2:
return &doActivationKernel<T, GemmOutputType, ScaleBiasType,
IdentityAdaptor<cutlass::epilogue::thread::Relu2>, decltype(block_scaling_type)::value>;
default: TLLM_CHECK_WITH_INFO(false, "Invalid activation type"); return nullptr;
}
};
auto NVFP4 = tensorrt_llm::common::ConstExprWrapper<TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType,
TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::NVFP4>{};

View File

@ -34,15 +34,15 @@ EventType = Enum(
# IMPORTANT: Keep the same order of activation functions in this enum and the enum in
# cpp/tensorrt_llm/kernels/cutlass_kernels/include/common.h
class ActivationType(IntEnum):
Gelu = 0
Relu = 1
Silu = 2
Swiglu = 3
Geglu = 4
SwigluBias = 5
Identity = 6
Relu2 = 7
InvalidType = 8
InvalidType = 0
Identity = 1
Gelu = 2
Relu = 3
Silu = 4
Swiglu = 5
Geglu = 6
SwigluBias = 7
Relu2 = 8
def set_torch_compiling(enable: bool):

View File

@ -20,6 +20,7 @@ import numpy as np
import tensorrt as trt
import torch
from tensorrt_llm._torch.utils import ActivationType
from tensorrt_llm._utils import (get_init_params, str_dtype_to_torch,
str_dtype_to_trt)
from tensorrt_llm.layers.lora import LoraParams
@ -49,14 +50,15 @@ from .mlp import MLP, GatedMLP
activation_str_to_int_map = {
# [WARNING] Keep the below in sync with cpp/tensorrt_llm/kernels/cutlass_kernels/include/common.h
"gelu": 0,
"gelu_new": 0,
"relu": 1,
"silu": 2,
"swiglu": 3,
"geglu": 4,
"swiglu_bias": 5,
"identity": 6,
"gelu": int(ActivationType.Gelu),
"gelu_new": int(ActivationType.Gelu),
"relu": int(ActivationType.Relu),
"silu": int(ActivationType.Silu),
"swiglu": int(ActivationType.Swiglu),
"geglu": int(ActivationType.Geglu),
"swiglu_bias": int(ActivationType.SwigluBias),
"identity": int(ActivationType.Identity),
"relu2": int(ActivationType.Relu2),
}