TensorRT-LLMs/tensorrt_llm/bench/dataclasses/enums.py
Kaiyu Xie 385626572d
Update TensorRT-LLM (#2502)
* Update TensorRT-LLM

---------

Co-authored-by: 岑灿 <yunyi.hyy@alibaba-inc.com>
2024-11-26 16:51:34 +08:00

83 lines
2.8 KiB
Python

from __future__ import annotations
from typing import List
from aenum import MultiValueEnum
from tensorrt_llm.bindings.executor import CapacitySchedulerPolicy
from tensorrt_llm.quantization.mode import QuantAlgo
NO_EVICT = "Guaranteed No Evict"
MAX_UTIL = "Max Utilization"
class ModelArchitecture(MultiValueEnum):
LLAMA = "LlamaForCausalLM"
GPTJ = "GPTJForCausalLM"
GEMMA = "GemmaForCausalLM"
BLOOM = "BloomForCausalLM"
OPT = "OPTForCausalLM"
MIXTRAL = "MixtralForCausalLM"
FALCON = "FalconForCausalLM"
class ResultsSchedulingPolicy(MultiValueEnum):
MAX_UTILIZTION = MAX_UTIL, CapacitySchedulerPolicy.MAX_UTILIZATION
NO_EVICT = NO_EVICT, CapacitySchedulerPolicy.GUARANTEED_NO_EVICT
STATIC = "Static"
class IFBSchedulingPolicy(MultiValueEnum):
MAX_UTILIZTION = CapacitySchedulerPolicy.MAX_UTILIZATION, MAX_UTIL, "max_utilization"
NO_EVICT = CapacitySchedulerPolicy.GUARANTEED_NO_EVICT, NO_EVICT, "guaranteed_no_evict"
STATIC = "Static", "static"
class KVCacheDtypeEnum(MultiValueEnum):
"""Enumeration of KV Cache precisions in TRT-LLM."""
FP8 = "FP8", "fp8", "float8"
FP16 = None, "FP16", "fp16", "float16"
INT8 = "INT8", "int8"
def get_build_options(self, dtype: str) -> List[str]:
"""Get the build options for TRT-LLM based on KV Cache precision.
Args:
dtype (str): The activation dtype for the model. This
parameter maps the activation dtype for GEMM plugins for certain
KV cache precisions.
Returns:
List[str]: A list of command line arguments to be added to build
commands.
"""
if not self.value == self.FP8:
return ["--gemm_plugin", dtype]
class ComputeDtypeEnum(MultiValueEnum):
"""Enumeration for activation data type."""
# FLOAT32 = "float32", "fp32", "FP32"
FLOAT16 = "float16", "FLOAT16", "fp16", "FP16"
BFLOAT16 = "bfloat16", "BFLOAT16", "bf16", "bfp16", "BF16"
# TODO: use quantization.mode.QuantAlgo eventually
class QuantizationAlgo(MultiValueEnum):
"""Enumerated type for quantization algorithms for string mapping."""
W8A16 = QuantAlgo.W8A16.value
W4A16 = QuantAlgo.W4A16.value
W4A16_AWQ = QuantAlgo.W4A16_AWQ.value
W4A8_AWQ = QuantAlgo.W4A8_AWQ.value
W4A16_GPTQ = QuantAlgo.W4A16_GPTQ.value
FP8 = QuantAlgo.FP8.value
INT8 = QuantAlgo.INT8.value
W8A8_SQ_PER_CHANNEL = QuantAlgo.W8A8_SQ_PER_CHANNEL.value
W8A8_SQ_PER_TENSOR_PLUGIN = QuantAlgo.W8A8_SQ_PER_TENSOR_PLUGIN.value
W8A8_SQ_PER_CHANNEL_PER_TOKEN_PLUGIN = QuantAlgo.W8A8_SQ_PER_CHANNEL_PER_TOKEN_PLUGIN.value
W8A8_SQ_PER_CHANNEL_PER_TENSOR_PLUGIN = QuantAlgo.W8A8_SQ_PER_CHANNEL_PER_TENSOR_PLUGIN.value
W8A8_SQ_PER_TENSOR_PER_TOKEN_PLUGIN = QuantAlgo.W8A8_SQ_PER_TENSOR_PER_TOKEN_PLUGIN.value
NONE = None, "None", "FP16", "BF16"