TensorRT-LLMs/tests/tools/plugin_gen/test_plugin_gen.py
Sharan Chetlur 258c7540c0 open source 09df54c0cc99354a60bbc0303e3e8ea33a96bef0 (#2725)
Co-authored-by: Kaiyu Xie <26294424+kaiyux@users.noreply.github.com>

open source f8c0381a2bc50ee2739c3d8c2be481b31e5f00bd (#2736)

Co-authored-by: Kaiyu Xie <26294424+kaiyux@users.noreply.github.com>

Add note for blackwell (#2742)

Update the docs to workaround the extra-index-url issue (#2744)

update README.md (#2751)

Fix github io pages (#2761)

Update
2025-02-11 02:21:51 +00:00

44 lines
1.1 KiB
Python

import os
from importlib.metadata import version
import pytest
from .kernel_config import get_fmha_kernel_meta_data
KERNEL_META_DATA = get_fmha_kernel_meta_data()
try:
from tensorrt_llm.tools.plugin_gen.plugin_gen import (TRITON_COMPILE_BIN,
gen_trt_plugins)
except ImportError:
TRITON_COMPILE_BIN = "does_not_exist"
def gen_trt_plugins(*args, **kwargs):
pass
WORKSPACE = './tmp/'
def is_triton_installed() -> bool:
# the triton detection does not work in PyTorch NGC 23.10 container
try:
if version('triton') != "2.1.0+440fd1b":
return False
except Exception:
return False
return os.path.exists(TRITON_COMPILE_BIN)
def is_trt_automation() -> bool:
return os.path.exists("/build/config.yml")
@pytest.mark.skipif(
not is_triton_installed() or is_trt_automation(),
reason=
'triton is not installed, this test is not supported in trt automation')
def test_end_to_end():
gen_trt_plugins(workspace=WORKSPACE, metas=[KERNEL_META_DATA])