TensorRT-LLMs/examples/apps/fastapi_server.py
Yan Chunwei 9bd42ecf9b
[TRTLLM-5208][BREAKING CHANGE] chore: make pytorch LLM the default (#5312)
Signed-off-by: Superjomn <328693+Superjomn@users.noreply.github.com>
2025-06-20 03:01:10 +08:00

142 lines
4.8 KiB
Python
Executable File

"""
NOTE: This FastAPI-based server is only an example for demonstrating the usage
of TensorRT-LLM LLM API. It is not intended for production use.
For production, use the `trtllm-serve` command. The server exposes OpenAI compatible API endpoints.
"""
#!/usr/bin/env python
import asyncio
import json
import logging
import signal
from contextlib import asynccontextmanager
from http import HTTPStatus
from typing import AsyncGenerator, Optional
import click
import uvicorn
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse, Response, StreamingResponse
from tensorrt_llm._tensorrt_engine import LLM
from tensorrt_llm.executor import CppExecutorError, RequestError
from tensorrt_llm.llmapi import BuildConfig, KvCacheConfig, SamplingParams
TIMEOUT_KEEP_ALIVE = 5 # seconds.
class LlmServer:
def __init__(self, llm: LLM):
self.llm = llm
@asynccontextmanager
async def lifespan(app: FastAPI):
# terminate rank0 worker
yield
self.llm.shutdown()
self.app = FastAPI(lifespan=lifespan)
self.register_routes()
def register_routes(self):
self.app.add_api_route("/stats", self.stats, methods=["GET"])
self.app.add_api_route("/health", self.health, methods=["GET"])
self.app.add_api_route("/generate", self.generate, methods=["POST"])
async def stats(self) -> Response:
content = await self.llm.aget_stats()
return JSONResponse(json.loads(content))
async def health(self) -> Response:
return Response(status_code=200)
async def generate(self, request: Request) -> Response:
''' Generate completion for the request.
The request should be a JSON object with the following fields:
- prompt: the prompt to use for the generation.
- stream: whether to stream the results or not.
- other fields: the sampling parameters (See `SamplingParams` for details).
'''
request_dict = await request.json()
prompt = request_dict.pop("prompt", "")
streaming = request_dict.pop("streaming", False)
sampling_params = SamplingParams(**request_dict)
try:
promise = self.llm.generate_async(prompt,
streaming=streaming,
sampling_params=sampling_params)
async def stream_results() -> AsyncGenerator[bytes, None]:
async for output in promise:
yield output.outputs[0].text_diff.encode("utf-8")
if streaming:
return StreamingResponse(stream_results())
# Non-streaming case
await promise.aresult()
return JSONResponse({"text": promise.outputs[0].text})
except RequestError as e:
return JSONResponse(content=str(e),
status_code=HTTPStatus.BAD_REQUEST)
except CppExecutorError:
# If internal executor error is raised, shutdown the server
signal.raise_signal(signal.SIGINT)
async def __call__(self, host, port):
config = uvicorn.Config(self.app,
host=host,
port=port,
log_level="info",
timeout_keep_alive=TIMEOUT_KEEP_ALIVE)
await uvicorn.Server(config).serve()
@click.command()
@click.argument("model_dir")
@click.option("--tokenizer", type=str, default=None)
@click.option("--host", type=str, default=None)
@click.option("--port", type=int, default=8000)
@click.option("--max_beam_width", type=int, default=1)
@click.option("--tp_size", type=int, default=1)
@click.option("--pp_size", type=int, default=1)
@click.option("--cp_size", type=int, default=1)
@click.option("--kv_cache_free_gpu_memory_fraction", type=float, default=0.8)
def entrypoint(model_dir: str,
tokenizer: Optional[str] = None,
host: Optional[str] = None,
port: int = 8000,
max_beam_width: int = 1,
tp_size: int = 1,
pp_size: int = 1,
cp_size: int = 1,
kv_cache_free_gpu_memory_fraction: float = 0.8):
host = host or "0.0.0.0"
port = port or 8000
logging.info(f"Starting server at {host}:{port}")
build_config = BuildConfig(max_batch_size=10, max_beam_width=max_beam_width)
kv_cache_config = KvCacheConfig(
free_gpu_memory_fraction=kv_cache_free_gpu_memory_fraction)
llm = LLM(model_dir,
tokenizer,
tensor_parallel_size=tp_size,
pipeline_parallel_size=pp_size,
build_config=build_config,
kv_cache_config=kv_cache_config)
server = LlmServer(llm=llm)
asyncio.run(server(host, port))
if __name__ == "__main__":
entrypoint()