mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
C++, Python and Python MoE layer all share the definition of ActivationType. Currently this is done thru redefinition which is fragile and can break when adding new activation function types. tensorrt_llm/_torch/utils.py cpp/tensorrt_llm/kernels/cutlass_kernels/include/common.h => tensorrt_llm/layers/moe.py cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu Signed-off-by: Kaiyu Xie <26294424+kaiyux@users.noreply.github.com> Signed-off-by: Neta Zmora <96238833+nzmora-nvidia@users.noreply.github.com> Co-authored-by: Kaiyu Xie <26294424+kaiyux@users.noreply.github.com>
This commit is contained in:
parent
b018b2698d
commit
1d6fbbf45d
@ -536,8 +536,8 @@ void help()
|
||||
"- \"num_tokens\" - The total number of tokens to benchmark\n"
|
||||
"- \"bias\" - If bias should be used, 0 = no bias, 1 = bias\n"
|
||||
"- \"do_final_scale\" - If final scales should be applied, 0 = no scale, 1 = scale\n"
|
||||
"- \"act_fn\" - The activation function to use, 0 = identity, 1 = relu, 2 = gelu, 3 = silu, 4 = geglu, 5 = "
|
||||
"swiglu\n"
|
||||
"- \"act_fn\" - The activation function to use, 1 = identity, 2 = gelu, 3 = relu, 4 = silu, 5 = swiglu, 6 = "
|
||||
"geglu, 7 = swiglu_bias, 8 = relu2\n"
|
||||
"- \"tactic_id1, tactic_id2\"\n"
|
||||
"The config for the CUTLASS GEMM. tactic_idX sets the tactic for the corresponding GEMM"
|
||||
"Valid tactics are:\n"
|
||||
|
||||
@ -31,6 +31,7 @@ namespace
|
||||
{
|
||||
using ElemCopyType = uint4;
|
||||
using SFCopyType = uint32_t;
|
||||
using ActivationType = tensorrt_llm::kernels::cutlass_kernels::ActivationType;
|
||||
|
||||
template <typename T>
|
||||
auto constexpr bitsPerElem()
|
||||
@ -385,23 +386,43 @@ void moeActivation(InputType const* input, OutputType* output, float const* glob
|
||||
int32_t const blocks = std::min(smCount, max_num_permuted_tokens);
|
||||
int32_t const threads = kThreadsPerBlock;
|
||||
|
||||
auto kernel_array
|
||||
= std::array{&moeActivationKernel<InputType, OutputType, SFType, kSFVecSize,
|
||||
cutlass_kernels::IdentityAdaptor<cutlass::epilogue::thread::GELU>, kThreadsPerBlock>,
|
||||
&moeActivationKernel<InputType, OutputType, SFType, kSFVecSize,
|
||||
cutlass_kernels::IdentityAdaptor<cutlass::epilogue::thread::ReLu>, kThreadsPerBlock>,
|
||||
&moeActivationKernel<InputType, OutputType, SFType, kSFVecSize,
|
||||
cutlass_kernels::IdentityAdaptor<cutlass::epilogue::thread::SiLu>, kThreadsPerBlock>,
|
||||
&moeActivationKernel<InputType, OutputType, SFType, kSFVecSize,
|
||||
cutlass_kernels::GLUAdaptor<cutlass::epilogue::thread::SiLu>, kThreadsPerBlock>,
|
||||
&moeActivationKernel<InputType, OutputType, SFType, kSFVecSize,
|
||||
cutlass_kernels::GLUAdaptor<cutlass::epilogue::thread::GELU>, kThreadsPerBlock>,
|
||||
&moeActivationKernel<InputType, OutputType, SFType, kSFVecSize, cutlass_kernels::SwigluBiasAdaptor,
|
||||
kThreadsPerBlock>,
|
||||
&moeActivationKernel<InputType, OutputType, SFType, kSFVecSize,
|
||||
cutlass_kernels::IdentityAdaptor<cutlass::epilogue::thread::Identity>, kThreadsPerBlock>};
|
||||
|
||||
auto kernel = kernel_array[static_cast<int32_t>(activation_params.activation_type)];
|
||||
auto get_act_kernel = [](ActivationType activation_type) -> void (*)(InputType const* input, OutputType* output,
|
||||
float const* global_sf, SFType* output_sf,
|
||||
int32_t const* tile_idx_to_mn_limit,
|
||||
int32_t const* num_non_exiting_tiles,
|
||||
int32_t const interm_size, int32_t const tile_size)
|
||||
{
|
||||
switch (activation_type)
|
||||
{
|
||||
case ActivationType::Identity:
|
||||
return &moeActivationKernel<InputType, OutputType, SFType, kSFVecSize,
|
||||
cutlass_kernels::IdentityAdaptor<cutlass::epilogue::thread::Identity>, kThreadsPerBlock>;
|
||||
case ActivationType::Gelu:
|
||||
return &moeActivationKernel<InputType, OutputType, SFType, kSFVecSize,
|
||||
cutlass_kernels::IdentityAdaptor<cutlass::epilogue::thread::GELU>, kThreadsPerBlock>;
|
||||
case ActivationType::Geglu:
|
||||
return &moeActivationKernel<InputType, OutputType, SFType, kSFVecSize,
|
||||
cutlass_kernels::GLUAdaptor<cutlass::epilogue::thread::GELU>, kThreadsPerBlock>;
|
||||
case ActivationType::Relu:
|
||||
return &moeActivationKernel<InputType, OutputType, SFType, kSFVecSize,
|
||||
cutlass_kernels::IdentityAdaptor<cutlass::epilogue::thread::ReLu>, kThreadsPerBlock>;
|
||||
case ActivationType::Silu:
|
||||
return &moeActivationKernel<InputType, OutputType, SFType, kSFVecSize,
|
||||
cutlass_kernels::IdentityAdaptor<cutlass::epilogue::thread::SiLu>, kThreadsPerBlock>;
|
||||
case ActivationType::Swiglu:
|
||||
return &moeActivationKernel<InputType, OutputType, SFType, kSFVecSize,
|
||||
cutlass_kernels::GLUAdaptor<cutlass::epilogue::thread::SiLu>, kThreadsPerBlock>;
|
||||
case ActivationType::SwigluBias:
|
||||
return &moeActivationKernel<InputType, OutputType, SFType, kSFVecSize, cutlass_kernels::SwigluBiasAdaptor,
|
||||
kThreadsPerBlock>;
|
||||
case ActivationType::Relu2:
|
||||
// Unsupported activation type
|
||||
break;
|
||||
}
|
||||
TLLM_CHECK_WITH_INFO(false, "Unsupported activation type: %d", int(activation_type));
|
||||
return nullptr;
|
||||
};
|
||||
auto kernel = get_act_kernel(activation_params.activation_type);
|
||||
|
||||
cudaLaunchConfig_t config;
|
||||
config.gridDim = blocks;
|
||||
|
||||
@ -23,15 +23,15 @@ namespace tensorrt_llm::kernels::cutlass_kernels
|
||||
// cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu::doActivationKernel().
|
||||
enum class ActivationType
|
||||
{
|
||||
Gelu = 0,
|
||||
Relu,
|
||||
Silu,
|
||||
Swiglu,
|
||||
Geglu,
|
||||
SwigluBias,
|
||||
Identity,
|
||||
Relu2,
|
||||
InvalidType
|
||||
InvalidType = 0,
|
||||
Identity = 1,
|
||||
Gelu = 2,
|
||||
Relu = 3,
|
||||
Silu = 4,
|
||||
Swiglu = 5,
|
||||
Geglu = 6,
|
||||
SwigluBias = 7,
|
||||
Relu2 = 8,
|
||||
};
|
||||
|
||||
} // namespace tensorrt_llm::kernels::cutlass_kernels
|
||||
|
||||
@ -2244,29 +2244,39 @@ void doActivation(T* output, GemmOutputType const* gemm_result, float const* fp8
|
||||
{
|
||||
// IMPORTANT: Keep the order of the activation functions in the same order as the ActivationType enum in
|
||||
// common.h
|
||||
auto fn = [&](auto block_scaling_type)
|
||||
auto fn
|
||||
= [&](auto block_scaling_type) -> void (*)(T*, GemmOutputType const*, float const*, ScaleBiasType const*,
|
||||
bool, int64_t const*, int, int64_t, float const*, bool,
|
||||
TmaWarpSpecializedGroupedGemmInput::ElementSF*, ActivationParams)
|
||||
{
|
||||
auto fn_list = std::array{
|
||||
&doActivationKernel<T, GemmOutputType, ScaleBiasType, IdentityAdaptor<cutlass::epilogue::thread::GELU>,
|
||||
decltype(block_scaling_type)::value>, // Gelu
|
||||
&doActivationKernel<T, GemmOutputType, ScaleBiasType, IdentityAdaptor<cutlass::epilogue::thread::ReLu>,
|
||||
decltype(block_scaling_type)::value>, // Relu
|
||||
&doActivationKernel<T, GemmOutputType, ScaleBiasType, IdentityAdaptor<cutlass::epilogue::thread::SiLu>,
|
||||
decltype(block_scaling_type)::value>, // Silu
|
||||
&doActivationKernel<T, GemmOutputType, ScaleBiasType, GLUAdaptor<cutlass::epilogue::thread::SiLu>,
|
||||
decltype(block_scaling_type)::value>, // Swiglu
|
||||
&doActivationKernel<T, GemmOutputType, ScaleBiasType, GLUAdaptor<cutlass::epilogue::thread::GELU>,
|
||||
decltype(block_scaling_type)::value>, // Geglu
|
||||
&doActivationKernel<T, GemmOutputType, ScaleBiasType, SwigluBiasAdaptor,
|
||||
decltype(block_scaling_type)::value>, // SwigluBias
|
||||
&doActivationKernel<T, GemmOutputType, ScaleBiasType,
|
||||
IdentityAdaptor<cutlass::epilogue::thread::Identity>,
|
||||
decltype(block_scaling_type)::value>, // Identity
|
||||
&doActivationKernel<T, GemmOutputType, ScaleBiasType, IdentityAdaptor<cutlass::epilogue::thread::Relu2>,
|
||||
decltype(block_scaling_type)::value> // Relu2
|
||||
|
||||
};
|
||||
return fn_list[static_cast<int>(activation_type.activation_type)];
|
||||
switch (activation_type.activation_type)
|
||||
{
|
||||
case ActivationType::Identity:
|
||||
return &doActivationKernel<T, GemmOutputType, ScaleBiasType,
|
||||
IdentityAdaptor<cutlass::epilogue::thread::Identity>, decltype(block_scaling_type)::value>;
|
||||
case ActivationType::Gelu:
|
||||
return &doActivationKernel<T, GemmOutputType, ScaleBiasType,
|
||||
IdentityAdaptor<cutlass::epilogue::thread::GELU>, decltype(block_scaling_type)::value>;
|
||||
case ActivationType::Relu:
|
||||
return &doActivationKernel<T, GemmOutputType, ScaleBiasType,
|
||||
IdentityAdaptor<cutlass::epilogue::thread::ReLu>, decltype(block_scaling_type)::value>;
|
||||
case ActivationType::Silu:
|
||||
return &doActivationKernel<T, GemmOutputType, ScaleBiasType,
|
||||
IdentityAdaptor<cutlass::epilogue::thread::SiLu>, decltype(block_scaling_type)::value>;
|
||||
case ActivationType::Swiglu:
|
||||
return &doActivationKernel<T, GemmOutputType, ScaleBiasType,
|
||||
GLUAdaptor<cutlass::epilogue::thread::SiLu>, decltype(block_scaling_type)::value>;
|
||||
case ActivationType::Geglu:
|
||||
return &doActivationKernel<T, GemmOutputType, ScaleBiasType,
|
||||
GLUAdaptor<cutlass::epilogue::thread::GELU>, decltype(block_scaling_type)::value>;
|
||||
case ActivationType::SwigluBias:
|
||||
return &doActivationKernel<T, GemmOutputType, ScaleBiasType, SwigluBiasAdaptor,
|
||||
decltype(block_scaling_type)::value>;
|
||||
case ActivationType::Relu2:
|
||||
return &doActivationKernel<T, GemmOutputType, ScaleBiasType,
|
||||
IdentityAdaptor<cutlass::epilogue::thread::Relu2>, decltype(block_scaling_type)::value>;
|
||||
default: TLLM_CHECK_WITH_INFO(false, "Invalid activation type"); return nullptr;
|
||||
}
|
||||
};
|
||||
auto NVFP4 = tensorrt_llm::common::ConstExprWrapper<TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType,
|
||||
TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::NVFP4>{};
|
||||
|
||||
@ -34,15 +34,15 @@ EventType = Enum(
|
||||
# IMPORTANT: Keep the same order of activation functions in this enum and the enum in
|
||||
# cpp/tensorrt_llm/kernels/cutlass_kernels/include/common.h
|
||||
class ActivationType(IntEnum):
|
||||
Gelu = 0
|
||||
Relu = 1
|
||||
Silu = 2
|
||||
Swiglu = 3
|
||||
Geglu = 4
|
||||
SwigluBias = 5
|
||||
Identity = 6
|
||||
Relu2 = 7
|
||||
InvalidType = 8
|
||||
InvalidType = 0
|
||||
Identity = 1
|
||||
Gelu = 2
|
||||
Relu = 3
|
||||
Silu = 4
|
||||
Swiglu = 5
|
||||
Geglu = 6
|
||||
SwigluBias = 7
|
||||
Relu2 = 8
|
||||
|
||||
|
||||
def set_torch_compiling(enable: bool):
|
||||
|
||||
@ -20,6 +20,7 @@ import numpy as np
|
||||
import tensorrt as trt
|
||||
import torch
|
||||
|
||||
from tensorrt_llm._torch.utils import ActivationType
|
||||
from tensorrt_llm._utils import (get_init_params, str_dtype_to_torch,
|
||||
str_dtype_to_trt)
|
||||
from tensorrt_llm.layers.lora import LoraParams
|
||||
@ -49,14 +50,15 @@ from .mlp import MLP, GatedMLP
|
||||
|
||||
activation_str_to_int_map = {
|
||||
# [WARNING] Keep the below in sync with cpp/tensorrt_llm/kernels/cutlass_kernels/include/common.h
|
||||
"gelu": 0,
|
||||
"gelu_new": 0,
|
||||
"relu": 1,
|
||||
"silu": 2,
|
||||
"swiglu": 3,
|
||||
"geglu": 4,
|
||||
"swiglu_bias": 5,
|
||||
"identity": 6,
|
||||
"gelu": int(ActivationType.Gelu),
|
||||
"gelu_new": int(ActivationType.Gelu),
|
||||
"relu": int(ActivationType.Relu),
|
||||
"silu": int(ActivationType.Silu),
|
||||
"swiglu": int(ActivationType.Swiglu),
|
||||
"geglu": int(ActivationType.Geglu),
|
||||
"swiglu_bias": int(ActivationType.SwigluBias),
|
||||
"identity": int(ActivationType.Identity),
|
||||
"relu2": int(ActivationType.Relu2),
|
||||
}
|
||||
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user