TensorRT-LLMs/examples/apps/openai_client.py
Dan Blanaru 48686bca3a
open source 7f370deb0090d885d7518c2b146399ba3933c004 (#2273)
* Update TensorRT-LLM

---------
Co-authored-by: Qingquan Song <ustcsqq@gmail.com>
2024-09-30 13:51:19 +02:00

89 lines
2.8 KiB
Python

import argparse
from openai import OpenAI
def run_completion(args: argparse.Namespace):
client = OpenAI(
base_url="http://localhost:8000/v1",
api_key="tensorrt_llm",
)
prompt = args.prompt if args.prompt else "Where is New York?"
completion = client.completions.create(
model="llama-v3-8b-instruct-hf",
prompt=[prompt],
top_p=args.top_p,
temperature=args.temperature,
max_tokens=args.max_tokens,
echo=args.echo,
n=args.n,
stream=args.stream,
extra_body={
"top_k": args.top_k,
"use_beam_search": args.use_beam_search,
},
)
if args.stream:
for chunk in completion:
print(chunk)
else:
for choice in completion.choices:
print(choice.text)
def run_chat(args: argparse.Namespace):
# TODO{pengyunl}: multi-run chat example for openai api
client = OpenAI(
base_url="http://localhost:8000/v1",
api_key="tensorrt_llm",
)
prompt = args.prompt if args.prompt else "Where is New York?"
chat_completion = client.chat.completions.create(
model="llama-v3-8b-instruct-hf",
messages=[{
"role": "user",
"content": prompt
}],
top_p=args.top_p,
temperature=args.temperature,
max_tokens=args.max_tokens,
stream=args.stream,
n=args.n,
logprobs=args.return_logprobs,
extra_body={
"top_k": args.top_k,
"use_beam_search": args.use_beam_search,
"echo": args.echo,
},
)
if args.stream:
for chunk in chat_completion:
print(chunk)
else:
for choice in chat_completion.choices:
print(choice)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("--n", type=int, default=1)
parser.add_argument("--top_p", type=float, default=1.0)
parser.add_argument("--top_k", type=int, default=0)
parser.add_argument("--temperature", type=float, default=1.0)
parser.add_argument("--max_tokens", type=int, default=16)
parser.add_argument("--echo", action="store_true", default=False)
parser.add_argument("--stream", action="store_true", default=False)
parser.add_argument("--use_beam_search", action="store_true", default=False)
parser.add_argument("--api",
type=str,
choices=["chat", "completions"],
default="chat")
parser.add_argument("--prompt", type=str, default=None)
parser.add_argument("--return_logprobs", action="store_true", default=False)
args = parser.parse_args()
if args.api == "chat":
run_chat(args)
else:
run_completion(args)