TensorRT-LLMs/tests/unittest/others/test_plugins.py
Kaiyu Xie 2631f21089
Update (#2978)
Signed-off-by: Kaiyu Xie <26294424+kaiyux@users.noreply.github.com>
2025-03-23 16:39:35 +08:00

71 lines
2.5 KiB
Python

import random
import pytest
import tensorrt as _trt
import tensorrt_llm.plugin as _tlp
def test_load_library():
"""Test loading the TensorRT-LLM plugin library."""
runtime = _trt.Runtime(_trt.Logger(_trt.Logger.WARNING))
_trt.init_libnvinfer_plugins(runtime.logger,
namespace=_tlp.TRT_LLM_PLUGIN_NAMESPACE)
registry = runtime.get_plugin_registry()
handle = registry.load_library(_tlp.plugin_lib_path())
creators = registry.all_creators
# This will give all plugins statically registered in getCreators (only V3 plugins for now)
assert len(creators) > 0
for creator in creators:
assert creator.plugin_namespace == _tlp.TRT_LLM_PLUGIN_NAMESPACE
registry.deregister_library(handle)
assert len(registry.plugin_creator_list) == 0
@pytest.mark.parametrize('dtype', ['float16', 'bfloat16', 'float32'])
def test_plugin_config(dtype):
plugin_config = _tlp.PluginConfig.from_dict({'dtype': dtype})
assert plugin_config.dtype == dtype
assert plugin_config._gpt_attention_plugin == 'auto'
assert plugin_config.gpt_attention_plugin == dtype
assert plugin_config._nccl_plugin == 'auto'
assert plugin_config.nccl_plugin == dtype
assert plugin_config._gemm_plugin is None
assert plugin_config.gemm_plugin is None
new_dtype_options = ['float16', 'bfloat16', 'float32']
new_dtype_options.remove(dtype)
new_dtype = random.choice(new_dtype_options)
assert new_dtype != dtype
plugin_config.dtype = new_dtype
plugin_config.gpt_attention_plugin = dtype
plugin_config.gemm_plugin = 'auto'
assert plugin_config._gpt_attention_plugin == dtype
assert plugin_config.gpt_attention_plugin == dtype
assert plugin_config._nccl_plugin == 'auto'
assert plugin_config.nccl_plugin == new_dtype
assert plugin_config._gemm_plugin == 'auto'
assert plugin_config.gemm_plugin == new_dtype
with pytest.raises(Exception):
plugin_config.dtype = None
with pytest.raises(Exception):
plugin_config.dtype = 'auto'
with pytest.raises(Exception):
plugin_config.dtype = 'xyz'
with pytest.raises(Exception):
plugin_config.gpt_attention_plugin = 'abc'
with pytest.raises(Exception):
plugin_config.nccl_plugin = 123
with pytest.raises(Exception):
plugin_config.a_new_xxx_plugin = 'float16'
config_dict = plugin_config.to_dict()
new_plugin_config = _tlp.PluginConfig.from_dict(config_dict)
assert config_dict == new_plugin_config.to_dict()
plugin_config.to_legacy_setting()