mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
487 lines
20 KiB
Python
487 lines
20 KiB
Python
#!/usr/bin/env python
|
|
import asyncio
|
|
import logging
|
|
from http import HTTPStatus
|
|
from pathlib import Path
|
|
from typing import (AsyncGenerator, AsyncIterator, List, Optional, Tuple,
|
|
TypedDict)
|
|
|
|
import click
|
|
import uvicorn
|
|
from fastapi import FastAPI
|
|
from fastapi.exceptions import RequestValidationError
|
|
from fastapi.responses import JSONResponse, Response, StreamingResponse
|
|
from openai.types.chat import ChatCompletionMessageParam
|
|
from transformers import AutoTokenizer, PreTrainedTokenizer
|
|
|
|
# yapf: disable
|
|
from tensorrt_llm.hlapi import LLM, BuildConfig, KvCacheConfig
|
|
from tensorrt_llm.hlapi.llm import RequestOutput
|
|
from tensorrt_llm.hlapi.openai_protocol import (
|
|
ChatCompletionLogProbs, ChatCompletionLogProbsContent,
|
|
ChatCompletionNamedToolChoiceParam, ChatCompletionRequest,
|
|
ChatCompletionResponse, ChatCompletionResponseChoice,
|
|
ChatCompletionResponseStreamChoice, ChatCompletionStreamResponse,
|
|
ChatMessage, CompletionRequest, CompletionResponse,
|
|
CompletionResponseChoice, CompletionResponseStreamChoice,
|
|
CompletionStreamResponse, DeltaMessage, ErrorResponse, FunctionCall,
|
|
ModelCard, ModelList, ToolCall, UsageInfo)
|
|
from tensorrt_llm.version import __version__ as VERSION
|
|
|
|
# yapf: enale
|
|
TIMEOUT_KEEP_ALIVE = 5 # seconds.
|
|
|
|
|
|
class ConversationMessage(TypedDict):
|
|
role: str
|
|
content: str
|
|
|
|
|
|
def parse_chat_message_content(
|
|
message: ChatCompletionMessageParam, ) -> ConversationMessage:
|
|
role = message["role"]
|
|
content = message.get("content")
|
|
|
|
if content is None:
|
|
return []
|
|
if isinstance(content, str):
|
|
return [ConversationMessage(role=role, content=content)]
|
|
|
|
# for Iterable[ChatCompletionContentPartTextParam]
|
|
texts: List[str] = []
|
|
for part in content:
|
|
part_type = part["type"]
|
|
if part_type == "text":
|
|
text = part["text"]
|
|
texts.append(text)
|
|
else:
|
|
raise NotImplementedError(f"{part_type} is not supported")
|
|
|
|
text_prompt = "\n".join(texts)
|
|
return [ConversationMessage(role=role, content=text_prompt)]
|
|
|
|
|
|
class OpenaiServer:
|
|
|
|
def __init__(self,
|
|
llm: LLM,
|
|
model: str,
|
|
kv_cache_config: KvCacheConfig,
|
|
hf_tokenizer: PreTrainedTokenizer = None):
|
|
self.llm = llm
|
|
self.kv_cache_config = kv_cache_config
|
|
self.tokenizer = hf_tokenizer
|
|
|
|
model_dir = Path(model)
|
|
if model_dir.exists() and model_dir.is_dir():
|
|
self.model = model_dir.name
|
|
else:
|
|
self.model = model
|
|
|
|
self.app = FastAPI()
|
|
|
|
@self.app.exception_handler(RequestValidationError)
|
|
async def validation_exception_handler(_, exc):
|
|
return self.create_error_response(message=str(exc))
|
|
|
|
self.register_routes()
|
|
|
|
@staticmethod
|
|
def create_error_response(
|
|
message: str,
|
|
err_type: str = "BadRequestError",
|
|
status_code: HTTPStatus = HTTPStatus.BAD_REQUEST) -> ErrorResponse:
|
|
error_response = ErrorResponse(message=message,
|
|
type=err_type,
|
|
code=status_code.value)
|
|
return JSONResponse(content=error_response.model_dump(),
|
|
status_code=error_response.code)
|
|
|
|
def register_routes(self):
|
|
self.app.add_api_route("/health", self.health, methods=["GET"])
|
|
self.app.add_api_route("/version", self.version, methods=["GET"])
|
|
self.app.add_api_route("/v1/models", self.get_model, methods=["GET"])
|
|
self.app.add_api_route("/v1/completions",
|
|
self.openai_completion,
|
|
methods=["POST"])
|
|
self.app.add_api_route("/v1/chat/completions",
|
|
self.openai_chat,
|
|
methods=["POST"])
|
|
|
|
async def health(self) -> Response:
|
|
return Response(status_code=200)
|
|
|
|
async def version(self) -> JSONResponse:
|
|
ver = {"version": VERSION}
|
|
return JSONResponse(content=ver)
|
|
|
|
async def get_model(self) -> JSONResponse:
|
|
model_list = ModelList(data=[ModelCard(id=self.model)])
|
|
return JSONResponse(content=model_list.model_dump())
|
|
|
|
async def openai_chat(self, request: ChatCompletionRequest) -> Response:
|
|
|
|
def get_role() -> str:
|
|
if request.add_generation_prompt:
|
|
role = "assistant"
|
|
else:
|
|
role = request.messages[-1]["role"]
|
|
return role
|
|
|
|
def stream_usage_info(prompt_tokens: int, completion_tokens: int):
|
|
if request.stream_options and request.stream_options.include_usage and \
|
|
request.stream_options.continuous_usage_stats:
|
|
usage = UsageInfo(prompt_tokens=prompt_tokens,
|
|
completion_tokens=completion_tokens,
|
|
total_tokens=prompt_tokens +
|
|
completion_tokens)
|
|
else:
|
|
usage = None
|
|
return usage
|
|
|
|
def create_logprobs(token_ids: List[int],
|
|
logprobs: List[float]) -> ChatCompletionLogProbs:
|
|
assert len(token_ids) == len(logprobs), \
|
|
"token_ids and logprobs have different lengths"
|
|
content: List[ChatCompletionLogProbsContent] = []
|
|
for token_id, logprob in zip(token_ids, logprobs):
|
|
token = self.tokenizer.decode(token_id)
|
|
# returning multiple logprobs is not supported
|
|
first_logprob = ChatCompletionLogProbsContent(
|
|
token=token, logprob=max(logprob, -9999.0),
|
|
bytes=list(token.encode("utf-8", errors="replace"))
|
|
)
|
|
content.append(first_logprob)
|
|
chat_logprobs = ChatCompletionLogProbs(content=content)
|
|
return chat_logprobs
|
|
|
|
async def chat_stream_generator(promise: RequestOutput) -> AsyncGenerator[str, None]:
|
|
first_iteration = True
|
|
num_choices = 1 if request.n is None else request.n
|
|
finish_reason_sent = [False] * num_choices
|
|
role = get_role()
|
|
|
|
def yield_first_chat(num_tokens: int, role: str = None, content: str = None):
|
|
for i in range(num_choices):
|
|
choice_data = ChatCompletionResponseStreamChoice(
|
|
index=i,
|
|
delta=DeltaMessage(
|
|
role=role, content=content),
|
|
logprobs=None,
|
|
finish_reason=None)
|
|
chunk = ChatCompletionStreamResponse(
|
|
choices=[choice_data], model=self.model)
|
|
chunk.usage = stream_usage_info(num_tokens, 0)
|
|
|
|
data = chunk.model_dump_json(exclude_unset=True)
|
|
yield f"data: {data}\n\n"
|
|
|
|
async for res in promise:
|
|
prompt_tokens = len(res.prompt_token_ids)
|
|
if first_iteration:
|
|
yield_first_chat(prompt_tokens, role=role)
|
|
|
|
if request.echo:
|
|
last_msg_content = ""
|
|
if conversation and conversation[-1].get(
|
|
"content") and conversation[-1].get(
|
|
"role") == role:
|
|
last_msg_content = conversation[-1][
|
|
"content"]
|
|
|
|
if last_msg_content:
|
|
yield_first_chat(prompt_tokens, content=last_msg_content)
|
|
first_iteration = False
|
|
|
|
for output in res.outputs:
|
|
i = output.index
|
|
|
|
if finish_reason_sent[i]:
|
|
continue
|
|
|
|
delta_text = output.text_diff
|
|
if request.tool_choice and type(
|
|
request.tool_choice
|
|
) is ChatCompletionNamedToolChoiceParam:
|
|
delta_message = DeltaMessage(tool_calls=[
|
|
ToolCall(function=FunctionCall(
|
|
name=request.tool_choice.function.name,
|
|
arguments=delta_text))
|
|
])
|
|
else:
|
|
delta_message = DeltaMessage(content=delta_text)
|
|
|
|
if delta_text:
|
|
# Send token-by-token response for each request.n
|
|
choice_data = ChatCompletionResponseStreamChoice(
|
|
index=i,
|
|
delta=delta_message,
|
|
finish_reason=None)
|
|
if request.logprobs:
|
|
logprobs = output.logprobs_diff
|
|
token_ids = output.token_ids_diff
|
|
choice_data.logprobs = create_logprobs(token_ids, logprobs)
|
|
chunk = ChatCompletionStreamResponse(
|
|
choices=[choice_data], model=self.model)
|
|
chunk.usage = stream_usage_info(
|
|
prompt_tokens, output.length)
|
|
data = chunk.model_dump_json()
|
|
yield f"data: {data}\n\n"
|
|
else:
|
|
finish_reason_sent[i] = True
|
|
|
|
if (request.stream_options
|
|
and request.stream_options.include_usage):
|
|
completion_tokens = sum(output.length
|
|
for output in promise.outputs)
|
|
final_usage = UsageInfo(
|
|
prompt_tokens=prompt_tokens,
|
|
completion_tokens=completion_tokens,
|
|
total_tokens=prompt_tokens + completion_tokens,
|
|
)
|
|
|
|
final_usage_chunk = ChatCompletionStreamResponse(
|
|
choices=[], model=self.model, usage=final_usage)
|
|
final_usage_data = final_usage_chunk.model_dump_json()
|
|
yield f"data: {final_usage_data}\n\n"
|
|
|
|
async def create_chat_response(promise: RequestOutput) -> JSONResponse:
|
|
await promise.aresult()
|
|
choices: List[ChatCompletionResponseChoice] = []
|
|
role = get_role()
|
|
for output in promise.outputs:
|
|
if request.tool_choice and isinstance(
|
|
request.tool_choice,
|
|
ChatCompletionNamedToolChoiceParam):
|
|
message = ChatMessage(
|
|
role=role,
|
|
content="",
|
|
tool_calls=[
|
|
ToolCall(function=FunctionCall(
|
|
name=request.tool_choice.function.name,
|
|
arguments=output.text))
|
|
])
|
|
else:
|
|
message = ChatMessage(role=role, content=output.text)
|
|
choice = ChatCompletionResponseChoice(
|
|
index=output.index,
|
|
message=message,
|
|
)
|
|
|
|
if request.logprobs:
|
|
choice.logprobs = create_logprobs(output.token_ids, output.logprobs)
|
|
choices.append(choice)
|
|
|
|
if request.echo:
|
|
last_msg_content = ""
|
|
if conversation and conversation[-1].get(
|
|
"content") and conversation[-1].get("role") == role:
|
|
last_msg_content = conversation[-1]["content"]
|
|
for choice in choices:
|
|
full_message = last_msg_content + choice.message.content
|
|
choice.message.content = full_message
|
|
|
|
num_prompt_tokens = len(promise.prompt_token_ids)
|
|
num_generated_tokens = sum(
|
|
len(output.token_ids) for output in promise.outputs)
|
|
usage = UsageInfo(
|
|
prompt_tokens=num_prompt_tokens,
|
|
completion_tokens=num_generated_tokens,
|
|
total_tokens=num_prompt_tokens + num_generated_tokens,
|
|
)
|
|
response = ChatCompletionResponse(
|
|
model=self.model,
|
|
choices=choices,
|
|
usage=usage,
|
|
)
|
|
return response
|
|
|
|
try:
|
|
conversation: List[ConversationMessage] = []
|
|
for msg in request.messages:
|
|
conversation.extend(parse_chat_message_content(msg))
|
|
tool_dicts = None if request.tools is None else [
|
|
tool.model_dump() for tool in request.tools
|
|
]
|
|
prompt: str = self.tokenizer.apply_chat_template(
|
|
conversation=conversation,
|
|
tokenize=False,
|
|
add_generation_prompt=request.add_generation_prompt,
|
|
tools=tool_dicts,
|
|
documents=request.documents,
|
|
chat_template=request.chat_template,
|
|
**(request.chat_template_kwargs or {}),
|
|
)
|
|
sampling_params = request.to_sampling_params()
|
|
|
|
promise = self.llm.generate_async(
|
|
inputs=prompt,
|
|
sampling_params=sampling_params,
|
|
streaming=request.stream,
|
|
)
|
|
if request.stream:
|
|
response_generator = chat_stream_generator(promise)
|
|
return StreamingResponse(content=response_generator,
|
|
media_type="text/event-stream")
|
|
else:
|
|
response = await create_chat_response(promise)
|
|
return JSONResponse(content=response.model_dump())
|
|
|
|
except Exception as e:
|
|
return self.create_error_response(str(e))
|
|
|
|
async def openai_completion(self, request: CompletionRequest) -> Response:
|
|
|
|
def merge_promises(promises: List[RequestOutput]) -> AsyncIterator[Tuple[int, RequestOutput]]:
|
|
outputs = asyncio.Queue()
|
|
finished = [False] * len(promises)
|
|
|
|
async def producer(i: int, promise: RequestOutput):
|
|
async for output in promise:
|
|
await outputs.put((i, output))
|
|
finished[i] = True
|
|
|
|
_tasks = [asyncio.create_task(producer(i, promise))
|
|
for i, promise in enumerate(promises)
|
|
]
|
|
|
|
async def consumer():
|
|
while not all(finished) or not outputs.empty():
|
|
item = await outputs.get()
|
|
yield item
|
|
await asyncio.gather(*_tasks)
|
|
|
|
return consumer()
|
|
|
|
async def create_completion_generator(generator: AsyncIterator[Tuple[int, RequestOutput]],
|
|
num_choices: int):
|
|
num_repsonse_per_request = 1 if request.n is None else request.n
|
|
echoed = [False] * num_choices
|
|
async for prompt_idx, requst_output in generator:
|
|
prompt = requst_output.prompt
|
|
for gen_idx, output in enumerate(requst_output.outputs):
|
|
response_idx = prompt_idx * num_repsonse_per_request + gen_idx
|
|
delta_text = output.text_diff
|
|
if request.echo and not echoed[response_idx]:
|
|
delta_text = prompt + delta_text
|
|
echoed[response_idx] = True
|
|
response = CompletionStreamResponse(
|
|
model=self.model,
|
|
choices=[
|
|
CompletionResponseStreamChoice(
|
|
index=response_idx, text=delta_text)
|
|
])
|
|
response_json = response.model_dump_json(
|
|
exclude_unset=False)
|
|
yield f"data: {response_json}\n\n"
|
|
yield f"data: [DONE]\n\n"
|
|
|
|
async def create_completion_response(generator: AsyncIterator[Tuple[int, RequestOutput]],
|
|
num_choices: int):
|
|
choices = [None] * num_choices
|
|
num_repsonse_per_request = 1 if request.n is None else request.n
|
|
num_prompt_tokens = num_gen_tokens = 0
|
|
async for prompt_idx, request_output in generator:
|
|
num_prompt_tokens += len(request_output.prompt_token_ids)
|
|
for gen_idx, output in enumerate(request_output.outputs):
|
|
num_gen_tokens += len(output.token_ids)
|
|
output_text = output.text
|
|
if request.echo:
|
|
output_text = request_output.prompt + output_text
|
|
idx = prompt_idx * num_repsonse_per_request + gen_idx
|
|
choice = CompletionResponseChoice(
|
|
index=idx,
|
|
text=output_text,
|
|
)
|
|
choices[idx] = choice
|
|
|
|
usage_info = UsageInfo(
|
|
prompt_tokens=num_prompt_tokens,
|
|
completion_tokens=num_gen_tokens,
|
|
total_tokens=num_gen_tokens + num_prompt_tokens,
|
|
)
|
|
response = CompletionResponse(
|
|
model=self.model,
|
|
choices=choices,
|
|
usage=usage_info,
|
|
)
|
|
return response
|
|
|
|
try:
|
|
if isinstance(request.prompt, str) or \
|
|
(isinstance(request.prompt, list) and isinstance(request.prompt[0], int)):
|
|
prompts = [request.prompt]
|
|
else:
|
|
prompts = request.prompt
|
|
|
|
promises: List[RequestOutput] = []
|
|
sampling_params = request.to_sampling_params()
|
|
for prompt in prompts:
|
|
promise = self.llm.generate_async(
|
|
inputs=prompt,
|
|
sampling_params=sampling_params,
|
|
streaming=request.stream,
|
|
)
|
|
promises.append(promise)
|
|
generator = merge_promises(promises)
|
|
num_choices = len(prompts) if request.n is None else len(prompts) * request.n
|
|
if request.stream:
|
|
response_generator = create_completion_generator(generator, num_choices)
|
|
return StreamingResponse(content=response_generator,
|
|
media_type="text/event-stream")
|
|
else:
|
|
response = await create_completion_response(generator, num_choices)
|
|
return JSONResponse(content=response.model_dump())
|
|
except Exception as e:
|
|
return self.create_error_response(str(e))
|
|
|
|
async def __call__(self, host, port):
|
|
config = uvicorn.Config(self.app,
|
|
host=host,
|
|
port=port,
|
|
log_level="info",
|
|
timeout_keep_alive=TIMEOUT_KEEP_ALIVE)
|
|
await uvicorn.Server(config).serve()
|
|
|
|
|
|
@click.command()
|
|
@click.argument("model_dir")
|
|
@click.option("--tokenizer", type=str, default=None)
|
|
@click.option("--host", type=str, default=None)
|
|
@click.option("--port", type=int, default=8000)
|
|
@click.option("--max_beam_width", type=int, default=1)
|
|
@click.option("--tp_size", type=int, default=1)
|
|
@click.option("--pp_size", type=int, default=1)
|
|
def entrypoint(model_dir: str,
|
|
tokenizer: Optional[str] = None,
|
|
host: Optional[str] = None,
|
|
port: int = 8000,
|
|
max_beam_width: int = 1,
|
|
tp_size: int = 1,
|
|
pp_size: int = 1):
|
|
host = host or "0.0.0.0"
|
|
port = port or 8000
|
|
logging.info(f"Starting server at {host}:{port}")
|
|
|
|
build_config = BuildConfig(max_batch_size=10, max_beam_width=max_beam_width)
|
|
|
|
llm = LLM(model_dir,
|
|
tokenizer,
|
|
tensor_parallel_size=tp_size,
|
|
pipeline_parallel_size=pp_size,
|
|
build_config=build_config)
|
|
|
|
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.8)
|
|
hf_tokenizer = AutoTokenizer.from_pretrained(tokenizer or model_dir)
|
|
|
|
server = OpenaiServer(llm=llm,
|
|
model=model_dir,
|
|
kv_cache_config=kv_cache_config,
|
|
hf_tokenizer=hf_tokenizer)
|
|
|
|
asyncio.run(server(host, port))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
entrypoint()
|