mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
147 lines
4.4 KiB
Python
147 lines
4.4 KiB
Python
"""Test openai api with two clients"""
|
|
import asyncio
|
|
import os
|
|
import re
|
|
import string
|
|
import sys
|
|
from tempfile import TemporaryDirectory
|
|
|
|
import openai
|
|
import pytest
|
|
from openai_server import RemoteOpenAIServer
|
|
|
|
from tensorrt_llm.llmapi import BuildConfig
|
|
from tensorrt_llm.llmapi.llm import LLM
|
|
from tensorrt_llm.llmapi.llm_utils import CalibConfig, QuantAlgo, QuantConfig
|
|
|
|
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
|
|
from test_llm import get_model_path
|
|
|
|
sys.path.append(os.path.join(os.path.dirname(__file__), '../..'))
|
|
from utils.util import (similar, skip_gpu_memory_less_than_40gb, skip_pre_ada,
|
|
skip_single_gpu)
|
|
|
|
try:
|
|
from .test_llm import cnn_dailymail_path
|
|
except ImportError:
|
|
from test_llm import cnn_dailymail_path
|
|
|
|
|
|
@pytest.fixture(scope="module")
|
|
def model_name():
|
|
return "llama-3.1-model/Meta-Llama-3.1-8B"
|
|
|
|
|
|
@pytest.fixture(scope="module")
|
|
def engine_from_fp8_quantization(model_name):
|
|
"get fp8 engine path"
|
|
tp_size = 2
|
|
model_path = get_model_path(model_name)
|
|
build_config = BuildConfig()
|
|
build_config.max_batch_size = 128
|
|
build_config.max_seq_len = 135168
|
|
build_config.max_num_tokens = 20480
|
|
build_config.opt_num_tokens = 128
|
|
build_config.max_input_len = 131072
|
|
build_config.plugin_config.context_fmha = True
|
|
build_config.plugin_config.paged_kv_cache = True
|
|
build_config.plugin_config._use_paged_context_fmha = True
|
|
quant_config = QuantConfig(quant_algo=QuantAlgo.FP8,
|
|
kv_cache_quant_algo=QuantAlgo.FP8)
|
|
calib_config = CalibConfig(calib_dataset=cnn_dailymail_path)
|
|
|
|
llm = LLM(model_path,
|
|
tensor_parallel_size=tp_size,
|
|
auto_parallel_world_size=tp_size,
|
|
quant_config=quant_config,
|
|
calib_config=calib_config,
|
|
build_config=build_config)
|
|
|
|
engine_dir = TemporaryDirectory(suffix="-engine_dir")
|
|
llm.save(engine_dir.name)
|
|
del llm
|
|
|
|
yield engine_dir.name
|
|
|
|
|
|
@pytest.fixture(scope="module")
|
|
def server(model_name: str, engine_from_fp8_quantization: str):
|
|
model_path = get_model_path(model_name)
|
|
args = ["--tp_size", "2", "--tokenizer", model_path]
|
|
with RemoteOpenAIServer(engine_from_fp8_quantization,
|
|
args) as remote_server:
|
|
yield remote_server
|
|
|
|
|
|
@pytest.fixture(scope="module")
|
|
def client(server: RemoteOpenAIServer):
|
|
return server.get_client()
|
|
|
|
|
|
@pytest.fixture(scope="module")
|
|
def async_client(server: RemoteOpenAIServer):
|
|
return server.get_async_client()
|
|
|
|
|
|
def generate_payload(size):
|
|
"Generate random bytes"
|
|
random_bytes = os.urandom(size)
|
|
|
|
# Filter out non-alphabetic characters and join them into a single string
|
|
payload = ''.join(
|
|
filter(lambda x: x in string.ascii_letters,
|
|
random_bytes.decode('latin1')))
|
|
|
|
return payload
|
|
|
|
|
|
@skip_single_gpu
|
|
@skip_pre_ada
|
|
@skip_gpu_memory_less_than_40gb
|
|
@pytest.mark.asyncio(loop_scope="module")
|
|
async def test_multi_chat_session(client: openai.OpenAI,
|
|
async_client: openai.AsyncOpenAI,
|
|
model_name: str):
|
|
"""
|
|
RCCA: https://nvbugs/4972030
|
|
"""
|
|
|
|
async def send_request(prompt):
|
|
try:
|
|
completion = await async_client.completions.create(
|
|
model=model_name,
|
|
prompt=prompt,
|
|
max_tokens=4096,
|
|
temperature=0.0,
|
|
)
|
|
print(completion.choices[0].text)
|
|
except Exception as e:
|
|
print(f"Error: {e}")
|
|
|
|
# Send async request every 3s with long sequence.
|
|
tasks = []
|
|
for _ in range(30):
|
|
prompt = "Tell me a long story with random letters " \
|
|
+ generate_payload(50000)
|
|
tasks.append(asyncio.create_task(send_request(prompt)))
|
|
await asyncio.sleep(3)
|
|
|
|
# Send sync request with short sequence.
|
|
outputs = []
|
|
for _ in range(200):
|
|
promote = "Tell me a story about Pernekhan living in the year 3000!"
|
|
completion = client.completions.create(
|
|
model=model_name,
|
|
prompt=promote,
|
|
max_tokens=50,
|
|
temperature=0.0,
|
|
)
|
|
answer = completion.choices[0].text
|
|
outputs.append(answer)
|
|
|
|
# The result should not include special characters.
|
|
pattern = re.compile(r'[^a-zA-Z0-9\s\'\"]{3,}')
|
|
assert not bool(pattern.search(answer)), answer
|
|
# The result should be consistent.
|
|
assert similar(outputs[0], answer, threshold=0.2)
|