mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
1823 lines
68 KiB
Python
1823 lines
68 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
import asyncio
|
|
import json
|
|
import os
|
|
import time
|
|
import uuid
|
|
# yapf: disable
|
|
from abc import ABC, abstractmethod
|
|
from collections.abc import AsyncGenerator
|
|
from copy import copy
|
|
from typing import Any, Literal, Optional, OrderedDict, Tuple, Union
|
|
|
|
from openai.types.responses import (ResponseCompletedEvent,
|
|
ResponseContentPartAddedEvent,
|
|
ResponseContentPartDoneEvent,
|
|
ResponseCreatedEvent,
|
|
ResponseFunctionToolCall,
|
|
ResponseInProgressEvent, ResponseOutputItem,
|
|
ResponseOutputItemAddedEvent,
|
|
ResponseOutputItemDoneEvent,
|
|
ResponseOutputMessage, ResponseOutputText,
|
|
ResponseReasoningItem,
|
|
ResponseReasoningTextDeltaEvent,
|
|
ResponseReasoningTextDoneEvent,
|
|
ResponseTextDeltaEvent,
|
|
ResponseTextDoneEvent)
|
|
from openai.types.responses.response_content_part_added_event import \
|
|
PartReasoningText
|
|
from openai.types.responses.response_content_part_done_event import \
|
|
Part as ResponseContentPart
|
|
from openai.types.responses.response_function_web_search import (
|
|
ActionFind, ActionOpenPage, ActionSearch, ResponseFunctionWebSearch)
|
|
from openai.types.responses.response_reasoning_item import Content
|
|
from openai.types.responses.tool import FunctionTool, Tool
|
|
from openai_harmony import (Author, Conversation, DeveloperContent,
|
|
HarmonyEncodingName, Message, ReasoningEffort, Role,
|
|
StreamState, SystemContent, TextContent,
|
|
ToolDescription, load_harmony_encoding)
|
|
from transformers import AutoProcessor, PretrainedConfig
|
|
|
|
from tensorrt_llm.bindings import steady_clock_now
|
|
from tensorrt_llm.inputs.utils import apply_chat_template
|
|
from tensorrt_llm.llmapi import SamplingParams
|
|
from tensorrt_llm.llmapi.llm import RequestOutput
|
|
from tensorrt_llm.llmapi.reasoning_parser import (BaseReasoningParser,
|
|
ReasoningParserFactory)
|
|
from tensorrt_llm.llmapi.tokenizer import TokenizerBase, TransformersTokenizer
|
|
from tensorrt_llm.logger import logger
|
|
from tensorrt_llm.serve.chat_utils import parse_chat_messages_coroutines
|
|
from tensorrt_llm.serve.openai_protocol import (ChatCompletionMessageParam,
|
|
ChatCompletionToolsParam,
|
|
FunctionDefinition,
|
|
OpenAIBaseModel,
|
|
ReasoningAssistantMessage,
|
|
ResponseInputOutputItem,
|
|
ResponsesRequest,
|
|
ResponsesResponse,
|
|
StreamingResponsesResponse,
|
|
UCompletionRequest,
|
|
UCompletionResponse)
|
|
from tensorrt_llm.serve.tool_parser.base_tool_parser import BaseToolParser
|
|
from tensorrt_llm.serve.tool_parser.core_types import ToolCallItem
|
|
from tensorrt_llm.serve.tool_parser.tool_parser_factory import ToolParserFactory
|
|
|
|
from .harmony_adapter import HarmonyAdapter, get_harmony_adapter
|
|
|
|
# yapf: enable
|
|
|
|
# yapf: enable
|
|
|
|
REASONING_EFFORT = {
|
|
"high": ReasoningEffort.HIGH,
|
|
"medium": ReasoningEffort.MEDIUM,
|
|
"low": ReasoningEffort.LOW,
|
|
}
|
|
|
|
ENABLE_RESPONSES_DEBUG_MSG = False
|
|
|
|
|
|
def _responses_debug_log(msg):
|
|
if ENABLE_RESPONSES_DEBUG_MSG:
|
|
logger.info(msg)
|
|
|
|
|
|
_harmony_encoding = None
|
|
|
|
|
|
def _random_uuid():
|
|
return str(uuid.uuid4().hex)
|
|
|
|
|
|
def _get_encoding():
|
|
global _harmony_encoding
|
|
if _harmony_encoding is None:
|
|
_harmony_encoding = load_harmony_encoding(
|
|
HarmonyEncodingName.HARMONY_GPT_OSS)
|
|
return _harmony_encoding
|
|
|
|
|
|
def _decode_tokens(
|
|
tokens: list[int],
|
|
tokenizer: Optional[Union[TransformersTokenizer,
|
|
TokenizerBase]] = None) -> str:
|
|
if tokenizer is not None:
|
|
return tokenizer.decode(tokens)
|
|
return _get_encoding().decode(tokens)
|
|
|
|
|
|
def get_steady_clock_now_in_seconds() -> float:
|
|
return steady_clock_now().total_seconds()
|
|
|
|
|
|
def _parse_response_input(
|
|
input_msg: ResponseInputOutputItem,
|
|
prev_responses: list[Union[ResponseOutputItem, ResponseReasoningItem]]
|
|
) -> Message:
|
|
if not isinstance(input_msg, dict):
|
|
input_msg = input_msg.model_dump()
|
|
|
|
_responses_debug_log(f"------- Parsing input -----------")
|
|
_responses_debug_log(input_msg)
|
|
_responses_debug_log("")
|
|
|
|
if "type" not in input_msg or input_msg["type"] == "message":
|
|
role = input_msg["role"]
|
|
content = input_msg["content"]
|
|
if role == "system":
|
|
# User is trying to set a system message. Change it to:
|
|
# <|start|>developer<|message|># Instructions
|
|
# {instructions}<|end|>
|
|
role = "developer"
|
|
text_prefix = "Instructions:\n"
|
|
else:
|
|
text_prefix = ""
|
|
if isinstance(content, str):
|
|
msg = Message.from_role_and_content(role, text_prefix + content)
|
|
elif isinstance(content, list):
|
|
contents = [
|
|
TextContent(text=text_prefix + c["text"]) for c in content
|
|
]
|
|
msg = Message.from_role_and_contents(role, contents)
|
|
else:
|
|
logger.warning("Responses API: Invalid input message type")
|
|
msg = None
|
|
elif input_msg["type"] == "function_call_output":
|
|
call_id = input_msg["call_id"]
|
|
call_response: Optional[ResponseFunctionToolCall] = None
|
|
for prev_response in reversed(prev_responses):
|
|
if isinstance(prev_response, ResponseFunctionToolCall
|
|
) and prev_response.call_id == call_id:
|
|
call_response = prev_response
|
|
break
|
|
if call_response is None:
|
|
raise ValueError(f"No call message found for {call_id}")
|
|
msg = Message.from_author_and_content(
|
|
Author.new(Role.TOOL, f"functions.{call_response.name}"),
|
|
input_msg["output"])
|
|
elif input_msg["type"] == "reasoning":
|
|
content = input_msg["content"]
|
|
assert len(content) == 1
|
|
msg = Message.from_role_and_content(Role.ASSISTANT, content[0]["text"])
|
|
elif input_msg["type"] == "function_call":
|
|
msg = Message.from_role_and_content(Role.ASSISTANT,
|
|
input_msg["arguments"])
|
|
msg = msg.with_channel("commentary")
|
|
msg = msg.with_recipient(f"functions.{input_msg['name']}")
|
|
msg = msg.with_content_type("json")
|
|
else:
|
|
raise ValueError(f"Unknown input type: {input_msg['type']}")
|
|
return msg
|
|
|
|
|
|
class ConversationHistoryStore:
|
|
|
|
def __init__(self, resp_capacity: int = 16, max_conversations=32):
|
|
# How many responses can be stored.
|
|
self.response_capacity = resp_capacity
|
|
# How many messages can be stored in a conversation.
|
|
self.conversation_capacity = resp_capacity * 4
|
|
# How many conversations can be stored.
|
|
self.max_conversations = max_conversations
|
|
|
|
self.responses_lock = asyncio.Lock()
|
|
# Responses store, responses stored more than response_capacity will be removed in LRU policy.
|
|
self.responses: OrderedDict[str, ResponsesResponse] = OrderedDict()
|
|
|
|
self.conversations_lock = asyncio.Lock()
|
|
# Conversations store, conversations stored more than conversation_capacity will be removed in LRU policy.
|
|
self.conversations: OrderedDict[str, Union[
|
|
list[Message], list[ChatCompletionMessageParam]]] = OrderedDict()
|
|
|
|
# Map from response id to conversation id. 1 to 1 mapping.
|
|
self.response_to_conversation: dict[str, str] = {}
|
|
|
|
# Map from conversation id to response id, which is the latest response in the conversation.
|
|
self.conversation_to_response: dict[str, str] = {}
|
|
|
|
async def load_response(self, resp_id: str) -> ResponsesResponse:
|
|
_responses_debug_log(
|
|
f"ConversationHistoryStore loading resp: {resp_id}")
|
|
async with self.responses_lock:
|
|
self.responses.move_to_end(resp_id)
|
|
return self.responses.get(resp_id)
|
|
|
|
async def store_response(self,
|
|
resp: ResponsesResponse,
|
|
resp_msgs: Optional[
|
|
Union[list[Message],
|
|
list[ChatCompletionMessageParam]]] = [],
|
|
prev_resp_id: Optional[str] = None) -> None:
|
|
"""
|
|
Store the response and its messages(model output messages) in the conversation store. If the previous response id is provided,
|
|
the messages will be appended to the conversation. Otherwise, a new conversation will be created.
|
|
|
|
Args:
|
|
resp: ResponsesResponse
|
|
resp_msgs: Optional[Union[list[Message], list[ChatCompletionMessageParam]]]
|
|
prev_resp_id: Optional[str]
|
|
|
|
Returns:
|
|
None
|
|
"""
|
|
resp_id = resp.id
|
|
_responses_debug_log(
|
|
f"ConversationHistoryStore storing resp: {resp_id}")
|
|
if ENABLE_RESPONSES_DEBUG_MSG:
|
|
_responses_debug_log(f" -> resp_msgs:")
|
|
for msg in resp_msgs:
|
|
_responses_debug_log(f" -> {msg}")
|
|
|
|
async with self.responses_lock:
|
|
self.responses[resp_id] = resp
|
|
if len(self.responses) > self.response_capacity:
|
|
self._pop_response()
|
|
|
|
async with self.conversations_lock:
|
|
conversation_id: str
|
|
if resp_id in self.response_to_conversation:
|
|
conversation_id = self.response_to_conversation[resp_id]
|
|
self.conversations[conversation_id].extend(resp_msgs)
|
|
elif prev_resp_id is not None:
|
|
if prev_resp_id not in self.response_to_conversation:
|
|
logger.warning(
|
|
f"Previous response id {prev_resp_id} not found in conversation store"
|
|
)
|
|
|
|
conversation_id = self.response_to_conversation[prev_resp_id]
|
|
self.conversations[conversation_id].extend(resp_msgs)
|
|
while len(self.conversations[conversation_id]
|
|
) > self.conversation_capacity:
|
|
self._pop_conversation(resp_id)
|
|
else:
|
|
conversation_id = _random_uuid()
|
|
self.conversations[conversation_id] = resp_msgs
|
|
|
|
_responses_debug_log(
|
|
f" * storing at conversation id: {conversation_id}")
|
|
|
|
self.response_to_conversation[resp_id] = conversation_id
|
|
self.conversation_to_response[conversation_id] = resp_id
|
|
self._update_visited_conversation(conversation_id)
|
|
|
|
async def store_messages(self, resp_id: str,
|
|
msgs: Union[list[Message],
|
|
list[ChatCompletionMessageParam]],
|
|
prev_resp_id: Optional[str]) -> None:
|
|
"""
|
|
Store the messages in the conversation store.
|
|
|
|
`msgs` should always contains the whole conversation messages, including the previous messages and the new messages.
|
|
|
|
Args:
|
|
resp_id: str
|
|
msgs: Union[list[Message], list[ChatCompletionMessageParam]]: The messages to store.
|
|
prev_resp_id: Optional[str]: The previous response id. If not provided, a new conversation will be created.
|
|
|
|
Returns:
|
|
None
|
|
"""
|
|
_responses_debug_log(f"ConversationHistoryStore storing msg:")
|
|
if ENABLE_RESPONSES_DEBUG_MSG:
|
|
for msg in msgs:
|
|
_responses_debug_log(f" -> {msg}")
|
|
|
|
async with self.conversations_lock:
|
|
conversation_id: str
|
|
if prev_resp_id is not None and prev_resp_id in self.response_to_conversation:
|
|
conversation_id = self.response_to_conversation[prev_resp_id]
|
|
else:
|
|
conversation_id = _random_uuid()
|
|
|
|
_responses_debug_log(
|
|
f" * storing at conversation: {conversation_id}")
|
|
self.conversations[conversation_id] = msgs
|
|
if len(self.conversations[conversation_id]
|
|
) > self.conversation_capacity:
|
|
self._pop_conversation(resp_id)
|
|
|
|
self.response_to_conversation[resp_id] = conversation_id
|
|
self.conversation_to_response[conversation_id] = resp_id
|
|
self._update_visited_conversation(conversation_id)
|
|
|
|
async def get_conversation_history(
|
|
self, resp_id: str
|
|
) -> Union[list[Message], list[ChatCompletionMessageParam]]:
|
|
_responses_debug_log(f"ConversationHistoryStore getting prev_msgs:")
|
|
_responses_debug_log(f" -> prev_resp_id: {resp_id}")
|
|
async with self.conversations_lock:
|
|
if resp_id in self.response_to_conversation:
|
|
conversation_id = self.response_to_conversation[resp_id]
|
|
_responses_debug_log(
|
|
f" -> getting conversation_id: {conversation_id}")
|
|
self._update_visited_conversation(conversation_id)
|
|
return self.conversations.get(conversation_id, [])
|
|
|
|
return []
|
|
|
|
def _update_visited_conversation(self, conversation_id) -> None:
|
|
"""
|
|
Update the visited conversation to the front of the conversation store.
|
|
This function is used to keep the conversation store sorted by the visited time.
|
|
And also remove the least recently visited conversation if the number of conversations exceeds the limit.
|
|
|
|
Args:
|
|
conversation_id: str, the id of the conversation to update.
|
|
|
|
Returns:
|
|
None
|
|
"""
|
|
if conversation_id not in self.conversations:
|
|
return
|
|
|
|
self.conversations.move_to_end(conversation_id)
|
|
if len(self.conversations) > self.max_conversations:
|
|
removed_id, _ = self.conversations.popitem(last=False)
|
|
_responses_debug_log(
|
|
f"ConversationHistoryStore Removing conversation {removed_id}")
|
|
removed_resp_id = self.conversation_to_response[removed_id]
|
|
# The responses may have been removed due to response capacity
|
|
if removed_resp_id in self.response_to_conversation:
|
|
self.response_to_conversation.pop(removed_resp_id)
|
|
self.conversation_to_response.pop(removed_id)
|
|
|
|
def _pop_conversation(self, resp_id) -> None:
|
|
"""
|
|
Pop the oldest conversation messages from a conversation.
|
|
The conversation is starting by a user message and ending by an assistant message.
|
|
This function is used to keep the number of messages in a conversation within the limit.
|
|
|
|
Args:
|
|
resp_id: str, the response id of the conversation to pop.
|
|
|
|
Returns:
|
|
None
|
|
"""
|
|
conversation_id = self.response_to_conversation.get(resp_id, None)
|
|
if conversation_id is None:
|
|
return
|
|
|
|
conversation = self.conversations[conversation_id]
|
|
if len(conversation) == 0:
|
|
return
|
|
|
|
is_harmony_conversation = isinstance(conversation[0], Message)
|
|
|
|
def get_first_conversation_range_harmony():
|
|
start_index = 0
|
|
end_index = 0
|
|
for i, msg in enumerate(conversation):
|
|
if msg.author.role == Role.USER:
|
|
start_index = i
|
|
elif msg.channel == "final":
|
|
end_index = i
|
|
break
|
|
|
|
return start_index, end_index
|
|
|
|
def get_first_conversation_range():
|
|
start_index = 0
|
|
end_index = 0
|
|
for i, msg in enumerate(conversation):
|
|
if msg.get("role", "") == "user":
|
|
start_index = i
|
|
elif msg.get("role", "") == "assistant":
|
|
end_index = i
|
|
break
|
|
|
|
return start_index, end_index
|
|
|
|
start_index, end_index = 0, 0
|
|
if is_harmony_conversation:
|
|
start_index, end_index = get_first_conversation_range_harmony()
|
|
else:
|
|
start_index, end_index = get_first_conversation_range()
|
|
|
|
del conversation[start_index:end_index + 1]
|
|
|
|
def _pop_response(self) -> None:
|
|
_responses_debug_log(f"responses type: {type(self.responses)}")
|
|
resp_id, _ = self.responses.popitem(last=False)
|
|
if resp_id in self.response_to_conversation:
|
|
self.response_to_conversation.pop(resp_id)
|
|
|
|
|
|
def _get_system_message(
|
|
model_identity: Optional[str] = None,
|
|
reasoning_effort: Optional[Literal["high", "medium", "low"]] = None,
|
|
start_date: Optional[str] = None,
|
|
browser_description: Optional[str] = None,
|
|
python_description: Optional[str] = None,
|
|
) -> Message:
|
|
sys_msg_content = SystemContent.new()
|
|
if model_identity is not None:
|
|
sys_msg_content = sys_msg_content.with_model_identity(model_identity)
|
|
if reasoning_effort is not None:
|
|
sys_msg_content = sys_msg_content.with_reasoning_effort(
|
|
REASONING_EFFORT[reasoning_effort])
|
|
if start_date:
|
|
sys_msg_content = sys_msg_content.with_conversation_start_date(
|
|
start_date)
|
|
if browser_description is not None:
|
|
sys_msg_content = sys_msg_content.with_tools(browser_description)
|
|
if python_description is not None:
|
|
sys_msg_content = sys_msg_content.with_tools(python_description)
|
|
sys_msg = Message.from_role_and_content(Role.SYSTEM, sys_msg_content)
|
|
return sys_msg
|
|
|
|
|
|
def _get_developer_message(instructions: Optional[str] = None,
|
|
tools: Optional[list[Tool]] = None) -> Message:
|
|
dev_msg_content = DeveloperContent.new()
|
|
if instructions is not None:
|
|
dev_msg_content = dev_msg_content.with_instructions(instructions)
|
|
if tools is not None:
|
|
function_tools = []
|
|
for tool in tools:
|
|
if tool.type in ("web_search_preview", "code_interpreter"):
|
|
# These are built-in tools that are added to the system message.
|
|
pass
|
|
elif tool.type == "function":
|
|
function_tools.append(tool)
|
|
else:
|
|
raise ValueError(f"tool type {tool.type} not supported")
|
|
if function_tools:
|
|
function_tool_descriptions = [
|
|
ToolDescription.new(
|
|
name=tool.name,
|
|
description=tool.description,
|
|
parameters=tool.parameters,
|
|
) for tool in function_tools
|
|
]
|
|
dev_msg_content = dev_msg_content.with_function_tools(
|
|
function_tool_descriptions)
|
|
dev_msg = Message.from_role_and_content(Role.DEVELOPER, dev_msg_content)
|
|
return dev_msg
|
|
|
|
|
|
def _get_user_message(content: str) -> Message:
|
|
return Message.from_role_and_content(Role.USER, content)
|
|
|
|
|
|
def _construct_harmony_messages(
|
|
request: ResponsesRequest,
|
|
prev_response: Optional[ResponsesResponse],
|
|
prev_msgs: list[Message] = [],
|
|
) -> list[Message]:
|
|
"""Construct messages from request input, includes conversation history messages if exists."""
|
|
messages: list[Message] = []
|
|
if prev_response is None:
|
|
# New conversation.
|
|
reasoning_effort = (request.reasoning.effort
|
|
if request.reasoning else None)
|
|
sys_msg = _get_system_message(reasoning_effort=reasoning_effort, )
|
|
messages.append(sys_msg)
|
|
dev_msg = _get_developer_message(request.instructions, request.tools)
|
|
messages.append(dev_msg)
|
|
else:
|
|
messages.extend(prev_msgs)
|
|
# Append the new input.
|
|
# Responses API supports simple text inputs without chat format.
|
|
if isinstance(request.input, str):
|
|
messages.append(_get_user_message(request.input))
|
|
else:
|
|
if prev_response is not None:
|
|
prev_outputs = copy(prev_response.output)
|
|
else:
|
|
prev_outputs = []
|
|
for input_msg in request.input:
|
|
msg = _parse_response_input(input_msg, prev_outputs)
|
|
if msg is not None:
|
|
messages.append(msg)
|
|
# User passes in a a tool call request and its output. We need
|
|
# to add the tool call request to prev_outputs so that the
|
|
# parse_response_input can find the tool call request when
|
|
# parsing the tool call output.
|
|
if isinstance(input_msg, ResponseFunctionToolCall):
|
|
prev_outputs.append(input_msg)
|
|
return messages
|
|
|
|
|
|
def _render_for_completion(messages: list[Message]) -> list[int]:
|
|
conversation = Conversation.from_messages(messages)
|
|
_responses_debug_log("Rendering conversation:")
|
|
_responses_debug_log(conversation.to_json())
|
|
token_ids = _get_encoding().render_conversation_for_completion(
|
|
conversation, Role.ASSISTANT)
|
|
return token_ids
|
|
|
|
|
|
def _parse_output_tokens(tokens: list[int]) -> list[Message]:
|
|
return _get_encoding().parse_messages_from_completion_tokens(
|
|
tokens, role=Role.ASSISTANT)
|
|
|
|
|
|
def _parse_output_message_harmony(message: Message) -> list[ResponseOutputItem]:
|
|
"""
|
|
Parse a Harmony message into a list of output response items.
|
|
"""
|
|
if message.author.role != "assistant":
|
|
# This is a message from a tool to the assistant (e.g., search result).
|
|
# Don't include it in the final output for now. This aligns with
|
|
# OpenAI's behavior on models like o4-mini.
|
|
return []
|
|
|
|
output_items: list[ResponseOutputItem] = []
|
|
recipient = message.recipient
|
|
if recipient is not None and recipient.startswith("browser."):
|
|
if len(message.content) != 1:
|
|
raise ValueError("Invalid number of contents in browser message")
|
|
content = message.content[0]
|
|
browser_call = json.loads(content.text)
|
|
# TODO: translate to url properly!
|
|
if recipient == "browser.search":
|
|
action = ActionSearch(
|
|
query=f"cursor:{browser_call.get('query', '')}", type="search")
|
|
elif recipient == "browser.open":
|
|
action = ActionOpenPage(url=f"cursor:{browser_call.get('url', '')}",
|
|
type="open_page")
|
|
elif recipient == "browser.find":
|
|
action = ActionFind(pattern=browser_call["pattern"],
|
|
url=f"cursor:{browser_call.get('url', '')}",
|
|
type="find")
|
|
else:
|
|
raise ValueError(f"Unknown browser action: {recipient}")
|
|
web_search_item = ResponseFunctionWebSearch(
|
|
id=f"ws_{_random_uuid()}",
|
|
action=action,
|
|
status="completed",
|
|
type="web_search_call",
|
|
)
|
|
output_items.append(web_search_item)
|
|
elif message.channel == "analysis":
|
|
for content in message.content:
|
|
reasoning_item = ResponseReasoningItem(
|
|
id=f"rs_{_random_uuid()}",
|
|
summary=[],
|
|
type="reasoning",
|
|
content=[Content(text=content.text, type="reasoning_text")],
|
|
status=None,
|
|
)
|
|
output_items.append(reasoning_item)
|
|
elif message.channel == "commentary":
|
|
if message.recipient is None:
|
|
pass
|
|
elif message.recipient.startswith("functions."):
|
|
function_name = message.recipient.split(".")[-1]
|
|
for content in message.content:
|
|
response_item = ResponseFunctionToolCall(
|
|
arguments=content.text,
|
|
call_id=f"call_{_random_uuid()}",
|
|
type="function_call",
|
|
name=function_name,
|
|
id=f"fc_{_random_uuid()}",
|
|
)
|
|
output_items.append(response_item)
|
|
elif message.recipient.startswith(
|
|
"python") or message.recipient.startswith("browser"):
|
|
for content in message.content:
|
|
reasoning_item = ResponseReasoningItem(
|
|
id=f"rs_{_random_uuid()}",
|
|
summary=[],
|
|
type="reasoning",
|
|
content=[Content(text=content.text, type="reasoning_text")],
|
|
status=None,
|
|
)
|
|
output_items.append(reasoning_item)
|
|
else:
|
|
raise ValueError(f"Unknown recipient: {message.recipient}")
|
|
elif message.channel == "final":
|
|
contents = []
|
|
for content in message.content:
|
|
output_text = ResponseOutputText(
|
|
text=content.text,
|
|
annotations=[], # TODO
|
|
type="output_text",
|
|
logprobs=None, # TODO
|
|
)
|
|
contents.append(output_text)
|
|
text_item = ResponseOutputMessage(
|
|
id=f"msg_{_random_uuid()}",
|
|
content=contents,
|
|
role=message.author.role,
|
|
status="completed",
|
|
type="message",
|
|
)
|
|
output_items.append(text_item)
|
|
else:
|
|
raise ValueError(f"Unknown channel: {message.channel}")
|
|
return output_items
|
|
|
|
|
|
def finish_reason_mapping(finish_reason: str) -> str:
|
|
match finish_reason:
|
|
case 'stop':
|
|
return 'completed'
|
|
case 'length':
|
|
return 'incomplete'
|
|
case 'timeout':
|
|
return 'failed'
|
|
case 'cancelled':
|
|
return 'cancelled'
|
|
|
|
raise RuntimeError("Should never reach here!")
|
|
|
|
|
|
def _response_output_item_to_chat_completion_message(
|
|
item: Union[dict,
|
|
ResponseInputOutputItem]) -> ChatCompletionMessageParam:
|
|
if not isinstance(item, dict):
|
|
item = item.model_dump()
|
|
|
|
item_type = item.get("type", "")
|
|
|
|
match item_type:
|
|
case "":
|
|
if "role" in item:
|
|
return item
|
|
else:
|
|
raise ValueError(f"Invalid input message item: {item}")
|
|
case "message":
|
|
return {
|
|
"role": "assistant",
|
|
"content": item["content"][0]["text"],
|
|
}
|
|
case "reasoning":
|
|
return {
|
|
"role": "assistant",
|
|
"reasoning": item["content"][0]["text"],
|
|
}
|
|
case "function_call":
|
|
return {
|
|
"role": "function",
|
|
"content": item["arguments"],
|
|
}
|
|
case "function_call_output":
|
|
return {
|
|
"role": "tool",
|
|
"content": item["output"],
|
|
"tool_call_id": item["call_id"],
|
|
}
|
|
case _:
|
|
raise ValueError(
|
|
f"Unsupported input item type: {item_type}, item: {item}")
|
|
|
|
|
|
async def _create_input_messages(
|
|
request: ResponsesRequest,
|
|
prev_msgs: list[ChatCompletionMessageParam],
|
|
) -> list[ChatCompletionMessageParam]:
|
|
messages: list[ChatCompletionMessageParam] = []
|
|
if request.instructions:
|
|
messages.append({
|
|
"role": "system",
|
|
"content": request.instructions,
|
|
})
|
|
|
|
# Prepend the conversation history.
|
|
# Skip the reasoning output.
|
|
for msg in prev_msgs:
|
|
if "reasoning" not in msg:
|
|
messages.append(msg)
|
|
|
|
# Append the new input.
|
|
# Responses API supports simple text inputs without chat format.
|
|
if isinstance(request.input, str):
|
|
messages.append({"role": "user", "content": request.input})
|
|
else:
|
|
for inp in request.input:
|
|
messages.append(
|
|
_response_output_item_to_chat_completion_message(inp))
|
|
|
|
return messages
|
|
|
|
|
|
def _create_output_messages(
|
|
output_contents: dict[str, Any]) -> list[ChatCompletionMessageParam]:
|
|
"""
|
|
Convert output contents to chat completion messages for conversation store.
|
|
|
|
Reasoning content is not included in the output messages to reduce the token usage.
|
|
|
|
Input:
|
|
output_contents: dict[str, str]
|
|
- text_content: Optional[str]
|
|
- reasoning_content: Optional[str]
|
|
- tool_calls: Optional[list[ToolCall]]
|
|
|
|
Returns:
|
|
list[ChatCompletionMessageParam]: Chat completion messages for conversation store.
|
|
"""
|
|
messages: list[ChatCompletionMessageParam] = []
|
|
|
|
text_content = output_contents.get("text_content", None)
|
|
if text_content:
|
|
messages.append({
|
|
"role": "assistant",
|
|
"content": text_content,
|
|
})
|
|
|
|
reasoning_content = output_contents.get("reasoning_content", None)
|
|
if reasoning_content:
|
|
reasoning_msg = ReasoningAssistantMessage(
|
|
role="assistant",
|
|
reasoning=reasoning_content,
|
|
)
|
|
|
|
tool_calls = output_contents.get("tool_calls", [])
|
|
tool_call_msgs = [{
|
|
"id": call.call_id,
|
|
"function": {
|
|
"arguments": call.arguments,
|
|
"name": call.name,
|
|
},
|
|
"type": "function",
|
|
} for call in tool_calls]
|
|
|
|
_responses_debug_log(f"tool_call_msgs: {tool_call_msgs}")
|
|
reasoning_msg["tool_calls"] = tool_call_msgs
|
|
|
|
messages.append(reasoning_msg)
|
|
|
|
return messages
|
|
|
|
|
|
def _get_chat_completion_function_tools(
|
|
tools: Optional[list[Tool]]) -> list[ChatCompletionToolsParam]:
|
|
function_tools: list[ChatCompletionToolsParam] = []
|
|
if tools is None:
|
|
return function_tools
|
|
|
|
for tool in tools:
|
|
if isinstance(tool, FunctionTool):
|
|
function_tools.append(
|
|
ChatCompletionToolsParam(
|
|
type="function",
|
|
function=FunctionDefinition(
|
|
name=tool.name,
|
|
description=tool.description,
|
|
parameters=tool.parameters,
|
|
),
|
|
))
|
|
else:
|
|
logger.warning(
|
|
f"Unsupported tool type: {type(tool)} for non-gpt-oss models, skipping."
|
|
)
|
|
|
|
return function_tools
|
|
|
|
|
|
async def _create_input_tokens(
|
|
request: ResponsesRequest,
|
|
prev_response: Optional[ResponsesResponse],
|
|
prev_msgs: list[ChatCompletionMessageParam],
|
|
conversation_store: ConversationHistoryStore,
|
|
enable_store: bool,
|
|
tokenizer: Union[TransformersTokenizer, TokenizerBase],
|
|
model_config: PretrainedConfig,
|
|
processor: AutoProcessor,
|
|
) -> Tuple[list[int], Optional[dict[str, list[Any]]]]:
|
|
"""
|
|
Create input tokens for the model. Also return the mm data if the model is multimodal.
|
|
|
|
Returns:
|
|
Tuple[list[int], Optional[dict[str, list[Any]]]]: Input tokens and mm data.
|
|
|
|
"""
|
|
messages = await _create_input_messages(
|
|
request=request,
|
|
prev_msgs=prev_msgs,
|
|
)
|
|
|
|
if enable_store and request.store:
|
|
await conversation_store.store_messages(request.request_id, messages,
|
|
request.previous_response_id)
|
|
|
|
conversation, mm_coroutines, mm_placeholder_counts = parse_chat_messages_coroutines(
|
|
messages, model_config)
|
|
mm_data = await mm_coroutines
|
|
|
|
tools_dict = [
|
|
tool.model_dump()
|
|
for tool in _get_chat_completion_function_tools(request.tools)
|
|
]
|
|
token_ids = apply_chat_template(
|
|
model_type=model_config.model_type,
|
|
tokenizer=tokenizer,
|
|
processor=processor,
|
|
conversation=conversation,
|
|
add_generation_prompt=True,
|
|
tools=tools_dict,
|
|
mm_placeholder_counts=mm_placeholder_counts,
|
|
enable_tokenize=True,
|
|
)
|
|
|
|
return token_ids, mm_data
|
|
|
|
|
|
async def _create_input_tokens_harmony(
|
|
request: ResponsesRequest,
|
|
prev_response: Optional[ResponsesResponse],
|
|
prev_msgs: list[Message],
|
|
conversation_store: ConversationHistoryStore,
|
|
enable_store: bool,
|
|
) -> list[int]:
|
|
messages = _construct_harmony_messages(request,
|
|
prev_response,
|
|
prev_msgs=prev_msgs)
|
|
|
|
if enable_store and request.store:
|
|
# Remove reasoning messages to save token usage during multi-turn conversation
|
|
msgs_to_store = [msg for msg in messages if msg.channel != "analysis"]
|
|
await conversation_store.store_messages(request.request_id,
|
|
msgs_to_store,
|
|
request.previous_response_id)
|
|
|
|
return _render_for_completion(messages)
|
|
|
|
|
|
async def request_preprocess(
|
|
request: ResponsesRequest,
|
|
prev_response: Optional[ResponsesResponse],
|
|
conversation_store: ConversationHistoryStore,
|
|
enable_store: bool,
|
|
use_harmony: bool,
|
|
tokenizer: Optional[Union[TransformersTokenizer, TokenizerBase]] = None,
|
|
model_config: Optional[PretrainedConfig] = None,
|
|
processor: Optional[AutoProcessor] = None,
|
|
) -> tuple[list[int], SamplingParams]:
|
|
|
|
sampling_params = request.to_sampling_params(
|
|
default_sampling_params={
|
|
"stop_token_ids":
|
|
get_harmony_adapter().get_stop_tokens() if use_harmony else []
|
|
})
|
|
|
|
prev_response_id = request.previous_response_id
|
|
|
|
# TODO: better way to enable metrics
|
|
if len(os.getenv("TRTLLM_KVCACHE_TIME_OUTPUT_PATH", "")) > 0:
|
|
sampling_params.return_perf_metrics = True
|
|
|
|
prev_msgs = []
|
|
if enable_store and prev_response_id is not None:
|
|
prev_msgs = await conversation_store.get_conversation_history(
|
|
prev_response_id)
|
|
|
|
_responses_debug_log(f"Prev msgs:")
|
|
for msg in prev_msgs:
|
|
_responses_debug_log(f" -> {msg}")
|
|
|
|
if use_harmony:
|
|
input_tokens = await _create_input_tokens_harmony(
|
|
request=request,
|
|
prev_response=prev_response,
|
|
prev_msgs=prev_msgs,
|
|
conversation_store=conversation_store,
|
|
enable_store=enable_store,
|
|
)
|
|
|
|
else:
|
|
input_tokens, _ = await _create_input_tokens(
|
|
request=request,
|
|
prev_response=prev_response,
|
|
prev_msgs=prev_msgs,
|
|
conversation_store=conversation_store,
|
|
enable_store=enable_store,
|
|
tokenizer=tokenizer,
|
|
model_config=model_config,
|
|
processor=processor,
|
|
)
|
|
|
|
_responses_debug_log("======= Complete Inputs to model =======")
|
|
_responses_debug_log(_decode_tokens(input_tokens, tokenizer))
|
|
_responses_debug_log("========================================")
|
|
return input_tokens, sampling_params
|
|
|
|
|
|
# TODO(JunyiXu-nv): move to use the same function in postprocess_handlers after multiple post processors are supported
|
|
def _apply_reasoning_parser(
|
|
reasoning_parser_id: Optional[str],
|
|
output_index: int,
|
|
text: str,
|
|
streaming: bool,
|
|
reasoning_parser_dict: Optional[dict[int, BaseReasoningParser]] = None,
|
|
) -> Tuple[str, str]:
|
|
reasoning_parser: Optional[BaseReasoningParser] = None
|
|
if reasoning_parser_id is not None:
|
|
if reasoning_parser_dict is not None:
|
|
if output_index not in reasoning_parser_dict:
|
|
reasoning_parser_dict[
|
|
output_index] = ReasoningParserFactory.create_reasoning_parser(
|
|
reasoning_parser_id)
|
|
|
|
reasoning_parser = reasoning_parser_dict[output_index]
|
|
else:
|
|
reasoning_parser = ReasoningParserFactory.create_reasoning_parser(
|
|
reasoning_parser_id)
|
|
|
|
if reasoning_parser is not None:
|
|
if not streaming:
|
|
result = reasoning_parser.parse(text)
|
|
else:
|
|
result = reasoning_parser.parse_delta(text)
|
|
content, reasoning_content = result.content, result.reasoning_content
|
|
else:
|
|
content, reasoning_content = text, ""
|
|
|
|
return content, reasoning_content
|
|
|
|
|
|
def _apply_tool_parser(
|
|
tool_parser_id: Optional[str],
|
|
tools: Optional[list[Tool]],
|
|
output_index: int,
|
|
text: str,
|
|
streaming: bool,
|
|
tool_parser_dict: Optional[dict[int, BaseToolParser]] = None,
|
|
) -> Tuple[str, list[ToolCallItem]]:
|
|
tool_parser: Optional[BaseToolParser] = None
|
|
if tool_parser_id is not None and tools is not None:
|
|
if tool_parser_dict is not None:
|
|
if output_index not in tool_parser_dict:
|
|
tool_parser_dict[
|
|
output_index] = ToolParserFactory.create_tool_parser(
|
|
tool_parser_id)
|
|
|
|
tool_parser = tool_parser_dict[output_index]
|
|
else:
|
|
tool_parser = ToolParserFactory.create_tool_parser(tool_parser_id)
|
|
|
|
if tool_parser is not None and tools is not None:
|
|
if not streaming:
|
|
result = tool_parser.detect_and_parse(text, tools)
|
|
else:
|
|
result = tool_parser.parse_streaming_increment(text, tools)
|
|
normal_text, calls = result.normal_text, result.calls
|
|
else:
|
|
normal_text, calls = text, []
|
|
|
|
return normal_text, calls
|
|
|
|
|
|
async def _create_output_content(
|
|
final_res: RequestOutput,
|
|
reasoning_parser: Optional[str] = None,
|
|
tool_parser: Optional[str] = None,
|
|
tools: Optional[list[Tool]] = None,
|
|
) -> Tuple[list[ResponseOutputItem], list[ChatCompletionMessageParam]]:
|
|
output_items: list[ResponseOutputItem] = []
|
|
output_messages: list[ChatCompletionMessageParam] = []
|
|
available_tools = _get_chat_completion_function_tools(tools)
|
|
|
|
for output in final_res.outputs:
|
|
text, reasoning_text = _apply_reasoning_parser(reasoning_parser,
|
|
output.index,
|
|
output.text, False)
|
|
|
|
if text:
|
|
text, calls = _apply_tool_parser(tool_parser, available_tools,
|
|
output.index, text, False)
|
|
|
|
text_item = None
|
|
reasoning_item = None
|
|
tool_calls_item = []
|
|
# Check again after tool parsing to avoid empty text
|
|
if text:
|
|
output_text = ResponseOutputText(
|
|
text=text.strip(),
|
|
annotations=[],
|
|
type="output_text",
|
|
logprobs=None,
|
|
)
|
|
|
|
text_item = ResponseOutputMessage(
|
|
id=f"msg_{_random_uuid()}",
|
|
content=[output_text],
|
|
role="assistant",
|
|
status="completed",
|
|
type="message",
|
|
)
|
|
|
|
output_items.append(text_item)
|
|
|
|
if reasoning_text:
|
|
reasoning_item = ResponseReasoningItem(
|
|
id=f"rs_{_random_uuid()}",
|
|
summary=[],
|
|
type="reasoning",
|
|
content=[
|
|
Content(text=reasoning_text.strip(), type="reasoning_text")
|
|
],
|
|
status=None,
|
|
)
|
|
output_items.append(reasoning_item)
|
|
|
|
if calls:
|
|
tool_calls_item = [
|
|
ResponseFunctionToolCall(
|
|
arguments=call.parameters,
|
|
call_id=f"call_{_random_uuid()}",
|
|
name=call.name,
|
|
type="function_call",
|
|
id=f"fc_{_random_uuid()}",
|
|
) for call in calls
|
|
]
|
|
output_items.extend(tool_calls_item)
|
|
|
|
output_messages.extend(
|
|
_create_output_messages({
|
|
"text_content":
|
|
text_item.content[0].text if text_item else None,
|
|
"reasoning_content":
|
|
reasoning_item.content[0].text if reasoning_item else None,
|
|
"tool_calls":
|
|
tool_calls_item,
|
|
}))
|
|
|
|
return output_items, output_messages
|
|
|
|
|
|
async def _create_output_content_harmony(
|
|
final_res: RequestOutput
|
|
) -> Tuple[list[ResponseOutputItem], list[Message]]:
|
|
output_messages = _parse_output_tokens(final_res.outputs[0].token_ids)
|
|
output_content = []
|
|
|
|
if ENABLE_RESPONSES_DEBUG_MSG:
|
|
_responses_debug_log(f"output messages: {len(output_messages)}")
|
|
for msg in output_messages:
|
|
_responses_debug_log(f" -> {msg.to_json()}")
|
|
|
|
for msg in output_messages:
|
|
output_content.extend(_parse_output_message_harmony(msg))
|
|
|
|
return output_content, output_messages
|
|
|
|
|
|
async def create_response(
|
|
generator,
|
|
request: ResponsesRequest,
|
|
sampling_params: SamplingParams,
|
|
model_name: str,
|
|
conversation_store: ConversationHistoryStore,
|
|
generation_result: Optional[RequestOutput] = None,
|
|
enable_store: bool = False,
|
|
use_harmony: bool = True,
|
|
create_time: int = None,
|
|
reasoning_parser: Optional[str] = None,
|
|
tool_parser: Optional[str] = None,
|
|
) -> ResponsesResponse:
|
|
|
|
final_res: Optional[RequestOutput] = None
|
|
response_creation_time = create_time if create_time is not None else int(
|
|
time.time())
|
|
prev_response_id = request.previous_response_id
|
|
|
|
if generation_result is not None:
|
|
final_res = generation_result
|
|
else:
|
|
final_res = await generator
|
|
|
|
if final_res is None:
|
|
raise RuntimeError("No output generated or provided")
|
|
|
|
_responses_debug_log("================================================")
|
|
_responses_debug_log("RAW MODEL OUTPUT:")
|
|
_responses_debug_log(final_res.outputs)
|
|
_responses_debug_log("================================================")
|
|
|
|
# prepare responses output
|
|
output_content = []
|
|
if use_harmony:
|
|
output_content, output_messages = await _create_output_content_harmony(
|
|
final_res)
|
|
else:
|
|
output_content, output_messages = await _create_output_content(
|
|
final_res, reasoning_parser, tool_parser, request.tools)
|
|
|
|
response = ResponsesResponse.from_request(
|
|
request=request,
|
|
sampling_params=sampling_params,
|
|
model_name=model_name,
|
|
created_time=response_creation_time,
|
|
output=output_content,
|
|
status=finish_reason_mapping(final_res.outputs[0].finish_reason),
|
|
)
|
|
|
|
if enable_store and request.store:
|
|
await conversation_store.store_response(resp=response,
|
|
resp_msgs=output_messages,
|
|
prev_resp_id=prev_response_id)
|
|
|
|
_responses_debug_log("========== Response ===========")
|
|
_responses_debug_log(response)
|
|
_responses_debug_log("===============================")
|
|
return response
|
|
|
|
|
|
class ResponsesStreamingStateTracker:
|
|
current_content_index: int = 0
|
|
current_output_index: int = 0
|
|
current_item_id: str = ""
|
|
sent_output_item_added: bool = False
|
|
|
|
# Only for non-harmony streaming
|
|
text_sent: bool = False
|
|
reasoning_sent: bool = False
|
|
|
|
|
|
class ResponsesStreamingEventsHelper:
|
|
|
|
def __init__(self):
|
|
self.state_tracker = ResponsesStreamingStateTracker()
|
|
|
|
def content_index_increment(self):
|
|
self.state_tracker.current_content_index += 1
|
|
|
|
def output_index_increment(self):
|
|
self.state_tracker.current_output_index += 1
|
|
|
|
@property
|
|
def item_id(self) -> str:
|
|
return self.state_tracker.current_item_id
|
|
|
|
@item_id.setter
|
|
def item_id(self, item_id: str):
|
|
self.state_tracker.current_item_id = item_id
|
|
|
|
@property
|
|
def is_output_item_added_sent(self) -> bool:
|
|
return self.state_tracker.sent_output_item_added
|
|
|
|
@is_output_item_added_sent.setter
|
|
def is_output_item_added_sent(self, is_sent: bool):
|
|
self.state_tracker.sent_output_item_added = is_sent
|
|
|
|
@property
|
|
def is_text_sent(self) -> bool:
|
|
return self.state_tracker.text_sent
|
|
|
|
@is_text_sent.setter
|
|
def is_text_sent(self, is_sent: bool):
|
|
self.state_tracker.text_sent = is_sent
|
|
|
|
@property
|
|
def is_reasoning_sent(self) -> bool:
|
|
return self.state_tracker.reasoning_sent
|
|
|
|
@is_reasoning_sent.setter
|
|
def is_reasoning_sent(self, is_sent: bool):
|
|
self.state_tracker.reasoning_sent = is_sent
|
|
|
|
def get_response_created_event(
|
|
self, response: ResponsesResponse) -> ResponseCreatedEvent:
|
|
return ResponseCreatedEvent(
|
|
type="response.created",
|
|
sequence_number=-1, # will set by _send_event function
|
|
response=response,
|
|
)
|
|
|
|
def get_response_in_progress_event(
|
|
self, response: ResponsesResponse) -> ResponseInProgressEvent:
|
|
return ResponseInProgressEvent(
|
|
type="response.in_progress",
|
|
sequence_number=-1,
|
|
response=response,
|
|
)
|
|
|
|
def get_reasoning_text_done_event(
|
|
self, text: str) -> ResponseReasoningTextDoneEvent:
|
|
return ResponseReasoningTextDoneEvent(
|
|
type="response.reasoning_text.done",
|
|
item_id=self.state_tracker.current_item_id,
|
|
sequence_number=-1,
|
|
output_index=self.state_tracker.current_output_index,
|
|
content_index=self.state_tracker.current_content_index,
|
|
text=text,
|
|
)
|
|
|
|
def get_text_done_event(self, text: str,
|
|
logprobs: list[float]) -> ResponseTextDoneEvent:
|
|
return ResponseTextDoneEvent(
|
|
type="response.output_text.done",
|
|
sequence_number=-1,
|
|
output_index=self.state_tracker.current_output_index,
|
|
content_index=self.state_tracker.current_content_index,
|
|
text=text,
|
|
logprobs=logprobs,
|
|
item_id=self.state_tracker.current_item_id,
|
|
)
|
|
|
|
def get_content_part_done_event(
|
|
self, part: ResponseContentPart) -> ResponseContentPartDoneEvent:
|
|
return ResponseContentPartDoneEvent(
|
|
type="response.content_part.done",
|
|
sequence_number=-1,
|
|
item_id=self.state_tracker.current_item_id,
|
|
output_index=self.state_tracker.current_output_index,
|
|
content_index=self.state_tracker.current_content_index,
|
|
part=part,
|
|
)
|
|
|
|
def get_output_item_done_event(
|
|
self, item: ResponseOutputItem) -> ResponseOutputItemDoneEvent:
|
|
return ResponseOutputItemDoneEvent(
|
|
type="response.output_item.done",
|
|
sequence_number=-1,
|
|
output_index=self.state_tracker.current_output_index,
|
|
item=item,
|
|
)
|
|
|
|
def get_output_item_added_event(
|
|
self, item: ResponseOutputItem) -> ResponseOutputItemAddedEvent:
|
|
return ResponseOutputItemAddedEvent(
|
|
type="response.output_item.added",
|
|
sequence_number=-1,
|
|
output_index=self.state_tracker.current_output_index,
|
|
item=item,
|
|
)
|
|
|
|
def get_content_part_added_event(
|
|
self, part: ResponseContentPart) -> ResponseContentPartAddedEvent:
|
|
return ResponseContentPartAddedEvent(
|
|
type="response.content_part.added",
|
|
sequence_number=-1,
|
|
output_index=self.state_tracker.current_output_index,
|
|
item_id=self.state_tracker.current_item_id,
|
|
content_index=self.state_tracker.current_content_index,
|
|
part=part,
|
|
)
|
|
|
|
def get_text_delta_event(self, delta: str,
|
|
logprobs: list[float]) -> ResponseTextDeltaEvent:
|
|
return ResponseTextDeltaEvent(
|
|
type="response.output_text.delta",
|
|
sequence_number=-1,
|
|
content_index=self.state_tracker.current_content_index,
|
|
output_index=self.state_tracker.current_output_index,
|
|
item_id=self.state_tracker.current_item_id,
|
|
delta=delta,
|
|
logprobs=logprobs,
|
|
)
|
|
|
|
def get_reasoning_text_delta_event(
|
|
self, delta: str) -> ResponseReasoningTextDeltaEvent:
|
|
return ResponseReasoningTextDeltaEvent(
|
|
type="response.reasoning_text.delta",
|
|
item_id=self.state_tracker.current_item_id,
|
|
output_index=self.state_tracker.current_output_index,
|
|
content_index=self.state_tracker.current_content_index,
|
|
delta=delta,
|
|
sequence_number=-1,
|
|
)
|
|
|
|
def _get_output_added_events(
|
|
self, output_item: ResponseOutputMessage | ResponseReasoningItem
|
|
) -> list[StreamingResponsesResponse]:
|
|
"""
|
|
Get item added event and content part added event for a message item which is starting
|
|
to be generated.
|
|
|
|
Returns:
|
|
list[StreamingResponsesResponse]: A list of streaming responses responses
|
|
"""
|
|
if not self.is_output_item_added_sent:
|
|
self.is_output_item_added_sent = True
|
|
|
|
if output_item.type == "message":
|
|
content_part = ResponseOutputText(
|
|
type="output_text",
|
|
text="",
|
|
annotations=[],
|
|
logprobs=[],
|
|
)
|
|
elif output_item.type == "reasoning":
|
|
content_part = PartReasoningText(
|
|
type="reasoning_text",
|
|
text="",
|
|
)
|
|
else:
|
|
raise ValueError(
|
|
f"Unknown content part type: {output_item.type}")
|
|
|
|
yield self.get_output_item_added_event(output_item)
|
|
yield self.get_content_part_added_event(content_part)
|
|
|
|
def get_message_output_added_events(
|
|
self) -> list[StreamingResponsesResponse]:
|
|
return self._get_output_added_events(output_item=ResponseOutputMessage(
|
|
id=self.item_id,
|
|
type="message",
|
|
role="assistant",
|
|
content=[],
|
|
status="in_progress",
|
|
))
|
|
|
|
def get_reasoning_output_added_events(
|
|
self) -> list[StreamingResponsesResponse]:
|
|
return self._get_output_added_events(output_item=ResponseReasoningItem(
|
|
id=self.item_id,
|
|
type="reasoning",
|
|
summary=[],
|
|
status="in_progress",
|
|
))
|
|
|
|
|
|
def _should_send_done_events(
|
|
output: RequestOutput,
|
|
output_index: int,
|
|
reasoning_parser_id: Optional[str] = None,
|
|
tool_parser_id: Optional[str] = None,
|
|
tools: Optional[list[Tool]] = None,
|
|
reasoning_parser_dict: Optional[dict[int, BaseReasoningParser]] = None,
|
|
tool_parser_dict: Optional[dict[int, BaseToolParser]] = None,
|
|
streaming_events_helper: Optional[ResponsesStreamingEventsHelper] = None,
|
|
finished_generation: bool = False,
|
|
) -> Tuple[bool, bool, Optional[str], Optional[str]]:
|
|
"""
|
|
Determine if done events should be sent for text or reasoning items.
|
|
|
|
Analyzes the complete output text to detect when reasoning or text sections
|
|
have been completed and should receive done events.
|
|
|
|
Args:
|
|
output: RequestOutput containing full generated text in output.text
|
|
output_index: Index of the output being processed
|
|
reasoning_parser_id: Parser ID for extracting reasoning content
|
|
tool_parser_id: Parser ID for extracting tool calls
|
|
tools: Available tools for tool parsing
|
|
reasoning_parser_dict: Dictionary of reasoning parsers
|
|
tool_parser_dict: Dictionary of tool parsers
|
|
streaming_events_helper: Helper tracking current streaming state
|
|
|
|
Returns:
|
|
Tuple of (should_send_reasoning_done, should_send_text_done,
|
|
reasoning_content, text_content)
|
|
"""
|
|
should_send_reasoning_done = False
|
|
should_send_text_done = False
|
|
reasoning_content = ""
|
|
text_content = ""
|
|
|
|
# TODO(JunyiXu-nv): find a more efficient way to decide if we need to send done events
|
|
# Parse complete output using non-streaming mode to get full content
|
|
full_text, full_reasoning = _apply_reasoning_parser(
|
|
reasoning_parser_id=reasoning_parser_id,
|
|
output_index=output_index,
|
|
text=output.text,
|
|
streaming=False,
|
|
reasoning_parser_dict=reasoning_parser_dict,
|
|
)
|
|
|
|
# Apply tool parsing to get tool calls
|
|
tool_calls = []
|
|
if full_text:
|
|
full_text, tool_calls = _apply_tool_parser(
|
|
tool_parser_id=tool_parser_id,
|
|
tools=tools,
|
|
output_index=output_index,
|
|
text=full_text,
|
|
streaming=False,
|
|
tool_parser_dict=tool_parser_dict,
|
|
)
|
|
|
|
# Detect reasoning -> text transition
|
|
# Reasoning is done when we have sent reasoning content and now have text content
|
|
if full_reasoning and full_text:
|
|
if streaming_events_helper and streaming_events_helper.is_reasoning_sent and not streaming_events_helper.is_text_sent:
|
|
should_send_reasoning_done = True
|
|
reasoning_content = full_reasoning
|
|
|
|
# Detect text -> tool call transition
|
|
# Text is done when we have sent text content and now have tool calls
|
|
if full_text and tool_calls:
|
|
if streaming_events_helper and streaming_events_helper.is_text_sent:
|
|
should_send_text_done = True
|
|
text_content = full_text
|
|
|
|
# Also check if text is done because generation finished (no tool calls case)
|
|
# Text is done when generation completes and we've sent text
|
|
if full_text and not tool_calls and finished_generation:
|
|
if streaming_events_helper and streaming_events_helper.is_text_sent:
|
|
should_send_text_done = True
|
|
text_content = full_text
|
|
|
|
# Similarly, reasoning is done if generation finished with only reasoning (no text case)
|
|
if full_reasoning and not full_text and finished_generation:
|
|
if streaming_events_helper and streaming_events_helper.is_reasoning_sent:
|
|
should_send_reasoning_done = True
|
|
reasoning_content = full_reasoning
|
|
|
|
return should_send_reasoning_done, should_send_text_done, reasoning_content, text_content
|
|
|
|
|
|
def _generate_streaming_event(
|
|
output: RequestOutput,
|
|
request: ResponsesRequest,
|
|
finished_generation: bool,
|
|
streaming_events_helper: ResponsesStreamingEventsHelper,
|
|
reasoning_parser_id: Optional[str] = None,
|
|
tool_parser_id: Optional[str] = None,
|
|
reasoning_parser_dict: Optional[dict[int, BaseReasoningParser]] = None,
|
|
tool_parser_dict: Optional[dict[int, BaseToolParser]] = None,
|
|
):
|
|
available_tools = _get_chat_completion_function_tools(request.tools)
|
|
output_idx = output.index
|
|
delta_text = output.text_diff
|
|
calls = []
|
|
|
|
def check_parser(parser_id: Optional[str],
|
|
parser_dict: Optional[dict[int, BaseReasoningParser]]):
|
|
if parser_id is not None:
|
|
if parser_dict is None:
|
|
raise RuntimeError(
|
|
f"Parser({parser_id}) dictionary is not provided for streaming"
|
|
)
|
|
|
|
check_parser(reasoning_parser_id, reasoning_parser_dict)
|
|
check_parser(tool_parser_id, tool_parser_dict)
|
|
|
|
delta_text, reasoning_delta_text = _apply_reasoning_parser(
|
|
reasoning_parser_id=reasoning_parser_id,
|
|
output_index=output_idx,
|
|
text=delta_text,
|
|
streaming=True,
|
|
reasoning_parser_dict=reasoning_parser_dict,
|
|
)
|
|
|
|
if delta_text:
|
|
# TODO(JunyiXu-nv): handle tool calls in streaming mode
|
|
delta_text, calls = _apply_tool_parser(
|
|
tool_parser_id=tool_parser_id,
|
|
tools=available_tools,
|
|
output_index=output_idx,
|
|
text=delta_text,
|
|
streaming=True,
|
|
tool_parser_dict=tool_parser_dict,
|
|
)
|
|
|
|
_responses_debug_log(
|
|
repr(
|
|
f" ---------> delta text: {delta_text}, reasoning delta text: {reasoning_delta_text}, calls: {calls}"
|
|
))
|
|
|
|
# Check if we need to send done events for completed sections
|
|
should_send_reasoning_done, should_send_text_done, reasoning_full_content, text_full_content = _should_send_done_events(
|
|
output=output,
|
|
output_index=output_idx,
|
|
reasoning_parser_id=reasoning_parser_id,
|
|
tool_parser_id=tool_parser_id,
|
|
tools=available_tools,
|
|
reasoning_parser_dict=reasoning_parser_dict,
|
|
tool_parser_dict=tool_parser_dict,
|
|
streaming_events_helper=streaming_events_helper,
|
|
finished_generation=finished_generation,
|
|
)
|
|
|
|
# Send done events if needed
|
|
if should_send_reasoning_done and reasoning_full_content:
|
|
reasoning_item = ResponseReasoningItem(
|
|
id=streaming_events_helper.item_id,
|
|
summary=[],
|
|
type="reasoning",
|
|
content=[
|
|
Content(text=reasoning_full_content, type="reasoning_text")
|
|
],
|
|
status="completed",
|
|
)
|
|
yield streaming_events_helper.get_reasoning_text_done_event(
|
|
reasoning_full_content)
|
|
yield streaming_events_helper.get_output_item_done_event(reasoning_item)
|
|
streaming_events_helper.output_index_increment()
|
|
streaming_events_helper.is_output_item_added_sent = False
|
|
streaming_events_helper.is_reasoning_sent = False
|
|
|
|
if should_send_text_done and text_full_content:
|
|
text_content = ResponseOutputText(
|
|
text=text_full_content,
|
|
annotations=[],
|
|
type="output_text",
|
|
logprobs=None,
|
|
)
|
|
text_item = ResponseOutputMessage(
|
|
id=streaming_events_helper.item_id,
|
|
content=[text_content],
|
|
role="assistant",
|
|
status="completed",
|
|
type="message",
|
|
)
|
|
yield streaming_events_helper.get_text_done_event(text_full_content, [])
|
|
yield streaming_events_helper.get_content_part_done_event(text_content)
|
|
yield streaming_events_helper.get_output_item_done_event(text_item)
|
|
streaming_events_helper.output_index_increment()
|
|
streaming_events_helper.is_output_item_added_sent = False
|
|
streaming_events_helper.is_text_sent = False
|
|
|
|
# Send delta events for ongoing content
|
|
if delta_text:
|
|
if delta_text.strip():
|
|
if not streaming_events_helper.is_text_sent:
|
|
streaming_events_helper.is_text_sent = True
|
|
yield from streaming_events_helper.get_message_output_added_events()
|
|
yield streaming_events_helper.get_text_delta_event(delta_text, [])
|
|
elif reasoning_delta_text:
|
|
if reasoning_delta_text.strip():
|
|
if not streaming_events_helper.is_reasoning_sent:
|
|
streaming_events_helper.is_reasoning_sent = True
|
|
yield from streaming_events_helper.get_reasoning_output_added_events(
|
|
)
|
|
yield streaming_events_helper.get_reasoning_text_delta_event(
|
|
reasoning_delta_text)
|
|
|
|
|
|
def _generate_streaming_event_harmony(
|
|
harmony_adapter: HarmonyAdapter,
|
|
stream_request_id: str,
|
|
output: RequestOutput,
|
|
request: ResponsesRequest,
|
|
streaming_events_helper: ResponsesStreamingEventsHelper,
|
|
):
|
|
tools = [tool.model_dump() for tool in request.tools]
|
|
messages = harmony_adapter.stateful_stream_harmony_tokens_to_openai_messages(
|
|
stream_request_id, output.token_ids_diff, tools, request.tool_choice)
|
|
stream_state = harmony_adapter.get_stream_state(stream_request_id)
|
|
assert stream_state is not None
|
|
parser = stream_state.get_parser()
|
|
if parser.state == StreamState.EXPECT_START:
|
|
streaming_events_helper.output_index_increment()
|
|
streaming_events_helper.is_output_item_added_sent = False
|
|
|
|
if len(messages) > 0:
|
|
previous_item = messages[-1]
|
|
if previous_item.recipient is not None:
|
|
# Deal with tool call here
|
|
pass
|
|
elif previous_item.channel == "analysis":
|
|
reasoning_item = ResponseReasoningItem(
|
|
type="reasoning",
|
|
content=[
|
|
Content(
|
|
text=previous_item.content[0].text,
|
|
type="reasoning_text",
|
|
),
|
|
],
|
|
status="completed",
|
|
id=streaming_events_helper.item_id,
|
|
summary=[],
|
|
)
|
|
yield streaming_events_helper.get_reasoning_text_done_event(
|
|
previous_item.content[0].text)
|
|
yield streaming_events_helper.get_output_item_done_event(
|
|
reasoning_item)
|
|
|
|
elif previous_item.channel == "final":
|
|
text_content = ResponseOutputText(
|
|
type="output_text",
|
|
text=previous_item.content[0].text,
|
|
annotations=[],
|
|
)
|
|
|
|
text_item = ResponseOutputMessage(
|
|
id=streaming_events_helper.item_id,
|
|
type="message",
|
|
role="assistant",
|
|
content=[text_content],
|
|
status="completed",
|
|
)
|
|
|
|
yield streaming_events_helper.get_text_done_event(
|
|
previous_item.content[0].text, [])
|
|
yield streaming_events_helper.get_content_part_done_event(
|
|
text_content)
|
|
yield streaming_events_helper.get_output_item_done_event(
|
|
text_item)
|
|
|
|
if parser.last_content_delta:
|
|
if (parser.current_channel == "final"
|
|
and parser.current_recipient is None):
|
|
if not streaming_events_helper.is_output_item_added_sent:
|
|
streaming_events_helper.is_output_item_added_sent = True
|
|
|
|
output_item = ResponseOutputMessage(
|
|
id=streaming_events_helper.item_id,
|
|
type="message",
|
|
role="assistant",
|
|
content=[],
|
|
status="in_progress",
|
|
)
|
|
|
|
content_part = ResponseOutputText(
|
|
type="output_text",
|
|
text="",
|
|
annotations=[],
|
|
logprobs=[],
|
|
)
|
|
yield streaming_events_helper.get_output_item_added_event(
|
|
output_item)
|
|
yield streaming_events_helper.get_content_part_added_event(
|
|
content_part)
|
|
|
|
yield streaming_events_helper.get_text_delta_event(
|
|
parser.last_content_delta, [])
|
|
|
|
elif (parser.current_channel == "analysis"
|
|
and parser.current_recipient is None):
|
|
if not streaming_events_helper.is_output_item_added_sent:
|
|
streaming_events_helper.is_output_item_added_sent = True
|
|
|
|
reasoning_item = ResponseReasoningItem(
|
|
id=streaming_events_helper.item_id,
|
|
type="reasoning",
|
|
summary=[],
|
|
status="in_progress",
|
|
)
|
|
|
|
reasoning_content = PartReasoningText(
|
|
type="reasoning_text",
|
|
text="",
|
|
)
|
|
|
|
yield streaming_events_helper.get_output_item_added_event(
|
|
reasoning_item)
|
|
yield streaming_events_helper.get_content_part_added_event(
|
|
reasoning_content)
|
|
|
|
yield streaming_events_helper.get_reasoning_text_delta_event(
|
|
parser.last_content_delta)
|
|
|
|
|
|
async def process_streaming_events(
|
|
generator,
|
|
request: ResponsesRequest,
|
|
sampling_params: SamplingParams,
|
|
model_name: str,
|
|
conversation_store: ConversationHistoryStore,
|
|
enable_store: bool = False,
|
|
use_harmony: bool = True,
|
|
create_time: Optional[int] = None,
|
|
reasoning_parser: Optional[str] = None,
|
|
tool_parser: Optional[str] = None,
|
|
) -> AsyncGenerator[str, None]:
|
|
sequence_number = 0
|
|
response_creation_time = create_time if create_time is not None else int(
|
|
time.time())
|
|
final_res: Optional[RequestOutput] = None
|
|
reasoning_parser_dict: dict[int, BaseReasoningParser] = {}
|
|
tool_parser_dict: dict[int, BaseToolParser] = {}
|
|
|
|
def _send_event(event: OpenAIBaseModel):
|
|
nonlocal sequence_number
|
|
# Set sequence_number if the event has this attribute
|
|
if hasattr(event, 'sequence_number'):
|
|
event.sequence_number = sequence_number
|
|
sequence_number += 1
|
|
# Get event type from the event's type field if it exists
|
|
event_type = getattr(event, 'type', 'unknown')
|
|
return (f"event: {event_type}\n"
|
|
f"data: {event.model_dump_json(indent=None)}\n\n")
|
|
|
|
streaming_events_helper = ResponsesStreamingEventsHelper()
|
|
|
|
initial_response = ResponsesResponse.from_request(
|
|
request,
|
|
sampling_params,
|
|
model_name=model_name,
|
|
created_time=response_creation_time,
|
|
output=[],
|
|
status="in_progress",
|
|
usage=None,
|
|
).model_dump()
|
|
|
|
yield _send_event(
|
|
streaming_events_helper.get_response_created_event(initial_response))
|
|
yield _send_event(
|
|
streaming_events_helper.get_response_in_progress_event(
|
|
initial_response))
|
|
|
|
stream_request_id = f"responses-api-{request.request_id}"
|
|
harmony_adapter = get_harmony_adapter()
|
|
async for res in generator:
|
|
final_res = res
|
|
# TODO(JunyiXu-nv): handle multiple outputs
|
|
output = res.outputs[0]
|
|
|
|
event_generator = None
|
|
if use_harmony:
|
|
event_generator = _generate_streaming_event_harmony(
|
|
harmony_adapter=harmony_adapter,
|
|
stream_request_id=stream_request_id,
|
|
output=output,
|
|
request=request,
|
|
streaming_events_helper=streaming_events_helper,
|
|
)
|
|
|
|
else:
|
|
event_generator = _generate_streaming_event(
|
|
output=output,
|
|
request=request,
|
|
finished_generation=res.finished,
|
|
streaming_events_helper=streaming_events_helper,
|
|
reasoning_parser_id=reasoning_parser,
|
|
tool_parser_id=tool_parser,
|
|
reasoning_parser_dict=reasoning_parser_dict,
|
|
tool_parser_dict=tool_parser_dict,
|
|
)
|
|
|
|
if event_generator is None:
|
|
raise RuntimeError("Failed to generate streaming events")
|
|
|
|
for event in event_generator:
|
|
yield _send_event(event)
|
|
|
|
final_response = await create_response(
|
|
generator=generator,
|
|
request=request,
|
|
sampling_params=sampling_params,
|
|
model_name=model_name,
|
|
conversation_store=conversation_store,
|
|
generation_result=final_res,
|
|
enable_store=enable_store,
|
|
use_harmony=use_harmony,
|
|
create_time=response_creation_time,
|
|
reasoning_parser=reasoning_parser,
|
|
tool_parser=tool_parser,
|
|
)
|
|
|
|
yield _send_event(
|
|
ResponseCompletedEvent(
|
|
type="response.completed",
|
|
sequence_number=-1,
|
|
response=final_response.model_dump(),
|
|
))
|
|
|
|
|
|
class ServerArrivalTimeMiddleware:
|
|
"""
|
|
Custom ASGI middleware to track server arrival time.
|
|
|
|
We implement this as a pure ASGI middleware instead of using FastAPI's
|
|
@app.middleware("http") decorator because the decorator internally uses
|
|
BaseHTTPMiddleware, which wraps the ASGI `receive` callable. This wrapping
|
|
breaks Request.is_disconnected() functionality - the wrapped receive doesn't
|
|
properly forward http.disconnect events while the middleware is waiting in
|
|
call_next(), preventing detection of client disconnections during long-running
|
|
non-streaming requests.
|
|
|
|
By implementing pure ASGI middleware, we pass through the original receive/send
|
|
callables unchanged, preserving the ability to detect client disconnections.
|
|
|
|
See: https://github.com/encode/starlette/discussions/2094
|
|
"""
|
|
|
|
def __init__(self, app):
|
|
self.app = app
|
|
|
|
async def __call__(self, scope, receive, send):
|
|
if scope["type"] == "http":
|
|
# Add arrival time to scope
|
|
scope["state"] = {}
|
|
scope["state"][
|
|
"server_arrival_time"] = get_steady_clock_now_in_seconds()
|
|
|
|
# Pass through the original receive/send - no wrapping!
|
|
await self.app(scope, receive, send)
|
|
|
|
|
|
class ResponseHooks(ABC):
|
|
"""
|
|
Hooks for response processing and (disagg) service perf observability.
|
|
"""
|
|
|
|
@abstractmethod
|
|
def on_req_begin(self, request: UCompletionRequest):
|
|
pass
|
|
|
|
@abstractmethod
|
|
def on_ctx_resp(self, ctx_server: str, response: UCompletionResponse):
|
|
pass
|
|
|
|
@abstractmethod
|
|
def on_first_token(self,
|
|
gen_server: str,
|
|
request: UCompletionRequest,
|
|
response: UCompletionResponse = None):
|
|
pass
|
|
|
|
@abstractmethod
|
|
def on_resp_done(self,
|
|
gen_server: str,
|
|
request: UCompletionRequest,
|
|
response: UCompletionResponse = None):
|
|
pass
|
|
|
|
|
|
async def done_generator() -> AsyncGenerator[bytes, None]:
|
|
yield "data: [DONE]\n\n".encode('utf-8')
|
|
|
|
|
|
UCompletionResponseOrGenerator = Union[UCompletionResponse,
|
|
AsyncGenerator[Any, None]]
|