mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
Co-authored-by: Kaiyu Xie <26294424+kaiyux@users.noreply.github.com> open source f8c0381a2bc50ee2739c3d8c2be481b31e5f00bd (#2736) Co-authored-by: Kaiyu Xie <26294424+kaiyux@users.noreply.github.com> Add note for blackwell (#2742) Update the docs to workaround the extra-index-url issue (#2744) update README.md (#2751) Fix github io pages (#2761) Update
57 lines
1.4 KiB
Python
57 lines
1.4 KiB
Python
import os
|
|
import sys
|
|
|
|
import openai
|
|
import pytest
|
|
import requests
|
|
from openai_server import RemoteOpenAIServer
|
|
|
|
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
|
|
from test_llm import get_model_path
|
|
|
|
from tensorrt_llm.version import __version__ as VERSION
|
|
|
|
|
|
@pytest.fixture(scope="module")
|
|
def model_name():
|
|
return "llama-models-v2/TinyLlama-1.1B-Chat-v1.0"
|
|
|
|
|
|
@pytest.fixture(scope="module", params=[None, 'pytorch'])
|
|
def backend(request):
|
|
return request.param
|
|
|
|
|
|
@pytest.fixture(scope="module")
|
|
def server(model_name: str, backend: str):
|
|
model_path = get_model_path(model_name)
|
|
args = ["--max_beam_width", "4"]
|
|
if backend is not None:
|
|
args.append("--backend")
|
|
args.append(backend)
|
|
with RemoteOpenAIServer(model_path, args) as remote_server:
|
|
yield remote_server
|
|
|
|
|
|
@pytest.fixture(scope="module")
|
|
def client(server: RemoteOpenAIServer):
|
|
return server.get_client()
|
|
|
|
|
|
def test_version(server: RemoteOpenAIServer):
|
|
version_url = server.url_for("version")
|
|
response = requests.get(version_url)
|
|
assert response.status_code == 200
|
|
assert response.json()["version"] == VERSION
|
|
|
|
|
|
def test_health(server: RemoteOpenAIServer):
|
|
health_url = server.url_for("health")
|
|
response = requests.get(health_url)
|
|
assert response.status_code == 200
|
|
|
|
|
|
def test_model(client: openai.OpenAI, model_name: str):
|
|
model = client.models.list().data[0]
|
|
assert model.id == model_name.split('/')[-1]
|