From c47ff4da43300b8c41a055df7348c788f25a2932 Mon Sep 17 00:00:00 2001 From: Guoming Zhang <137257613+nv-guomingz@users.noreply.github.com> Date: Wed, 11 Feb 2026 21:50:45 +0800 Subject: [PATCH] =?UTF-8?q?[None][feat]=20Remove=20the=20hard=20code=20for?= =?UTF-8?q?=20activation=20type=20definition=20in=20T=E2=80=A6=20(#11164)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: nv-guomingz <137257613+nv-guomingz@users.noreply.github.com> --- tensorrt_llm/_torch/custom_ops/trtllm_gen_custom_ops.py | 5 +++-- tensorrt_llm/_torch/utils.py | 7 +++++++ 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/tensorrt_llm/_torch/custom_ops/trtllm_gen_custom_ops.py b/tensorrt_llm/_torch/custom_ops/trtllm_gen_custom_ops.py index 53d3ae73a8..ae1b952030 100644 --- a/tensorrt_llm/_torch/custom_ops/trtllm_gen_custom_ops.py +++ b/tensorrt_llm/_torch/custom_ops/trtllm_gen_custom_ops.py @@ -6,7 +6,8 @@ import torch from tensorrt_llm._torch.modules.fused_moe.routing import ( ROUTING_METHOD_TYPE_TO_CLASS, RoutingMethodType) -from tensorrt_llm._torch.utils import (Fp4QuantizedTensor, fp4_utils, +from tensorrt_llm._torch.utils import (ActType_TrtllmGen, Fp4QuantizedTensor, + fp4_utils, get_last_power_of_2_num_tokens_buckets, last_positive_power_of_2, next_positive_power_of_2) @@ -393,7 +394,7 @@ def fp4_block_scale_moe_runner( routed_scaling_factor: Optional[float], routing_method_type: int, do_finalize: bool, - act_type: int = 0, + act_type: int = ActType_TrtllmGen.SwiGlu.value, topk_weights: Optional[torch.Tensor] = None, topk_ids: Optional[torch.Tensor] = None) -> List[torch.Tensor]: diff --git a/tensorrt_llm/_torch/utils.py b/tensorrt_llm/_torch/utils.py index e182eee32c..3c243346bb 100644 --- a/tensorrt_llm/_torch/utils.py +++ b/tensorrt_llm/_torch/utils.py @@ -49,6 +49,13 @@ class ActivationType(IntEnum): Relu2 = 8 +# Keep this in sync with the ActType enum in +# cpp/tensorrt_llm/kernels/trtllmGenKernels/batchedGemm/KernelRunner.h +class ActType_TrtllmGen(IntEnum): + SwiGlu = 0 + Relu2 = 1 + + # IMPORTANT: when adding a new activation type, please update this function. # And make sure it aligned with cpp/tensorrt_llm/kernels/cutlass_kernels/include/moe_gemm_kernels.h::isGatedActivation function. def is_gated_activation(activation_type: ActivationType) -> bool: