TensorRT-LLMs/tensorrt_llm/plugin/plugin.py
Kaiyu Xie 587d063e6d
Update TensorRT-LLM (#506)
* Update TensorRT-LLM

---------

Co-authored-by: Shixiaowei02 <39303645+Shixiaowei02@users.noreply.github.com>
2023-11-30 16:46:22 +08:00

183 lines
6.2 KiB
Python

# SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import ctypes
import platform
from enum import IntEnum
from pathlib import Path
from tensorrt_llm.logger import logger
TRT_LLM_PLUGIN_NAMESPACE = 'tensorrt_llm'
def plugin_lib_path() -> str:
project_dir = Path(__file__).parent.parent.absolute()
dyn_lib = "libnvinfer_plugin_tensorrt_llm.so" if platform.system(
) != "Windows" else "nvinfer_plugin_tensorrt_llm.dll"
return str(project_dir.joinpath("libs", dyn_lib))
def _load_plugin_lib():
winmode = 0 if platform.system() == "Windows" else None
handle = ctypes.CDLL(plugin_lib_path(),
mode=ctypes.RTLD_GLOBAL,
winmode=winmode)
try:
handle.initTrtLlmPlugins.argtypes = [ctypes.c_void_p, ctypes.c_char_p]
handle.initTrtLlmPlugins.restype = ctypes.c_bool
except AttributeError as err:
raise ImportError('TensorRT-LLM Plugin is unavailable') from err
assert handle.initTrtLlmPlugins(None,
TRT_LLM_PLUGIN_NAMESPACE.encode('utf-8'))
class ContextFMHAType(IntEnum):
disabled = 0
# FP16 I/O, FP16 Accumulation
enabled = 1
# FP16 I/O, FP32 Accumulation
enabled_with_fp32_acc = 2
class PluginConfig(object):
def __init__(self) -> None:
self.init()
def init(self):
self.bert_attention_plugin = False
self.gpt_attention_plugin = False
self.multi_block_mode = False
self.identity_plugin = False
self.gemm_plugin = False
self.smooth_quant_gemm_plugin = False
self.layernorm_plugin = False
self.layernorm_quantization_plugin = False
self.rmsnorm_plugin = False
self.rmsnorm_quantization_plugin = False
self.attention_qk_half_accumulation = False
self.remove_input_padding = False
self.context_fmha_type = ContextFMHAType.disabled
self.weight_only_groupwise_quant_matmul_plugin = False
self.weight_only_quant_matmul_plugin = False
self.nccl_plugin = False
self.use_custom_all_reduce = False
self.quantize_per_token_plugin = False
self.quantize_tensor_plugin = False
self.paged_kv_cache = False
self.tokens_per_block = 0
self.lookup_plugin = False
self.lora_plugin = False
def enable_qk_half_accum(self):
self.attention_qk_half_accumulation = True
logger.info(f"Attention BMM1(QK) accumulation type is set to FP16")
return self
def set_context_fmha(self, context_fmha_type=ContextFMHAType.enabled):
assert context_fmha_type in \
[ContextFMHAType.disabled, ContextFMHAType.enabled, ContextFMHAType.enabled_with_fp32_acc]
self.context_fmha_type = context_fmha_type
if context_fmha_type == ContextFMHAType.enabled:
logger.info(f"Context FMHA Enabled")
elif context_fmha_type == ContextFMHAType.enabled_with_fp32_acc:
logger.info(f"Context FMHA with FP32 Accumulation Enabled")
elif context_fmha_type == ContextFMHAType.disabled:
logger.info(f"Context FMHA Disabled")
return self
def enable_remove_input_padding(self):
self.remove_input_padding = True
logger.info(f"Remove Padding Enabled")
return self
def enable_paged_kv_cache(self, tokens_per_block=64):
self.paged_kv_cache = True
self.tokens_per_block = tokens_per_block
logger.info(f"Paged KV Cache Enabled")
return self
def set_gpt_attention_plugin(self, dtype='float16'):
self.gpt_attention_plugin = dtype
return self
def enable_mmha_multi_block_mode(self):
self.multi_block_mode = True
logger.info(f"Generation Multi Block Mode Enabled")
return self
def set_bert_attention_plugin(self, dtype='float16'):
self.bert_attention_plugin = dtype
return self
def set_identity_plugin(self, dtype='float16'):
self.identity_plugin = dtype
return self
def set_gemm_plugin(self, dtype='float16'):
self.gemm_plugin = dtype
return self
def set_smooth_quant_gemm_plugin(self, dtype='float16'):
self.smooth_quant_gemm_plugin = dtype
return self
def set_layernorm_plugin(self, dtype='float16'):
self.layernorm_plugin = dtype
return self
def set_layernorm_quantization_plugin(self, dtype='float16'):
self.layernorm_quantization_plugin = dtype
return self
def set_rmsnorm_plugin(self, dtype='float16'):
self.rmsnorm_plugin = dtype
return self
def set_rmsnorm_quantization_plugin(self, dtype='float16'):
self.rmsnorm_quantization_plugin = dtype
return self
def set_weight_only_quant_matmul_plugin(self, dtype='float16'):
self.weight_only_quant_matmul_plugin = dtype
return self
def set_weight_only_groupwise_quant_matmul_plugin(self, dtype='float16'):
self.weight_only_groupwise_quant_matmul_plugin = dtype
return self
def set_nccl_plugin(self,
dtype='float16',
use_custom_all_reduce: bool = False):
self.use_custom_all_reduce = use_custom_all_reduce
self.nccl_plugin = dtype
return self
def set_quantize_per_token_plugin(self):
self.quantize_per_token_plugin = True
return self
def set_quantize_tensor_plugin(self):
self.quantize_tensor_plugin = True
return self
def set_lookup_plugin(self, dtype='float16'):
self.lookup_plugin = dtype
return self
def set_lora_plugin(self, dtype='float16'):
self.lora_plugin = dtype
return self