TensorRT-LLMs/examples/apps/fastapi_server.py
Kaiyu Xie bca9a33b02
Update TensorRT-LLM (#2008)
* Update TensorRT-LLM

---------

Co-authored-by: Timur Abishev <abishev.timur@gmail.com>
Co-authored-by: MahmoudAshraf97 <hassouna97.ma@gmail.com>
Co-authored-by: Saeyoon Oh <saeyoon.oh@furiosa.ai>
Co-authored-by: hattizai <hattizai@gmail.com>
2024-07-23 23:05:09 +08:00

113 lines
3.7 KiB
Python
Executable File

#!/usr/bin/env python
import asyncio
import json
import logging
from typing import AsyncGenerator, Optional
import click
import uvicorn
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse, Response, StreamingResponse
from tensorrt_llm.hlapi import LLM, BuildConfig, KvCacheConfig, SamplingParams
TIMEOUT_KEEP_ALIVE = 5 # seconds.
class LlmServer:
def __init__(self, llm: LLM, kv_cache_config: KvCacheConfig):
self.llm = llm
self.kv_cache_config = kv_cache_config
self.app = FastAPI()
self.register_routes()
def register_routes(self):
self.app.add_api_route("/stats", self.stats, methods=["GET"])
self.app.add_api_route("/health", self.health, methods=["GET"])
self.app.add_api_route("/generate", self.generate, methods=["POST"])
async def stats(self) -> Response:
content = await self.llm.aget_stats()
return JSONResponse(json.loads(content))
async def health(self) -> Response:
return Response(status_code=200)
async def generate(self, request: Request) -> Response:
''' Generate completion for the request.
The request should be a JSON object with the following fields:
- prompt: the prompt to use for the generation.
- stream: whether to stream the results or not.
- other fields: the sampling parameters (See `SamplingParams` for details).
'''
request_dict = await request.json()
prompt = request_dict.pop("prompt", "")
streaming = request_dict.pop("streaming", False)
sampling_params = SamplingParams(**request_dict)
promise = self.llm.generate_async(prompt,
streaming=streaming,
sampling_params=sampling_params)
async def stream_results() -> AsyncGenerator[bytes, None]:
async for output in promise:
yield output.outputs[0].text_diff.encode("utf-8")
if streaming:
return StreamingResponse(stream_results())
# Non-streaming case
await promise.aresult()
return JSONResponse({"text": promise.outputs[0].text})
async def __call__(self, host, port):
config = uvicorn.Config(self.app,
host=host,
port=port,
log_level="info",
timeout_keep_alive=TIMEOUT_KEEP_ALIVE)
await uvicorn.Server(config).serve()
@click.command()
@click.argument("model_dir")
@click.option("--tokenizer", type=str, default=None)
@click.option("--host", type=str, default=None)
@click.option("--port", type=int, default=8000)
@click.option("--max_beam_width", type=int, default=1)
@click.option("--tp_size", type=int, default=1)
@click.option("--pp_size", type=int, default=1)
def entrypoint(model_dir: str,
tokenizer: Optional[str] = None,
host: Optional[str] = None,
port: int = 8000,
max_beam_width: int = 1,
tp_size: int = 1,
pp_size: int = 1):
host = host or "0.0.0.0"
port = port or 8000
logging.info(f"Starting server at {host}:{port}")
build_config = BuildConfig(max_batch_size=10, max_beam_width=max_beam_width)
llm = LLM(model_dir,
tokenizer,
tensor_parallel_size=tp_size,
pipeline_parallel_size=pp_size,
build_config=build_config)
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.8)
server = LlmServer(llm=llm, kv_cache_config=kv_cache_config)
asyncio.run(server(host, port))
if __name__ == "__main__":
entrypoint()