mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
* Update TensorRT-LLM --------- Co-authored-by: Shixiaowei02 <39303645+Shixiaowei02@users.noreply.github.com>
183 lines
6.2 KiB
Python
183 lines
6.2 KiB
Python
# SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
import ctypes
|
|
import platform
|
|
from enum import IntEnum
|
|
from pathlib import Path
|
|
|
|
from tensorrt_llm.logger import logger
|
|
|
|
TRT_LLM_PLUGIN_NAMESPACE = 'tensorrt_llm'
|
|
|
|
|
|
def plugin_lib_path() -> str:
|
|
project_dir = Path(__file__).parent.parent.absolute()
|
|
dyn_lib = "libnvinfer_plugin_tensorrt_llm.so" if platform.system(
|
|
) != "Windows" else "nvinfer_plugin_tensorrt_llm.dll"
|
|
return str(project_dir.joinpath("libs", dyn_lib))
|
|
|
|
|
|
def _load_plugin_lib():
|
|
winmode = 0 if platform.system() == "Windows" else None
|
|
handle = ctypes.CDLL(plugin_lib_path(),
|
|
mode=ctypes.RTLD_GLOBAL,
|
|
winmode=winmode)
|
|
try:
|
|
handle.initTrtLlmPlugins.argtypes = [ctypes.c_void_p, ctypes.c_char_p]
|
|
handle.initTrtLlmPlugins.restype = ctypes.c_bool
|
|
except AttributeError as err:
|
|
raise ImportError('TensorRT-LLM Plugin is unavailable') from err
|
|
assert handle.initTrtLlmPlugins(None,
|
|
TRT_LLM_PLUGIN_NAMESPACE.encode('utf-8'))
|
|
|
|
|
|
class ContextFMHAType(IntEnum):
|
|
disabled = 0
|
|
# FP16 I/O, FP16 Accumulation
|
|
enabled = 1
|
|
# FP16 I/O, FP32 Accumulation
|
|
enabled_with_fp32_acc = 2
|
|
|
|
|
|
class PluginConfig(object):
|
|
|
|
def __init__(self) -> None:
|
|
self.init()
|
|
|
|
def init(self):
|
|
self.bert_attention_plugin = False
|
|
self.gpt_attention_plugin = False
|
|
self.multi_block_mode = False
|
|
self.identity_plugin = False
|
|
self.gemm_plugin = False
|
|
self.smooth_quant_gemm_plugin = False
|
|
self.layernorm_plugin = False
|
|
self.layernorm_quantization_plugin = False
|
|
self.rmsnorm_plugin = False
|
|
self.rmsnorm_quantization_plugin = False
|
|
self.attention_qk_half_accumulation = False
|
|
self.remove_input_padding = False
|
|
self.context_fmha_type = ContextFMHAType.disabled
|
|
self.weight_only_groupwise_quant_matmul_plugin = False
|
|
self.weight_only_quant_matmul_plugin = False
|
|
self.nccl_plugin = False
|
|
self.use_custom_all_reduce = False
|
|
self.quantize_per_token_plugin = False
|
|
self.quantize_tensor_plugin = False
|
|
self.paged_kv_cache = False
|
|
self.tokens_per_block = 0
|
|
self.lookup_plugin = False
|
|
self.lora_plugin = False
|
|
|
|
def enable_qk_half_accum(self):
|
|
self.attention_qk_half_accumulation = True
|
|
logger.info(f"Attention BMM1(QK) accumulation type is set to FP16")
|
|
return self
|
|
|
|
def set_context_fmha(self, context_fmha_type=ContextFMHAType.enabled):
|
|
assert context_fmha_type in \
|
|
[ContextFMHAType.disabled, ContextFMHAType.enabled, ContextFMHAType.enabled_with_fp32_acc]
|
|
self.context_fmha_type = context_fmha_type
|
|
if context_fmha_type == ContextFMHAType.enabled:
|
|
logger.info(f"Context FMHA Enabled")
|
|
elif context_fmha_type == ContextFMHAType.enabled_with_fp32_acc:
|
|
logger.info(f"Context FMHA with FP32 Accumulation Enabled")
|
|
elif context_fmha_type == ContextFMHAType.disabled:
|
|
logger.info(f"Context FMHA Disabled")
|
|
return self
|
|
|
|
def enable_remove_input_padding(self):
|
|
self.remove_input_padding = True
|
|
logger.info(f"Remove Padding Enabled")
|
|
return self
|
|
|
|
def enable_paged_kv_cache(self, tokens_per_block=64):
|
|
self.paged_kv_cache = True
|
|
self.tokens_per_block = tokens_per_block
|
|
logger.info(f"Paged KV Cache Enabled")
|
|
return self
|
|
|
|
def set_gpt_attention_plugin(self, dtype='float16'):
|
|
self.gpt_attention_plugin = dtype
|
|
return self
|
|
|
|
def enable_mmha_multi_block_mode(self):
|
|
self.multi_block_mode = True
|
|
logger.info(f"Generation Multi Block Mode Enabled")
|
|
return self
|
|
|
|
def set_bert_attention_plugin(self, dtype='float16'):
|
|
self.bert_attention_plugin = dtype
|
|
return self
|
|
|
|
def set_identity_plugin(self, dtype='float16'):
|
|
self.identity_plugin = dtype
|
|
return self
|
|
|
|
def set_gemm_plugin(self, dtype='float16'):
|
|
self.gemm_plugin = dtype
|
|
return self
|
|
|
|
def set_smooth_quant_gemm_plugin(self, dtype='float16'):
|
|
self.smooth_quant_gemm_plugin = dtype
|
|
return self
|
|
|
|
def set_layernorm_plugin(self, dtype='float16'):
|
|
self.layernorm_plugin = dtype
|
|
return self
|
|
|
|
def set_layernorm_quantization_plugin(self, dtype='float16'):
|
|
self.layernorm_quantization_plugin = dtype
|
|
return self
|
|
|
|
def set_rmsnorm_plugin(self, dtype='float16'):
|
|
self.rmsnorm_plugin = dtype
|
|
return self
|
|
|
|
def set_rmsnorm_quantization_plugin(self, dtype='float16'):
|
|
self.rmsnorm_quantization_plugin = dtype
|
|
return self
|
|
|
|
def set_weight_only_quant_matmul_plugin(self, dtype='float16'):
|
|
self.weight_only_quant_matmul_plugin = dtype
|
|
return self
|
|
|
|
def set_weight_only_groupwise_quant_matmul_plugin(self, dtype='float16'):
|
|
self.weight_only_groupwise_quant_matmul_plugin = dtype
|
|
return self
|
|
|
|
def set_nccl_plugin(self,
|
|
dtype='float16',
|
|
use_custom_all_reduce: bool = False):
|
|
self.use_custom_all_reduce = use_custom_all_reduce
|
|
self.nccl_plugin = dtype
|
|
return self
|
|
|
|
def set_quantize_per_token_plugin(self):
|
|
self.quantize_per_token_plugin = True
|
|
return self
|
|
|
|
def set_quantize_tensor_plugin(self):
|
|
self.quantize_tensor_plugin = True
|
|
return self
|
|
|
|
def set_lookup_plugin(self, dtype='float16'):
|
|
self.lookup_plugin = dtype
|
|
return self
|
|
|
|
def set_lora_plugin(self, dtype='float16'):
|
|
self.lora_plugin = dtype
|
|
return self
|