mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-16 07:53:55 +08:00
[None][feat] Remove the hard code for activation type definition in T… (#11164)
Signed-off-by: nv-guomingz <137257613+nv-guomingz@users.noreply.github.com>
This commit is contained in:
parent
eed9c16560
commit
c47ff4da43
@ -6,7 +6,8 @@ import torch
|
||||
|
||||
from tensorrt_llm._torch.modules.fused_moe.routing import (
|
||||
ROUTING_METHOD_TYPE_TO_CLASS, RoutingMethodType)
|
||||
from tensorrt_llm._torch.utils import (Fp4QuantizedTensor, fp4_utils,
|
||||
from tensorrt_llm._torch.utils import (ActType_TrtllmGen, Fp4QuantizedTensor,
|
||||
fp4_utils,
|
||||
get_last_power_of_2_num_tokens_buckets,
|
||||
last_positive_power_of_2,
|
||||
next_positive_power_of_2)
|
||||
@ -393,7 +394,7 @@ def fp4_block_scale_moe_runner(
|
||||
routed_scaling_factor: Optional[float],
|
||||
routing_method_type: int,
|
||||
do_finalize: bool,
|
||||
act_type: int = 0,
|
||||
act_type: int = ActType_TrtllmGen.SwiGlu.value,
|
||||
topk_weights: Optional[torch.Tensor] = None,
|
||||
topk_ids: Optional[torch.Tensor] = None) -> List[torch.Tensor]:
|
||||
|
||||
|
||||
@ -49,6 +49,13 @@ class ActivationType(IntEnum):
|
||||
Relu2 = 8
|
||||
|
||||
|
||||
# Keep this in sync with the ActType enum in
|
||||
# cpp/tensorrt_llm/kernels/trtllmGenKernels/batchedGemm/KernelRunner.h
|
||||
class ActType_TrtllmGen(IntEnum):
|
||||
SwiGlu = 0
|
||||
Relu2 = 1
|
||||
|
||||
|
||||
# IMPORTANT: when adding a new activation type, please update this function.
|
||||
# And make sure it aligned with cpp/tensorrt_llm/kernels/cutlass_kernels/include/moe_gemm_kernels.h::isGatedActivation function.
|
||||
def is_gated_activation(activation_type: ActivationType) -> bool:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user