TensorRT-LLMs/tests/llmapi/apps/_test_openai_multi_chat.py
Sharan Chetlur 258c7540c0 open source 09df54c0cc99354a60bbc0303e3e8ea33a96bef0 (#2725)
Co-authored-by: Kaiyu Xie <26294424+kaiyux@users.noreply.github.com>

open source f8c0381a2bc50ee2739c3d8c2be481b31e5f00bd (#2736)

Co-authored-by: Kaiyu Xie <26294424+kaiyux@users.noreply.github.com>

Add note for blackwell (#2742)

Update the docs to workaround the extra-index-url issue (#2744)

update README.md (#2751)

Fix github io pages (#2761)

Update
2025-02-11 02:21:51 +00:00

155 lines
4.6 KiB
Python

"""Test openai api with two clients"""
import asyncio
import os
import re
import string
import sys
from tempfile import TemporaryDirectory
import openai
import pytest
from openai_server import RemoteOpenAIServer
from tensorrt_llm.llmapi import BuildConfig
from tensorrt_llm.llmapi.llm import LLM
from tensorrt_llm.llmapi.llm_utils import CalibConfig, QuantAlgo, QuantConfig
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
from test_llm import get_model_path
sys.path.append(os.path.join(os.path.dirname(__file__), '../..'))
from utils.util import (similar, skip_gpu_memory_less_than_40gb, skip_pre_ada,
skip_single_gpu)
try:
from .test_llm import cnn_dailymail_path
except ImportError:
from test_llm import cnn_dailymail_path
@pytest.fixture(scope="module")
def model_name():
return "llama-3.1-model/Meta-Llama-3.1-8B"
@pytest.fixture(scope="module")
def engine_from_fp8_quantization(model_name):
"get fp8 engine path"
tp_size = 2
model_path = get_model_path(model_name)
build_config = BuildConfig()
build_config.max_batch_size = 128
build_config.max_seq_len = 135168
build_config.max_num_tokens = 20480
build_config.opt_num_tokens = 128
build_config.max_input_len = 131072
build_config.plugin_config.context_fmha = True
build_config.plugin_config.paged_kv_cache = True
build_config.plugin_config._use_paged_context_fmha = True
quant_config = QuantConfig(quant_algo=QuantAlgo.FP8,
kv_cache_quant_algo=QuantAlgo.FP8)
calib_config = CalibConfig(calib_dataset=cnn_dailymail_path)
llm = LLM(model_path,
tensor_parallel_size=tp_size,
auto_parallel_world_size=tp_size,
quant_config=quant_config,
calib_config=calib_config,
build_config=build_config)
engine_dir = TemporaryDirectory(suffix="-engine_dir")
llm.save(engine_dir.name)
del llm
yield engine_dir.name
@pytest.fixture(scope="module", params=[None, 'pytorch'])
def backend(request):
return request.param
@pytest.fixture(scope="module")
def server(model_name: str, engine_from_fp8_quantization: str, backend: str):
model_path = get_model_path(model_name)
args = ["--tp_size", "2", "--tokenizer", model_path]
if backend is not None:
args.append("--backend")
args.append(backend)
with RemoteOpenAIServer(engine_from_fp8_quantization,
args) as remote_server:
yield remote_server
@pytest.fixture(scope="module")
def client(server: RemoteOpenAIServer):
return server.get_client()
@pytest.fixture(scope="module")
def async_client(server: RemoteOpenAIServer):
return server.get_async_client()
def generate_payload(size):
"Generate random bytes"
random_bytes = os.urandom(size)
# Filter out non-alphabetic characters and join them into a single string
payload = ''.join(
filter(lambda x: x in string.ascii_letters,
random_bytes.decode('latin1')))
return payload
@skip_single_gpu
@skip_pre_ada
@skip_gpu_memory_less_than_40gb
@pytest.mark.asyncio(loop_scope="module")
async def test_multi_chat_session(client: openai.OpenAI,
async_client: openai.AsyncOpenAI,
model_name: str):
"""
RCCA: https://nvbugs/4972030
"""
async def send_request(prompt):
try:
completion = await async_client.completions.create(
model=model_name,
prompt=prompt,
max_tokens=4096,
temperature=0.0,
)
print(completion.choices[0].text)
except Exception as e:
print(f"Error: {e}")
# Send async request every 3s with long sequence.
tasks = []
for _ in range(30):
prompt = "Tell me a long story with random letters " \
+ generate_payload(50000)
tasks.append(asyncio.create_task(send_request(prompt)))
await asyncio.sleep(3)
# Send sync request with short sequence.
outputs = []
for _ in range(200):
promote = "Tell me a story about Pernekhan living in the year 3000!"
completion = client.completions.create(
model=model_name,
prompt=promote,
max_tokens=50,
temperature=0.0,
)
answer = completion.choices[0].text
outputs.append(answer)
# The result should not include special characters.
pattern = re.compile(r'[^a-zA-Z0-9\s\'\"]{3,}')
assert not bool(pattern.search(answer)), answer
# The result should be consistent.
assert similar(outputs[0], answer, threshold=0.2)