diff --git a/tensorrt_llm/_torch/custom_ops/trtllm_gen_custom_ops.py b/tensorrt_llm/_torch/custom_ops/trtllm_gen_custom_ops.py index 53d3ae73a8..ae1b952030 100644 --- a/tensorrt_llm/_torch/custom_ops/trtllm_gen_custom_ops.py +++ b/tensorrt_llm/_torch/custom_ops/trtllm_gen_custom_ops.py @@ -6,7 +6,8 @@ import torch from tensorrt_llm._torch.modules.fused_moe.routing import ( ROUTING_METHOD_TYPE_TO_CLASS, RoutingMethodType) -from tensorrt_llm._torch.utils import (Fp4QuantizedTensor, fp4_utils, +from tensorrt_llm._torch.utils import (ActType_TrtllmGen, Fp4QuantizedTensor, + fp4_utils, get_last_power_of_2_num_tokens_buckets, last_positive_power_of_2, next_positive_power_of_2) @@ -393,7 +394,7 @@ def fp4_block_scale_moe_runner( routed_scaling_factor: Optional[float], routing_method_type: int, do_finalize: bool, - act_type: int = 0, + act_type: int = ActType_TrtllmGen.SwiGlu.value, topk_weights: Optional[torch.Tensor] = None, topk_ids: Optional[torch.Tensor] = None) -> List[torch.Tensor]: diff --git a/tensorrt_llm/_torch/utils.py b/tensorrt_llm/_torch/utils.py index e182eee32c..3c243346bb 100644 --- a/tensorrt_llm/_torch/utils.py +++ b/tensorrt_llm/_torch/utils.py @@ -49,6 +49,13 @@ class ActivationType(IntEnum): Relu2 = 8 +# Keep this in sync with the ActType enum in +# cpp/tensorrt_llm/kernels/trtllmGenKernels/batchedGemm/KernelRunner.h +class ActType_TrtllmGen(IntEnum): + SwiGlu = 0 + Relu2 = 1 + + # IMPORTANT: when adding a new activation type, please update this function. # And make sure it aligned with cpp/tensorrt_llm/kernels/cutlass_kernels/include/moe_gemm_kernels.h::isGatedActivation function. def is_gated_activation(activation_type: ActivationType) -> bool: