TensorRT-LLMs/examples/apps/openai_server.py
Dan Blanaru 48686bca3a
open source 7f370deb0090d885d7518c2b146399ba3933c004 (#2273)
* Update TensorRT-LLM

---------
Co-authored-by: Qingquan Song <ustcsqq@gmail.com>
2024-09-30 13:51:19 +02:00

487 lines
20 KiB
Python

#!/usr/bin/env python
import asyncio
import logging
from http import HTTPStatus
from pathlib import Path
from typing import (AsyncGenerator, AsyncIterator, List, Optional, Tuple,
TypedDict)
import click
import uvicorn
from fastapi import FastAPI
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse, Response, StreamingResponse
from openai.types.chat import ChatCompletionMessageParam
from transformers import AutoTokenizer, PreTrainedTokenizer
# yapf: disable
from tensorrt_llm.hlapi import LLM, BuildConfig, KvCacheConfig
from tensorrt_llm.hlapi.llm import RequestOutput
from tensorrt_llm.hlapi.openai_protocol import (
ChatCompletionLogProbs, ChatCompletionLogProbsContent,
ChatCompletionNamedToolChoiceParam, ChatCompletionRequest,
ChatCompletionResponse, ChatCompletionResponseChoice,
ChatCompletionResponseStreamChoice, ChatCompletionStreamResponse,
ChatMessage, CompletionRequest, CompletionResponse,
CompletionResponseChoice, CompletionResponseStreamChoice,
CompletionStreamResponse, DeltaMessage, ErrorResponse, FunctionCall,
ModelCard, ModelList, ToolCall, UsageInfo)
from tensorrt_llm.version import __version__ as VERSION
# yapf: enale
TIMEOUT_KEEP_ALIVE = 5 # seconds.
class ConversationMessage(TypedDict):
role: str
content: str
def parse_chat_message_content(
message: ChatCompletionMessageParam, ) -> ConversationMessage:
role = message["role"]
content = message.get("content")
if content is None:
return []
if isinstance(content, str):
return [ConversationMessage(role=role, content=content)]
# for Iterable[ChatCompletionContentPartTextParam]
texts: List[str] = []
for part in content:
part_type = part["type"]
if part_type == "text":
text = part["text"]
texts.append(text)
else:
raise NotImplementedError(f"{part_type} is not supported")
text_prompt = "\n".join(texts)
return [ConversationMessage(role=role, content=text_prompt)]
class OpenaiServer:
def __init__(self,
llm: LLM,
model: str,
kv_cache_config: KvCacheConfig,
hf_tokenizer: PreTrainedTokenizer = None):
self.llm = llm
self.kv_cache_config = kv_cache_config
self.tokenizer = hf_tokenizer
model_dir = Path(model)
if model_dir.exists() and model_dir.is_dir():
self.model = model_dir.name
else:
self.model = model
self.app = FastAPI()
@self.app.exception_handler(RequestValidationError)
async def validation_exception_handler(_, exc):
return self.create_error_response(message=str(exc))
self.register_routes()
@staticmethod
def create_error_response(
message: str,
err_type: str = "BadRequestError",
status_code: HTTPStatus = HTTPStatus.BAD_REQUEST) -> ErrorResponse:
error_response = ErrorResponse(message=message,
type=err_type,
code=status_code.value)
return JSONResponse(content=error_response.model_dump(),
status_code=error_response.code)
def register_routes(self):
self.app.add_api_route("/health", self.health, methods=["GET"])
self.app.add_api_route("/version", self.version, methods=["GET"])
self.app.add_api_route("/v1/models", self.get_model, methods=["GET"])
self.app.add_api_route("/v1/completions",
self.openai_completion,
methods=["POST"])
self.app.add_api_route("/v1/chat/completions",
self.openai_chat,
methods=["POST"])
async def health(self) -> Response:
return Response(status_code=200)
async def version(self) -> JSONResponse:
ver = {"version": VERSION}
return JSONResponse(content=ver)
async def get_model(self) -> JSONResponse:
model_list = ModelList(data=[ModelCard(id=self.model)])
return JSONResponse(content=model_list.model_dump())
async def openai_chat(self, request: ChatCompletionRequest) -> Response:
def get_role() -> str:
if request.add_generation_prompt:
role = "assistant"
else:
role = request.messages[-1]["role"]
return role
def stream_usage_info(prompt_tokens: int, completion_tokens: int):
if request.stream_options and request.stream_options.include_usage and \
request.stream_options.continuous_usage_stats:
usage = UsageInfo(prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens +
completion_tokens)
else:
usage = None
return usage
def create_logprobs(token_ids: List[int],
logprobs: List[float]) -> ChatCompletionLogProbs:
assert len(token_ids) == len(logprobs), \
"token_ids and logprobs have different lengths"
content: List[ChatCompletionLogProbsContent] = []
for token_id, logprob in zip(token_ids, logprobs):
token = self.tokenizer.decode(token_id)
# returning multiple logprobs is not supported
first_logprob = ChatCompletionLogProbsContent(
token=token, logprob=max(logprob, -9999.0),
bytes=list(token.encode("utf-8", errors="replace"))
)
content.append(first_logprob)
chat_logprobs = ChatCompletionLogProbs(content=content)
return chat_logprobs
async def chat_stream_generator(promise: RequestOutput) -> AsyncGenerator[str, None]:
first_iteration = True
num_choices = 1 if request.n is None else request.n
finish_reason_sent = [False] * num_choices
role = get_role()
def yield_first_chat(num_tokens: int, role: str = None, content: str = None):
for i in range(num_choices):
choice_data = ChatCompletionResponseStreamChoice(
index=i,
delta=DeltaMessage(
role=role, content=content),
logprobs=None,
finish_reason=None)
chunk = ChatCompletionStreamResponse(
choices=[choice_data], model=self.model)
chunk.usage = stream_usage_info(num_tokens, 0)
data = chunk.model_dump_json(exclude_unset=True)
yield f"data: {data}\n\n"
async for res in promise:
prompt_tokens = len(res.prompt_token_ids)
if first_iteration:
yield_first_chat(prompt_tokens, role=role)
if request.echo:
last_msg_content = ""
if conversation and conversation[-1].get(
"content") and conversation[-1].get(
"role") == role:
last_msg_content = conversation[-1][
"content"]
if last_msg_content:
yield_first_chat(prompt_tokens, content=last_msg_content)
first_iteration = False
for output in res.outputs:
i = output.index
if finish_reason_sent[i]:
continue
delta_text = output.text_diff
if request.tool_choice and type(
request.tool_choice
) is ChatCompletionNamedToolChoiceParam:
delta_message = DeltaMessage(tool_calls=[
ToolCall(function=FunctionCall(
name=request.tool_choice.function.name,
arguments=delta_text))
])
else:
delta_message = DeltaMessage(content=delta_text)
if delta_text:
# Send token-by-token response for each request.n
choice_data = ChatCompletionResponseStreamChoice(
index=i,
delta=delta_message,
finish_reason=None)
if request.logprobs:
logprobs = output.logprobs_diff
token_ids = output.token_ids_diff
choice_data.logprobs = create_logprobs(token_ids, logprobs)
chunk = ChatCompletionStreamResponse(
choices=[choice_data], model=self.model)
chunk.usage = stream_usage_info(
prompt_tokens, output.length)
data = chunk.model_dump_json()
yield f"data: {data}\n\n"
else:
finish_reason_sent[i] = True
if (request.stream_options
and request.stream_options.include_usage):
completion_tokens = sum(output.length
for output in promise.outputs)
final_usage = UsageInfo(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
final_usage_chunk = ChatCompletionStreamResponse(
choices=[], model=self.model, usage=final_usage)
final_usage_data = final_usage_chunk.model_dump_json()
yield f"data: {final_usage_data}\n\n"
async def create_chat_response(promise: RequestOutput) -> JSONResponse:
await promise.aresult()
choices: List[ChatCompletionResponseChoice] = []
role = get_role()
for output in promise.outputs:
if request.tool_choice and isinstance(
request.tool_choice,
ChatCompletionNamedToolChoiceParam):
message = ChatMessage(
role=role,
content="",
tool_calls=[
ToolCall(function=FunctionCall(
name=request.tool_choice.function.name,
arguments=output.text))
])
else:
message = ChatMessage(role=role, content=output.text)
choice = ChatCompletionResponseChoice(
index=output.index,
message=message,
)
if request.logprobs:
choice.logprobs = create_logprobs(output.token_ids, output.logprobs)
choices.append(choice)
if request.echo:
last_msg_content = ""
if conversation and conversation[-1].get(
"content") and conversation[-1].get("role") == role:
last_msg_content = conversation[-1]["content"]
for choice in choices:
full_message = last_msg_content + choice.message.content
choice.message.content = full_message
num_prompt_tokens = len(promise.prompt_token_ids)
num_generated_tokens = sum(
len(output.token_ids) for output in promise.outputs)
usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
completion_tokens=num_generated_tokens,
total_tokens=num_prompt_tokens + num_generated_tokens,
)
response = ChatCompletionResponse(
model=self.model,
choices=choices,
usage=usage,
)
return response
try:
conversation: List[ConversationMessage] = []
for msg in request.messages:
conversation.extend(parse_chat_message_content(msg))
tool_dicts = None if request.tools is None else [
tool.model_dump() for tool in request.tools
]
prompt: str = self.tokenizer.apply_chat_template(
conversation=conversation,
tokenize=False,
add_generation_prompt=request.add_generation_prompt,
tools=tool_dicts,
documents=request.documents,
chat_template=request.chat_template,
**(request.chat_template_kwargs or {}),
)
sampling_params = request.to_sampling_params()
promise = self.llm.generate_async(
inputs=prompt,
sampling_params=sampling_params,
streaming=request.stream,
)
if request.stream:
response_generator = chat_stream_generator(promise)
return StreamingResponse(content=response_generator,
media_type="text/event-stream")
else:
response = await create_chat_response(promise)
return JSONResponse(content=response.model_dump())
except Exception as e:
return self.create_error_response(str(e))
async def openai_completion(self, request: CompletionRequest) -> Response:
def merge_promises(promises: List[RequestOutput]) -> AsyncIterator[Tuple[int, RequestOutput]]:
outputs = asyncio.Queue()
finished = [False] * len(promises)
async def producer(i: int, promise: RequestOutput):
async for output in promise:
await outputs.put((i, output))
finished[i] = True
_tasks = [asyncio.create_task(producer(i, promise))
for i, promise in enumerate(promises)
]
async def consumer():
while not all(finished) or not outputs.empty():
item = await outputs.get()
yield item
await asyncio.gather(*_tasks)
return consumer()
async def create_completion_generator(generator: AsyncIterator[Tuple[int, RequestOutput]],
num_choices: int):
num_repsonse_per_request = 1 if request.n is None else request.n
echoed = [False] * num_choices
async for prompt_idx, requst_output in generator:
prompt = requst_output.prompt
for gen_idx, output in enumerate(requst_output.outputs):
response_idx = prompt_idx * num_repsonse_per_request + gen_idx
delta_text = output.text_diff
if request.echo and not echoed[response_idx]:
delta_text = prompt + delta_text
echoed[response_idx] = True
response = CompletionStreamResponse(
model=self.model,
choices=[
CompletionResponseStreamChoice(
index=response_idx, text=delta_text)
])
response_json = response.model_dump_json(
exclude_unset=False)
yield f"data: {response_json}\n\n"
yield f"data: [DONE]\n\n"
async def create_completion_response(generator: AsyncIterator[Tuple[int, RequestOutput]],
num_choices: int):
choices = [None] * num_choices
num_repsonse_per_request = 1 if request.n is None else request.n
num_prompt_tokens = num_gen_tokens = 0
async for prompt_idx, request_output in generator:
num_prompt_tokens += len(request_output.prompt_token_ids)
for gen_idx, output in enumerate(request_output.outputs):
num_gen_tokens += len(output.token_ids)
output_text = output.text
if request.echo:
output_text = request_output.prompt + output_text
idx = prompt_idx * num_repsonse_per_request + gen_idx
choice = CompletionResponseChoice(
index=idx,
text=output_text,
)
choices[idx] = choice
usage_info = UsageInfo(
prompt_tokens=num_prompt_tokens,
completion_tokens=num_gen_tokens,
total_tokens=num_gen_tokens + num_prompt_tokens,
)
response = CompletionResponse(
model=self.model,
choices=choices,
usage=usage_info,
)
return response
try:
if isinstance(request.prompt, str) or \
(isinstance(request.prompt, list) and isinstance(request.prompt[0], int)):
prompts = [request.prompt]
else:
prompts = request.prompt
promises: List[RequestOutput] = []
sampling_params = request.to_sampling_params()
for prompt in prompts:
promise = self.llm.generate_async(
inputs=prompt,
sampling_params=sampling_params,
streaming=request.stream,
)
promises.append(promise)
generator = merge_promises(promises)
num_choices = len(prompts) if request.n is None else len(prompts) * request.n
if request.stream:
response_generator = create_completion_generator(generator, num_choices)
return StreamingResponse(content=response_generator,
media_type="text/event-stream")
else:
response = await create_completion_response(generator, num_choices)
return JSONResponse(content=response.model_dump())
except Exception as e:
return self.create_error_response(str(e))
async def __call__(self, host, port):
config = uvicorn.Config(self.app,
host=host,
port=port,
log_level="info",
timeout_keep_alive=TIMEOUT_KEEP_ALIVE)
await uvicorn.Server(config).serve()
@click.command()
@click.argument("model_dir")
@click.option("--tokenizer", type=str, default=None)
@click.option("--host", type=str, default=None)
@click.option("--port", type=int, default=8000)
@click.option("--max_beam_width", type=int, default=1)
@click.option("--tp_size", type=int, default=1)
@click.option("--pp_size", type=int, default=1)
def entrypoint(model_dir: str,
tokenizer: Optional[str] = None,
host: Optional[str] = None,
port: int = 8000,
max_beam_width: int = 1,
tp_size: int = 1,
pp_size: int = 1):
host = host or "0.0.0.0"
port = port or 8000
logging.info(f"Starting server at {host}:{port}")
build_config = BuildConfig(max_batch_size=10, max_beam_width=max_beam_width)
llm = LLM(model_dir,
tokenizer,
tensor_parallel_size=tp_size,
pipeline_parallel_size=pp_size,
build_config=build_config)
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.8)
hf_tokenizer = AutoTokenizer.from_pretrained(tokenizer or model_dir)
server = OpenaiServer(llm=llm,
model=model_dir,
kv_cache_config=kv_cache_config,
hf_tokenizer=hf_tokenizer)
asyncio.run(server(host, port))
if __name__ == "__main__":
entrypoint()