mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
71 lines
2.5 KiB
Python
71 lines
2.5 KiB
Python
import random
|
|
|
|
import pytest
|
|
import tensorrt as _trt
|
|
|
|
import tensorrt_llm.plugin as _tlp
|
|
|
|
|
|
def test_load_library():
|
|
"""Test loading the TensorRT-LLM plugin library."""
|
|
runtime = _trt.Runtime(_trt.Logger(_trt.Logger.WARNING))
|
|
_trt.init_libnvinfer_plugins(runtime.logger,
|
|
namespace=_tlp.TRT_LLM_PLUGIN_NAMESPACE)
|
|
registry = runtime.get_plugin_registry()
|
|
handle = registry.load_library(_tlp.plugin_lib_path())
|
|
creators = registry.all_creators
|
|
# This will give all plugins statically registered in getCreators (only V3 plugins for now)
|
|
assert len(creators) > 0
|
|
for creator in creators:
|
|
assert creator.plugin_namespace == _tlp.TRT_LLM_PLUGIN_NAMESPACE
|
|
|
|
registry.deregister_library(handle)
|
|
assert len(registry.plugin_creator_list) == 0
|
|
|
|
|
|
@pytest.mark.parametrize('dtype', ['float16', 'bfloat16', 'float32'])
|
|
def test_plugin_config(dtype):
|
|
plugin_config = _tlp.PluginConfig.from_dict({'dtype': dtype})
|
|
assert plugin_config.dtype == dtype
|
|
|
|
assert plugin_config._gpt_attention_plugin == 'auto'
|
|
assert plugin_config.gpt_attention_plugin == dtype
|
|
assert plugin_config._nccl_plugin == 'auto'
|
|
assert plugin_config.nccl_plugin == dtype
|
|
assert plugin_config._gemm_plugin is None
|
|
assert plugin_config.gemm_plugin is None
|
|
|
|
new_dtype_options = ['float16', 'bfloat16', 'float32']
|
|
new_dtype_options.remove(dtype)
|
|
new_dtype = random.choice(new_dtype_options)
|
|
assert new_dtype != dtype
|
|
plugin_config.dtype = new_dtype
|
|
plugin_config.gpt_attention_plugin = dtype
|
|
plugin_config.gemm_plugin = 'auto'
|
|
|
|
assert plugin_config._gpt_attention_plugin == dtype
|
|
assert plugin_config.gpt_attention_plugin == dtype
|
|
assert plugin_config._nccl_plugin == 'auto'
|
|
assert plugin_config.nccl_plugin == new_dtype
|
|
assert plugin_config._gemm_plugin == 'auto'
|
|
assert plugin_config.gemm_plugin == new_dtype
|
|
|
|
with pytest.raises(Exception):
|
|
plugin_config.dtype = None
|
|
with pytest.raises(Exception):
|
|
plugin_config.dtype = 'auto'
|
|
with pytest.raises(Exception):
|
|
plugin_config.dtype = 'xyz'
|
|
with pytest.raises(Exception):
|
|
plugin_config.gpt_attention_plugin = 'abc'
|
|
with pytest.raises(Exception):
|
|
plugin_config.nccl_plugin = 123
|
|
with pytest.raises(Exception):
|
|
plugin_config.a_new_xxx_plugin = 'float16'
|
|
|
|
config_dict = plugin_config.to_dict()
|
|
new_plugin_config = _tlp.PluginConfig.from_dict(config_dict)
|
|
assert config_dict == new_plugin_config.to_dict()
|
|
|
|
plugin_config.to_legacy_setting()
|