# SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import ctypes import platform from enum import IntEnum from pathlib import Path from tensorrt_llm.logger import logger TRT_LLM_PLUGIN_NAMESPACE = 'tensorrt_llm' def plugin_lib_path() -> str: project_dir = Path(__file__).parent.parent.absolute() dyn_lib = "libnvinfer_plugin_tensorrt_llm.so" if platform.system( ) != "Windows" else "nvinfer_plugin_tensorrt_llm.dll" return str(project_dir.joinpath("libs", dyn_lib)) def _load_plugin_lib(): winmode = 0 if platform.system() == "Windows" else None handle = ctypes.CDLL(plugin_lib_path(), mode=ctypes.RTLD_GLOBAL, winmode=winmode) try: handle.initTrtLlmPlugins.argtypes = [ctypes.c_void_p, ctypes.c_char_p] handle.initTrtLlmPlugins.restype = ctypes.c_bool except AttributeError as err: raise ImportError('TensorRT-LLM Plugin is unavailable') from err assert handle.initTrtLlmPlugins(None, TRT_LLM_PLUGIN_NAMESPACE.encode('utf-8')) class ContextFMHAType(IntEnum): disabled = 0 # FP16 I/O, FP16 Accumulation enabled = 1 # FP16 I/O, FP32 Accumulation enabled_with_fp32_acc = 2 class PluginConfig(object): def __init__(self) -> None: self.init() def init(self): self.bert_attention_plugin = False self.gpt_attention_plugin = False self.multi_block_mode = False self.identity_plugin = False self.gemm_plugin = False self.smooth_quant_gemm_plugin = False self.layernorm_plugin = False self.layernorm_quantization_plugin = False self.rmsnorm_plugin = False self.rmsnorm_quantization_plugin = False self.attention_qk_half_accumulation = False self.remove_input_padding = False self.context_fmha_type = ContextFMHAType.disabled self.weight_only_groupwise_quant_matmul_plugin = False self.weight_only_quant_matmul_plugin = False self.nccl_plugin = False self.use_custom_all_reduce = False self.quantize_per_token_plugin = False self.quantize_tensor_plugin = False self.paged_kv_cache = False self.tokens_per_block = 0 self.lookup_plugin = False self.lora_plugin = False self.use_paged_context_fmha = False self.use_context_fmha_for_generation = False def enable_qk_half_accum(self): self.attention_qk_half_accumulation = True logger.info(f"Attention BMM1(QK) accumulation type is set to FP16") return self def set_context_fmha(self, context_fmha_type=ContextFMHAType.enabled): assert context_fmha_type in \ [ContextFMHAType.disabled, ContextFMHAType.enabled, ContextFMHAType.enabled_with_fp32_acc] self.context_fmha_type = context_fmha_type if context_fmha_type == ContextFMHAType.enabled: logger.info(f"Context FMHA Enabled") elif context_fmha_type == ContextFMHAType.enabled_with_fp32_acc: logger.info(f"Context FMHA with FP32 Accumulation Enabled") elif context_fmha_type == ContextFMHAType.disabled: logger.info(f"Context FMHA Disabled") return self def enable_remove_input_padding(self): self.remove_input_padding = True logger.info(f"Remove Padding Enabled") return self def enable_paged_kv_cache(self, tokens_per_block=64): self.paged_kv_cache = True self.tokens_per_block = tokens_per_block logger.info(f"Paged KV Cache Enabled") return self def set_gpt_attention_plugin(self, dtype='float16'): self.gpt_attention_plugin = dtype return self def enable_mmha_multi_block_mode(self): self.multi_block_mode = True logger.info(f"Generation Multi Block Mode Enabled") return self def set_bert_attention_plugin(self, dtype='float16'): self.bert_attention_plugin = dtype return self def set_identity_plugin(self, dtype='float16'): self.identity_plugin = dtype return self def set_gemm_plugin(self, dtype='float16'): self.gemm_plugin = dtype return self def set_smooth_quant_gemm_plugin(self, dtype='float16'): self.smooth_quant_gemm_plugin = dtype return self def set_layernorm_plugin(self, dtype='float16'): self.layernorm_plugin = dtype return self def set_layernorm_quantization_plugin(self, dtype='float16'): self.layernorm_quantization_plugin = dtype return self def set_rmsnorm_plugin(self, dtype='float16'): self.rmsnorm_plugin = dtype return self def set_rmsnorm_quantization_plugin(self, dtype='float16'): self.rmsnorm_quantization_plugin = dtype return self def set_weight_only_quant_matmul_plugin(self, dtype='float16'): self.weight_only_quant_matmul_plugin = dtype return self def set_weight_only_groupwise_quant_matmul_plugin(self, dtype='float16'): self.weight_only_groupwise_quant_matmul_plugin = dtype return self def set_nccl_plugin(self, dtype='float16', use_custom_all_reduce: bool = False): self.use_custom_all_reduce = use_custom_all_reduce self.nccl_plugin = dtype return self def set_quantize_per_token_plugin(self): self.quantize_per_token_plugin = True return self def set_quantize_tensor_plugin(self): self.quantize_tensor_plugin = True return self def set_lookup_plugin(self, dtype='float16'): self.lookup_plugin = dtype return self def set_lora_plugin(self, dtype='float16'): self.lora_plugin = dtype return self def set_paged_context_fmha(self): self.use_paged_context_fmha = True return self def set_context_fmha_for_generation(self): self.use_context_fmha_for_generation = True return self