TensorRT-LLMs/examples/apps/fastapi_server.py
Kaiyu Xie 250d9c293d
Update TensorRT-LLM Release branch (#1445)
* Update TensorRT-LLM

---------

Co-authored-by: Bhuvanesh Sridharan <bhuvan.sridharan@gmail.com>
Co-authored-by: Morgan Funtowicz <funtowiczmo@gmail.com>
Co-authored-by: Eddie-Wang1120 <wangjinheng1120@163.com>
Co-authored-by: meghagarwal <16129366+megha95@users.noreply.github.com>
2024-04-12 17:59:19 +08:00

83 lines
2.6 KiB
Python

import argparse
import asyncio
import json
from typing import AsyncGenerator, Optional
import uvicorn
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse, Response, StreamingResponse
from tensorrt_llm.executor import GenerationExecutorWorker, GenerationResult
TIMEOUT_KEEP_ALIVE = 5 # seconds.
TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds.
app = FastAPI()
executor: Optional[GenerationExecutorWorker] = None
@app.get("/stats")
async def stats() -> Response:
assert executor is not None
return JSONResponse(json.loads(await executor.aget_stats()))
@app.get("/health")
async def health() -> Response:
"""Health check."""
return Response(status_code=200)
@app.post("/generate")
async def generate(request: Request) -> Response:
assert executor is not None
"""Generate completion for the request.
The request should be a JSON object with the following fields:
- prompt: the prompt to use for the generation.
- stream: whether to stream the results or not.
- other fields: the sampling parameters (See `SamplingParams` for details).
"""
request_dict = await request.json()
prompt = request_dict.pop("prompt", "")
streaming = request_dict.pop("streaming", False)
promise = executor.generate_async(prompt, streaming, **request_dict)
assert isinstance(promise, GenerationResult)
async def stream_results() -> AsyncGenerator[bytes, None]:
async for output in promise:
yield (json.dumps(output.text_diff) + "\0").encode("utf-8")
if streaming:
return StreamingResponse(stream_results())
# Non-streaming case
await promise.aresult()
return JSONResponse({"text": promise.text})
async def main(args):
global executor
with GenerationExecutorWorker(args.model_dir, args.tokenizer_type,
args.max_beam_width) as executor:
executor.block_subordinates()
config = uvicorn.Config(app,
host=args.host,
port=args.port,
log_level="info",
timeout_keep_alive=TIMEOUT_KEEP_ALIVE)
await uvicorn.Server(config).serve()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("model_dir")
parser.add_argument("tokenizer_type")
parser.add_argument("--host", type=str, default=None)
parser.add_argument("--port", type=int, default=8000)
parser.add_argument("--max_beam_width", type=int, default=1)
args = parser.parse_args()
asyncio.run(main(args))