mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-06 03:01:50 +08:00
65 lines
1.9 KiB
Python
65 lines
1.9 KiB
Python
from tensorrt_llm.llmapi import LlmArgs
|
|
from tensorrt_llm.llmapi.utils import (ApiStatusRegistry,
|
|
generate_api_docs_as_docstring)
|
|
|
|
|
|
def test_api_status_registry():
|
|
|
|
@ApiStatusRegistry.set_api_status("beta")
|
|
def _my_method(self, *args, **kwargs):
|
|
pass
|
|
|
|
assert ApiStatusRegistry.get_api_status(_my_method) == "beta"
|
|
|
|
@ApiStatusRegistry.set_api_status("prototype")
|
|
def _my_method(self, *args, **kwargs):
|
|
pass
|
|
|
|
# will always keep the first status, and the behaviour will be unknown if
|
|
# one method is registered with a different status in different files.
|
|
assert ApiStatusRegistry.get_api_status(_my_method) == "beta"
|
|
|
|
class App:
|
|
|
|
@ApiStatusRegistry.set_api_status("beta")
|
|
def _my_method(self, *args, **kwargs):
|
|
pass
|
|
|
|
assert ApiStatusRegistry.get_api_status(App._my_method) == "beta"
|
|
|
|
|
|
def test_generate_api_docs_as_docstring():
|
|
doc = generate_api_docs_as_docstring(LlmArgs)
|
|
assert ":tag:`beta`" in doc, "the label is not generated"
|
|
print(doc)
|
|
|
|
|
|
class DelayedAssert:
|
|
|
|
def __init__(self, store_stack: bool = False):
|
|
self.assertions = []
|
|
self.store_stack = store_stack
|
|
|
|
def add(self, result: bool, msg: str):
|
|
import traceback
|
|
self.assertions.append(
|
|
(bool(result), str(msg), traceback.format_stack()))
|
|
|
|
def get_msg(self):
|
|
ret = ['Some assertions failed:']
|
|
for result, msg, stack in self.assertions:
|
|
ret.append('\n'.join([
|
|
f'Assert result: {result}', msg,
|
|
''.join(stack) if self.store_stack else ''
|
|
]))
|
|
ret = '\n-----------------------------------------\n'.join(ret)
|
|
ret = 'Some assertions failed:\n' + ret
|
|
return ret
|
|
|
|
def clear(self):
|
|
self.assertions.clear()
|
|
|
|
def assert_all(self):
|
|
assert all(ret[0] for ret in self.assertions), self.get_msg()
|
|
self.clear()
|