mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
142 lines
4.8 KiB
Python
Executable File
142 lines
4.8 KiB
Python
Executable File
"""
|
|
NOTE: This FastAPI-based server is only an example for demonstrating the usage
|
|
of TensorRT-LLM LLM API. It is not intended for production use.
|
|
For production, use the `trtllm-serve` command. The server exposes OpenAI compatible API endpoints.
|
|
"""
|
|
|
|
#!/usr/bin/env python
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
import signal
|
|
from contextlib import asynccontextmanager
|
|
from http import HTTPStatus
|
|
from typing import AsyncGenerator, Optional
|
|
|
|
import click
|
|
import uvicorn
|
|
from fastapi import FastAPI, Request
|
|
from fastapi.responses import JSONResponse, Response, StreamingResponse
|
|
|
|
from tensorrt_llm._tensorrt_engine import LLM
|
|
from tensorrt_llm.executor import CppExecutorError, RequestError
|
|
from tensorrt_llm.llmapi import BuildConfig, KvCacheConfig, SamplingParams
|
|
|
|
TIMEOUT_KEEP_ALIVE = 5 # seconds.
|
|
|
|
|
|
class LlmServer:
|
|
|
|
def __init__(self, llm: LLM):
|
|
self.llm = llm
|
|
|
|
@asynccontextmanager
|
|
async def lifespan(app: FastAPI):
|
|
# terminate rank0 worker
|
|
yield
|
|
self.llm.shutdown()
|
|
|
|
self.app = FastAPI(lifespan=lifespan)
|
|
self.register_routes()
|
|
|
|
def register_routes(self):
|
|
self.app.add_api_route("/stats", self.stats, methods=["GET"])
|
|
self.app.add_api_route("/health", self.health, methods=["GET"])
|
|
self.app.add_api_route("/generate", self.generate, methods=["POST"])
|
|
|
|
async def stats(self) -> Response:
|
|
content = await self.llm.aget_stats()
|
|
return JSONResponse(json.loads(content))
|
|
|
|
async def health(self) -> Response:
|
|
return Response(status_code=200)
|
|
|
|
async def generate(self, request: Request) -> Response:
|
|
''' Generate completion for the request.
|
|
|
|
The request should be a JSON object with the following fields:
|
|
- prompt: the prompt to use for the generation.
|
|
- stream: whether to stream the results or not.
|
|
- other fields: the sampling parameters (See `SamplingParams` for details).
|
|
'''
|
|
request_dict = await request.json()
|
|
|
|
prompt = request_dict.pop("prompt", "")
|
|
streaming = request_dict.pop("streaming", False)
|
|
|
|
sampling_params = SamplingParams(**request_dict)
|
|
|
|
try:
|
|
promise = self.llm.generate_async(prompt,
|
|
streaming=streaming,
|
|
sampling_params=sampling_params)
|
|
|
|
async def stream_results() -> AsyncGenerator[bytes, None]:
|
|
async for output in promise:
|
|
yield output.outputs[0].text_diff.encode("utf-8")
|
|
|
|
if streaming:
|
|
return StreamingResponse(stream_results())
|
|
|
|
# Non-streaming case
|
|
await promise.aresult()
|
|
return JSONResponse({"text": promise.outputs[0].text})
|
|
except RequestError as e:
|
|
return JSONResponse(content=str(e),
|
|
status_code=HTTPStatus.BAD_REQUEST)
|
|
except CppExecutorError:
|
|
# If internal executor error is raised, shutdown the server
|
|
signal.raise_signal(signal.SIGINT)
|
|
|
|
async def __call__(self, host, port):
|
|
config = uvicorn.Config(self.app,
|
|
host=host,
|
|
port=port,
|
|
log_level="info",
|
|
timeout_keep_alive=TIMEOUT_KEEP_ALIVE)
|
|
await uvicorn.Server(config).serve()
|
|
|
|
|
|
@click.command()
|
|
@click.argument("model_dir")
|
|
@click.option("--tokenizer", type=str, default=None)
|
|
@click.option("--host", type=str, default=None)
|
|
@click.option("--port", type=int, default=8000)
|
|
@click.option("--max_beam_width", type=int, default=1)
|
|
@click.option("--tp_size", type=int, default=1)
|
|
@click.option("--pp_size", type=int, default=1)
|
|
@click.option("--cp_size", type=int, default=1)
|
|
@click.option("--kv_cache_free_gpu_memory_fraction", type=float, default=0.8)
|
|
def entrypoint(model_dir: str,
|
|
tokenizer: Optional[str] = None,
|
|
host: Optional[str] = None,
|
|
port: int = 8000,
|
|
max_beam_width: int = 1,
|
|
tp_size: int = 1,
|
|
pp_size: int = 1,
|
|
cp_size: int = 1,
|
|
kv_cache_free_gpu_memory_fraction: float = 0.8):
|
|
host = host or "0.0.0.0"
|
|
port = port or 8000
|
|
logging.info(f"Starting server at {host}:{port}")
|
|
|
|
build_config = BuildConfig(max_batch_size=10, max_beam_width=max_beam_width)
|
|
|
|
kv_cache_config = KvCacheConfig(
|
|
free_gpu_memory_fraction=kv_cache_free_gpu_memory_fraction)
|
|
|
|
llm = LLM(model_dir,
|
|
tokenizer,
|
|
tensor_parallel_size=tp_size,
|
|
pipeline_parallel_size=pp_size,
|
|
build_config=build_config,
|
|
kv_cache_config=kv_cache_config)
|
|
|
|
server = LlmServer(llm=llm)
|
|
|
|
asyncio.run(server(host, port))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
entrypoint()
|