TensorRT-LLMs/tensorrt_llm/serve/openai_server.py
Shunkangz ea050084ad
feat: Add support of chat completion in PD (#2985)
* Add support of chat completion in PD

Add support of include_usage in PD


Reformat


* Remove redundant code

Signed-off-by: Shunkang <182541032+Shunkangz@users.noreply.github.co>

* Refactor code

Signed-off-by: Shunkang <182541032+Shunkangz@users.noreply.github.co>

* Add chat completion test

Signed-off-by: Shunkang <182541032+Shunkangz@users.noreply.github.co>

* Refactor code

Signed-off-by: Shunkang <182541032+Shunkangz@users.noreply.github.co>

---------

Signed-off-by: Shunkang <182541032+Shunkangz@users.noreply.github.co>
Co-authored-by: Shunkang <182541032+Shunkangz@users.noreply.github.co>
2025-04-11 17:53:28 +08:00

362 lines
15 KiB
Python

#!/usr/bin/env python
import asyncio
import signal
import traceback
from contextlib import asynccontextmanager
from http import HTTPStatus
from pathlib import Path
from typing import (AsyncGenerator, AsyncIterator, List, Optional, Tuple,
TypedDict)
import uvicorn
from fastapi import FastAPI, Request
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse, Response, StreamingResponse
from openai.types.chat import ChatCompletionMessageParam
# yapf: disable
from tensorrt_llm.executor import CppExecutorError
from tensorrt_llm.executor.postproc_worker import PostprocParams
from tensorrt_llm.llmapi import LLM
from tensorrt_llm.llmapi.llm import RequestOutput
from tensorrt_llm.llmapi.utils import nvtx_mark
from tensorrt_llm.logger import logger
from tensorrt_llm.serve.openai_protocol import (ChatCompletionRequest,
ChatCompletionResponse,
CompletionRequest,
CompletionResponse,
CompletionResponseChoice,
ErrorResponse, ModelCard,
ModelList, UsageInfo,
to_llm_disaggregated_params)
from tensorrt_llm.serve.postprocess_handlers import (
ChatPostprocArgs, CompletionPostprocArgs, chat_response_post_processor,
chat_stream_post_processor, completion_response_post_processor,
completion_stream_post_processor)
from tensorrt_llm.version import __version__ as VERSION
# yapf: enale
TIMEOUT_KEEP_ALIVE = 5 # seconds.
class ConversationMessage(TypedDict):
role: str
content: str
def parse_chat_message_content(
message: ChatCompletionMessageParam, ) -> ConversationMessage:
role = message["role"]
content = message.get("content")
if content is None:
return []
if isinstance(content, str):
return [ConversationMessage(role=role, content=content)]
# for Iterable[ChatCompletionContentPartTextParam]
texts: List[str] = []
for part in content:
part_type = part["type"]
if part_type == "text":
text = part["text"]
texts.append(text)
else:
raise NotImplementedError(f"{part_type} is not supported")
text_prompt = "\n".join(texts)
return [ConversationMessage(role=role, content=text_prompt)]
class OpenAIServer:
def __init__(self,
llm: LLM,
model: str):
self.llm = llm
self.tokenizer = llm.tokenizer
model_dir = Path(model)
if model_dir.exists() and model_dir.is_dir():
self.model = model_dir.name
else:
self.model = model
@asynccontextmanager
async def lifespan(app: FastAPI):
# terminate rank0 worker
yield
self.llm.shutdown()
self.app = FastAPI(lifespan=lifespan)
@self.app.exception_handler(RequestValidationError)
async def validation_exception_handler(_, exc):
return self.create_error_response(message=str(exc))
self.register_routes()
async def await_disconnected(self, raw_request: Request, promise):
while not await raw_request.is_disconnected():
await asyncio.sleep(1)
if not promise.finished:
promise.abort()
logger.info(
f"{raw_request.client} is disconnected, abort {promise.request_id}")
@property
def postproc_worker_enabled(self) -> bool:
return True if self.llm.args.num_postprocess_workers > 0 else False
@staticmethod
def create_error_response(
message: str,
err_type: str = "BadRequestError",
status_code: HTTPStatus = HTTPStatus.BAD_REQUEST) -> ErrorResponse:
error_response = ErrorResponse(message=message,
type=err_type,
code=status_code.value)
return JSONResponse(content=error_response.model_dump(),
status_code=error_response.code)
def register_routes(self):
self.app.add_api_route("/health", self.health, methods=["GET"])
self.app.add_api_route("/version", self.version, methods=["GET"])
self.app.add_api_route("/v1/models", self.get_model, methods=["GET"])
# TODO: the metrics endpoint only reports iteration stats, not the runtime stats for now
self.app.add_api_route("/metrics", self.get_iteration_stats, methods=["GET"])
self.app.add_api_route("/v1/completions",
self.openai_completion,
methods=["POST"])
self.app.add_api_route("/v1/chat/completions",
self.openai_chat,
methods=["POST"])
async def health(self) -> Response:
return Response(status_code=200)
async def version(self) -> JSONResponse:
ver = {"version": VERSION}
return JSONResponse(content=ver)
async def get_model(self) -> JSONResponse:
model_list = ModelList(data=[ModelCard(id=self.model)])
return JSONResponse(content=model_list.model_dump())
async def get_iteration_stats(self) -> JSONResponse:
stats = []
async for stat in self.llm.get_stats_async(2):
stats.append(stat)
return JSONResponse(content=stats)
async def openai_chat(self, request: ChatCompletionRequest, raw_request: Request) -> Response:
def get_role() -> str:
if request.add_generation_prompt:
role = "assistant"
else:
role = request.messages[-1]["role"]
return role
async def chat_stream_generator(
promise: RequestOutput, postproc_params: PostprocParams) -> AsyncGenerator[str, None]:
if not self.postproc_worker_enabled:
post_processor, args = postproc_params.post_processor, postproc_params.postproc_args
async for res in promise:
pp_results = res.outputs[0]._postprocess_result if self.postproc_worker_enabled else post_processor(res, args)
for pp_res in pp_results:
yield pp_res
yield f"data: [DONE]\n\n"
nvtx_mark("generation ends")
async def create_chat_response(
promise: RequestOutput, postproc_params: PostprocParams) -> ChatCompletionResponse:
await promise.aresult()
if self.postproc_worker_enabled:
return promise.outputs[0]._postprocess_result
else:
post_processor, args = postproc_params.post_processor, postproc_params.postproc_args
return post_processor(promise, args)
try:
conversation: List[ConversationMessage] = []
for msg in request.messages:
conversation.extend(parse_chat_message_content(msg))
tool_dicts = None if request.tools is None else [
tool.model_dump() for tool in request.tools
]
prompt: str = self.tokenizer.apply_chat_template(
conversation=conversation,
tokenize=False,
add_generation_prompt=request.add_generation_prompt,
tools=tool_dicts,
documents=request.documents,
chat_template=request.chat_template,
**(request.chat_template_kwargs or {}),
)
sampling_params = request.to_sampling_params()
disaggregated_params = to_llm_disaggregated_params(request.disaggregated_params)
postproc_args = ChatPostprocArgs.from_request(request)
if conversation and conversation[-1].get(
"content") and conversation[-1].get("role") == get_role():
postproc_args.last_message_content = conversation[-1]["content"]
postproc_params = PostprocParams(
post_processor=chat_stream_post_processor
if request.stream else chat_response_post_processor,
postproc_args=postproc_args,
)
promise = self.llm.generate_async(
inputs=prompt,
sampling_params=sampling_params,
_postproc_params=postproc_params if self.postproc_worker_enabled else None,
streaming=request.stream,
disaggregated_params=disaggregated_params
)
asyncio.create_task(self.await_disconnected(raw_request, promise))
if not self.postproc_worker_enabled:
postproc_args.tokenizer = self.tokenizer
postproc_args.num_prompt_tokens = len(promise.prompt_token_ids)
if request.stream:
response_generator = chat_stream_generator(promise, postproc_params)
return StreamingResponse(content=response_generator,
media_type="text/event-stream")
else:
response = await create_chat_response(promise, postproc_params)
return JSONResponse(content=response.model_dump())
except CppExecutorError:
# If internal executor error is raised, shutdown the server
signal.raise_signal(signal.SIGINT)
except Exception as e:
return self.create_error_response(str(e))
async def openai_completion(self, request: CompletionRequest, raw_request: Request) -> Response:
def merge_promises(
promises: List[RequestOutput],
postproc_params_collections: List[Optional[PostprocParams]]
) -> AsyncIterator[Tuple[RequestOutput, Optional[PostprocParams]]]:
outputs = asyncio.Queue()
finished = [False] * len(promises)
async def producer(i: int, promise: RequestOutput, postproc_params: Optional[PostprocParams]):
async for output in promise:
await outputs.put((output, postproc_params))
finished[i] = True
_tasks = [
asyncio.create_task(producer(i, promise, postproc_params))
for i, (promise, postproc_params) in enumerate(zip(promises, postproc_params_collections))
]
async def consumer():
while not all(finished) or not outputs.empty():
item = await outputs.get()
yield item
await asyncio.gather(*_tasks)
return consumer()
async def create_completion_generator(
generator: AsyncIterator[Tuple[RequestOutput, Optional[PostprocParams]]]):
async for request_output, postproc_params in generator:
if not self.postproc_worker_enabled:
post_processor, args = postproc_params.post_processor, postproc_params.postproc_args
pp_result = post_processor(request_output, args)
else:
pp_result = request_output.outputs[0]._postprocess_result
for pp_res in pp_result:
yield pp_res
yield f"data: [DONE]\n\n"
async def create_completion_response(
generator: AsyncIterator[Tuple[RequestOutput, Optional[PostprocParams]]]) -> CompletionResponse:
all_choices: List[CompletionResponseChoice] = []
num_prompt_tokens = num_gen_tokens = 0
async for request_output, postproc_params in generator:
pp_result: CompletionResponse
if not self.postproc_worker_enabled:
post_processor, args = postproc_params.post_processor, postproc_params.postproc_args
pp_result = post_processor(request_output, args)
else:
pp_result = request_output.outputs[0]._postprocess_result
choices, usage = pp_result.choices, pp_result.usage
all_choices.extend(choices)
num_prompt_tokens += usage.prompt_tokens
num_gen_tokens += usage.completion_tokens
usage_info = UsageInfo(
prompt_tokens=num_prompt_tokens,
completion_tokens=num_gen_tokens,
total_tokens=num_gen_tokens + num_prompt_tokens,
)
response = CompletionResponse(
model=self.model,
choices=all_choices,
usage=usage_info,
)
return response
try:
if isinstance(request.prompt, str) or \
(isinstance(request.prompt, list) and isinstance(request.prompt[0], int)):
prompts = [request.prompt]
else:
prompts = request.prompt
promises: List[RequestOutput] = []
postproc_params_collection: List[Optional[PostprocParams]] = []
sampling_params = request.to_sampling_params()
disaggregated_params = to_llm_disaggregated_params(request.disaggregated_params)
for idx, prompt in enumerate(prompts):
postproc_args = CompletionPostprocArgs.from_request(request)
postproc_args.prompt_idx = idx
if request.echo:
postproc_args.prompt = prompt
postproc_params = PostprocParams(
post_processor=completion_stream_post_processor
if request.stream else completion_response_post_processor,
postproc_args=postproc_args,
)
promise = self.llm.generate_async(
inputs=prompt,
sampling_params=sampling_params,
_postproc_params=postproc_params,
streaming=request.stream,
disaggregated_params=disaggregated_params
)
asyncio.create_task(self.await_disconnected(raw_request, promise))
if not self.postproc_worker_enabled:
postproc_args.tokenizer = self.tokenizer
postproc_args.num_prompt_tokens = len(promise.prompt_token_ids)
promises.append(promise)
postproc_params_collection.append(None if self.postproc_worker_enabled else postproc_params)
generator = merge_promises(promises, postproc_params_collection)
if request.stream:
response_generator = create_completion_generator(
generator)
return StreamingResponse(content=response_generator,
media_type="text/event-stream")
else:
response = await create_completion_response(
generator)
return JSONResponse(content=response.model_dump())
except CppExecutorError:
# If internal executor error is raised, shutdown the server
signal.raise_signal(signal.SIGINT)
except Exception as e:
print(f"Encountered an exception: {str(e)}")
traceback.print_exc()
return self.create_error_response(str(e))
async def __call__(self, host, port):
config = uvicorn.Config(self.app,
host=host,
port=port,
log_level="info",
timeout_keep_alive=TIMEOUT_KEEP_ALIVE)
await uvicorn.Server(config).serve()