TensorRT-LLMs/tests/integration/evaltool/patch_for_server/server.py
Kaiyu Xie 2631f21089
Update (#2978)
Signed-off-by: Kaiyu Xie <26294424+kaiyux@users.noreply.github.com>
2025-03-23 16:39:35 +08:00

212 lines
7.1 KiB
Python

import argparse
import asyncio
import httpx
import uvicorn
from fastapi import Body, FastAPI
from openai_schema import (ChatCompletionRequest, ChatCompletionResponse,
CompletionRequest, CompletionResponse)
from trt_llm_model import TRTLLMModel
import tensorrt_llm
runtime_rank = tensorrt_llm.mpi_rank()
world_size = tensorrt_llm.mpi_world_size()
# API Endpoints
app = FastAPI(
title="TRT-LLM Inference Server",
version="0.0.1",
)
@app.post("/v1/chat/completions/forward")
async def create_chat_completion(req: ChatCompletionRequest = Body(...), ):
# TODO: Update the following logic to align `llm_name` with the `mtbench` convention:
# \w+\/\w+ (e.g., meta-llama/Llama-2-7b).
#
# Refer to the Slack thread for more details: https://nvidia.slack.com/archives/C06DF9J4ZRA/p1721894616172659?thread_ts=1721881115.768169&cid=C06DF9J4ZRA
# Currently, the `trtllm` server accepts `llm_name` in a format not supported by `mtbench`.
# We can replace `/` with `-` to match the Hugging Face model naming conventions.
#
# if req.model != model.engine_path:
# raise HTTPException(
# status_code=400,
# detail=f"Invalid model name specified. Please use the correct model: `{model.engine_path}`.",
# )
context = ""
for m in req.messages:
if m.role == "system":
context += m.content + "\n\n"
elif m.role == "user":
context += "User:\n" + m.content + "\n\n"
elif m.role == "assistant":
context += "Assistant:\n" + m.content + "\n\n"
else:
raise NotImplementedError("Unsupported role: " + str(m.role))
kwargs = {"context": [context]}
if req.max_tokens is not None:
kwargs.update({"max_output_len": req.max_tokens})
if req.temperature is not None:
kwargs.update({"temperature": req.temperature})
if req.top_p is not None:
kwargs.update({"top_p": req.top_p})
if req.stop is not None:
kwargs.update({"stop": req.stop})
prediction = model.generate_text(**kwargs)
choice = {
"index": 0,
"message": {
"role": "assistant",
"content": prediction["predictions"][0],
},
}
if req.logprobs:
choice.update({
"logprobs": {
"content": [{
"token": token,
"logprob": logprob
} for token, logprob in zip(prediction["tokens"][0],
prediction["logprobs"][0])]
},
})
response = {"model": req.model, "choices": [choice]}
return response
@app.post("/v1/completions/forward")
async def create_completion(req: CompletionRequest = Body(...), ):
# TODO: Update the following logic to align `llm_name` with the `mtbench` convention:
# \w+\/\w+ (e.g., meta-llama/Llama-2-7b).
#
# Refer to the Slack thread for more details: https://nvidia.slack.com/archives/C06DF9J4ZRA/p1721894616172659?thread_ts=1721881115.768169&cid=C06DF9J4ZRA
# Currently, the `trtllm` server accepts `llm_name` in a format not supported by `mtbench`.
# We can replace `/` with `-` to match the Hugging Face model naming conventions.
#
# if req.model != model.engine_path:
# raise HTTPException(
# status_code=400,
# detail=f"Invalid model name specified. Please use the correct model: `{model.engine_path}`.",
# )
kwargs = {"context": [req.prompt]}
if req.max_tokens is not None:
kwargs.update({"max_output_len": req.max_tokens})
if req.temperature is not None:
kwargs.update({"temperature": req.temperature})
if req.top_p is not None:
kwargs.update({"top_p": req.top_p})
if req.echo is not None:
kwargs.update({"echo": req.echo})
if req.stop is not None:
kwargs.update({"stop": req.stop})
prediction = model.generate_text(**kwargs)
choice = {
"index": 0,
"text": prediction["predictions"][0],
}
if req.logprobs:
assert ("logprobs" in prediction.keys()
), "TRT LLM engine hasn't returned context_logits"
choice.update({
"logprobs": {
"token_logprobs": prediction["logprobs"][0],
"tokens": prediction["tokens"][0],
},
})
response = {"model": req.model, "choices": [choice]}
return response
def run_server(port):
headers = {
"Content-Type": "application/json",
}
app2 = FastAPI(
title="TRT-LLM Inference Server",
version="0.0.1",
)
@app2.post("/v1/chat/completions")
async def create_chat_completion(
req: ChatCompletionRequest = Body(...), ) -> ChatCompletionResponse:
async with httpx.AsyncClient(timeout=httpx.Timeout(36000.0)) as client:
tasks = []
for i in range(0, world_size):
_port = 12479 + i
task = client.post(
f"http://0.0.0.0:{_port}/v1/chat/completions/forward",
headers=headers,
json=req.dict())
tasks.append(task)
try:
responses = await asyncio.gather(*tasks)
except httpx.RequestError as e:
print(f"Network error: {e}")
return {"error": "Network error occurred during request"}
return ChatCompletionResponse(**responses[0].json())
@app2.post("/v1/completions")
async def create_completion(
req: CompletionRequest = Body(...), ) -> CompletionResponse:
async with httpx.AsyncClient(timeout=httpx.Timeout(36000.0)) as client:
tasks = []
for i in range(0, world_size):
_port = 12479 + i
task = client.post(
f"http://0.0.0.0:{_port}/v1/completions/forward",
headers=headers,
json=req.dict())
tasks.append(task)
try:
responses = await asyncio.gather(*tasks)
except httpx.RequestError as e:
print(f"Network error: {e}")
return {"error": "Network error occurred during request"}
return CompletionResponse(**responses[0].json())
uvicorn.run(app2, host="0.0.0.0", port=port)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--engine_path", type=str, required=True)
parser.add_argument("--tokenizer_path", type=str, required=True)
parser.add_argument("--max_output_len", type=int, default=256)
parser.add_argument("--lookahead_config", type=str, default=None)
args = parser.parse_args()
print(args)
model = TRTLLMModel(engine_path=args.engine_path,
tokenizer_path=args.tokenizer_path,
max_output_len=args.max_output_len,
lookahead_config=args.lookahead_config)
if runtime_rank == 0:
import multiprocessing
p = multiprocessing.Process(target=run_server, args=(12478, ))
p.start()
uvicorn.run(app, host="0.0.0.0", port=12479 + runtime_rank)
p.join()