[None][feat] Remove the hard code for activation type definition in T… (#11164)

Signed-off-by: nv-guomingz <137257613+nv-guomingz@users.noreply.github.com>
This commit is contained in:
Guoming Zhang 2026-02-11 21:50:45 +08:00 committed by GitHub
parent eed9c16560
commit c47ff4da43
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 10 additions and 2 deletions

View File

@ -6,7 +6,8 @@ import torch
from tensorrt_llm._torch.modules.fused_moe.routing import (
ROUTING_METHOD_TYPE_TO_CLASS, RoutingMethodType)
from tensorrt_llm._torch.utils import (Fp4QuantizedTensor, fp4_utils,
from tensorrt_llm._torch.utils import (ActType_TrtllmGen, Fp4QuantizedTensor,
fp4_utils,
get_last_power_of_2_num_tokens_buckets,
last_positive_power_of_2,
next_positive_power_of_2)
@ -393,7 +394,7 @@ def fp4_block_scale_moe_runner(
routed_scaling_factor: Optional[float],
routing_method_type: int,
do_finalize: bool,
act_type: int = 0,
act_type: int = ActType_TrtllmGen.SwiGlu.value,
topk_weights: Optional[torch.Tensor] = None,
topk_ids: Optional[torch.Tensor] = None) -> List[torch.Tensor]:

View File

@ -49,6 +49,13 @@ class ActivationType(IntEnum):
Relu2 = 8
# Keep this in sync with the ActType enum in
# cpp/tensorrt_llm/kernels/trtllmGenKernels/batchedGemm/KernelRunner.h
class ActType_TrtllmGen(IntEnum):
SwiGlu = 0
Relu2 = 1
# IMPORTANT: when adding a new activation type, please update this function.
# And make sure it aligned with cpp/tensorrt_llm/kernels/cutlass_kernels/include/moe_gemm_kernels.h::isGatedActivation function.
def is_gated_activation(activation_type: ActivationType) -> bool: