TensorRT-LLMs/tensorrt_llm/serve/responses_utils.py
Yilin Fan 01423ac183
[None][feat] perf_metrics endpoint functionality improvement (#8005)
Signed-off-by: Yilin Fan <206948969+nv-yilinf@users.noreply.github.com>
Signed-off-by: nv-yilinf <206948969+nv-yilinf@users.noreply.github.com>
2025-10-02 17:43:25 -07:00

886 lines
36 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
import json
import os
import time
import uuid
from collections.abc import AsyncGenerator
from copy import copy
from typing import Literal, Optional, OrderedDict, Union
# yapf: disable
from openai.types.responses import (ResponseCompletedEvent,
ResponseContentPartAddedEvent,
ResponseContentPartDoneEvent,
ResponseCreatedEvent,
ResponseFunctionToolCall,
ResponseInProgressEvent, ResponseOutputItem,
ResponseOutputItemAddedEvent,
ResponseOutputItemDoneEvent,
ResponseOutputMessage, ResponseOutputText,
ResponseReasoningItem,
ResponseReasoningTextDeltaEvent,
ResponseReasoningTextDoneEvent,
ResponseTextDeltaEvent,
ResponseTextDoneEvent)
# yapf: enable
from openai.types.responses.response_function_web_search import (
ActionFind, ActionOpenPage, ActionSearch, ResponseFunctionWebSearch)
from openai.types.responses.response_reasoning_item import Content
from openai.types.responses.tool import Tool
from openai_harmony import (Author, Conversation, DeveloperContent,
HarmonyEncodingName, Message, ReasoningEffort, Role,
StreamState, SystemContent, TextContent,
ToolDescription, load_harmony_encoding)
from tensorrt_llm.bindings import steady_clock_now
from tensorrt_llm.llmapi import SamplingParams
from tensorrt_llm.llmapi.llm import RequestOutput
from tensorrt_llm.logger import logger
from tensorrt_llm.serve.openai_protocol import (OpenAIBaseModel,
ResponseInputOutputItem,
ResponsesRequest,
ResponsesResponse)
from .harmony_adapter import HarmonyAdapter
REASONING_EFFORT = {
"high": ReasoningEffort.HIGH,
"medium": ReasoningEffort.MEDIUM,
"low": ReasoningEffort.LOW,
}
ENABLE_RESPONSES_DEBUG_MSG = False
def responses_debug_log(msg):
if ENABLE_RESPONSES_DEBUG_MSG:
logger.debug(msg)
_harmony_encoding = None
def random_uuid():
return str(uuid.uuid4().hex)
def get_encoding():
global _harmony_encoding
if _harmony_encoding is None:
_harmony_encoding = load_harmony_encoding(
HarmonyEncodingName.HARMONY_GPT_OSS)
return _harmony_encoding
def decode_tokens(tokens):
return get_encoding().decode(tokens)
def get_steady_clock_now_in_seconds() -> float:
return steady_clock_now().total_seconds()
def parse_response_input(
input_msg: ResponseInputOutputItem,
prev_responses: list[Union[ResponseOutputItem, ResponseReasoningItem]]
) -> Message:
if not isinstance(input_msg, dict):
input_msg = input_msg.model_dump()
responses_debug_log(f"------- Parsing input -----------")
responses_debug_log(input_msg)
responses_debug_log("")
if "type" not in input_msg or input_msg["type"] == "message":
role = input_msg["role"]
content = input_msg["content"]
if role == "system":
# User is trying to set a system message. Change it to:
# <|start|>developer<|message|># Instructions
# {instructions}<|end|>
role = "developer"
text_prefix = "Instructions:\n"
else:
text_prefix = ""
if isinstance(content, str):
msg = Message.from_role_and_content(role, text_prefix + content)
elif isinstance(content, list):
contents = [
TextContent(text=text_prefix + c["text"]) for c in content
]
msg = Message.from_role_and_contents(role, contents)
else:
logger.warning("Responses API: Invalid input message type")
msg = None
elif input_msg["type"] == "function_call_output":
call_id = input_msg["call_id"]
call_response: Optional[ResponseFunctionToolCall] = None
for prev_response in reversed(prev_responses):
if isinstance(prev_response, ResponseFunctionToolCall
) and prev_response.call_id == call_id:
call_response = prev_response
break
if call_response is None:
raise ValueError(f"No call message found for {call_id}")
msg = Message.from_author_and_content(
Author.new(Role.TOOL, f"functions.{call_response.name}"),
input_msg["output"])
elif input_msg["type"] == "reasoning":
content = input_msg["content"]
assert len(content) == 1
msg = Message.from_role_and_content(Role.ASSISTANT, content[0]["text"])
elif input_msg["type"] == "function_call":
msg = Message.from_role_and_content(Role.ASSISTANT,
input_msg["arguments"])
msg = msg.with_channel("commentary")
msg = msg.with_recipient(f"functions.{input_msg['name']}")
msg = msg.with_content_type("json")
else:
raise ValueError(f"Unknown input type: {input_msg['type']}")
return msg
class ConversationHistoryStore:
def __init__(self, resp_capacity: int = 16, max_conversations=32):
self.response_capacity = resp_capacity
self.conversation_capacity = resp_capacity * 4
self.max_conversations = max_conversations
self.responses_lock = asyncio.Lock()
self.responses: OrderedDict[str, ResponsesResponse] = OrderedDict()
self.conversations_lock = asyncio.Lock()
self.conversations: OrderedDict[str, list[Message]] = OrderedDict()
self.response_to_conversation: dict[str, str] = {}
self.conversation_to_response: dict[str, str] = {}
async def load_response(self, resp_id: str) -> ResponsesResponse:
responses_debug_log(f"ConversationHistoryStore loading resp: {resp_id}")
async with self.responses_lock:
return self.responses.get(resp_id)
async def store_response(self,
resp: ResponsesResponse,
resp_msgs: Optional[list[Message]] = [],
prev_resp_id: Optional[str] = None) -> None:
resp_id = resp.id
responses_debug_log(f"ConversationHistoryStore storing resp: {resp_id}")
async with self.responses_lock:
self.responses[resp_id] = resp
if len(self.responses) > self.response_capacity:
self._pop_response()
async with self.conversations_lock:
conversation_id: str
if resp_id in self.response_to_conversation:
conversation_id = self.response_to_conversation[resp_id]
self.conversations[conversation_id].extend(resp_msgs)
elif prev_resp_id is not None:
conversation_id = self.response_to_conversation[prev_resp_id]
self.conversations[conversation_id].extend(resp_msgs)
while len(self.conversations[conversation_id]
) > self.conversation_capacity:
self._pop_conversation(resp_id)
else:
conversation_id = random_uuid()
self.conversations[conversation_id] = resp_msgs
responses_debug_log(
f" * storing at conversation id: {conversation_id}")
self.response_to_conversation[resp_id] = conversation_id
self.conversation_to_response[conversation_id] = resp_id
self._update_visited_conversation(conversation_id)
async def store_messages(self, resp_id: str, msgs: list[Message],
prev_resp_id: Optional[str]):
responses_debug_log(f"ConversationHistoryStore storing msg:")
for msg in msgs:
responses_debug_log(f" -> {msg.to_json()}")
async with self.conversations_lock:
conversation_id: str
if prev_resp_id is not None:
conversation_id = self.response_to_conversation[prev_resp_id]
else:
conversation_id = random_uuid()
responses_debug_log(
f" * storing at conversation: {conversation_id}")
self.conversations[conversation_id] = msgs
if len(self.conversations[conversation_id]
) > self.conversation_capacity:
self._pop_conversation(resp_id)
self.response_to_conversation[resp_id] = conversation_id
self.conversation_to_response[conversation_id] = resp_id
self._update_visited_conversation(conversation_id)
async def append_messages(self, resp_id: str, msgs: list[Message]):
responses_debug_log(f"ConversationHistoryStore appending msgs:")
for msg in msgs:
responses_debug_log(f" -> {msg.to_json()}")
async with self.conversations_lock:
assert resp_id in self.response_to_conversation
conversation_id = self.response_to_conversation[resp_id]
responses_debug_log(
f" * appending at conversation: {conversation_id}")
self.conversations[conversation_id].extend(msgs)
if len(self.conversations[conversation_id]
) > self.conversation_capacity:
self._pop_conversation(resp_id)
self._update_visited_conversation(conversation_id)
async def get_conversation_history(self, resp_id: str) -> list[Message]:
responses_debug_log(f"ConversationHistoryStore getting prev_msgs:")
responses_debug_log(f" -> prev_resp_id: {resp_id}")
async with self.conversations_lock:
if resp_id in self.response_to_conversation:
conversation_id = self.response_to_conversation[resp_id]
self._update_visited_conversation(conversation_id)
return self.conversations.get(conversation_id, [])
return []
def _update_visited_conversation(self, conversation_id) -> None:
if conversation_id not in self.conversations:
return
self.conversations.move_to_end(conversation_id)
if len(self.conversations) > self.max_conversations:
removed_id, _ = self.conversations.popitem(last=False)
responses_debug_log(
f"ConversationHistoryStore Removing conversation {removed_id}")
removed_resp_id = self.conversation_to_response[removed_id]
# The responses may have been removed due to response capacity
if removed_resp_id in self.response_to_conversation:
self.response_to_conversation.pop(removed_resp_id)
self.conversation_to_response.pop(removed_id)
def _pop_conversation(self, resp_id) -> None:
conversation_id = self.response_to_conversation.get(resp_id, None)
if conversation_id is None:
return
conversation = self.conversations[conversation_id]
first_conversation_range = []
for i, msg in enumerate(conversation):
if msg.author.role == Role.USER:
first_conversation_range.append(i)
elif msg.channel == "final":
first_conversation_range.append(i)
break
del conversation[
first_conversation_range[0]:first_conversation_range[1] + 1]
def _pop_response(self) -> None:
responses_debug_log(f"responses type: {type(self.responses)}")
resp_id, _ = self.responses.popitem(last=False)
if resp_id in self.response_to_conversation:
self.response_to_conversation.pop(resp_id)
def get_system_message(
model_identity: Optional[str] = None,
reasoning_effort: Optional[Literal["high", "medium", "low"]] = None,
start_date: Optional[str] = None,
browser_description: Optional[str] = None,
python_description: Optional[str] = None,
) -> Message:
sys_msg_content = SystemContent.new()
if model_identity is not None:
sys_msg_content = sys_msg_content.with_model_identity(model_identity)
if reasoning_effort is not None:
sys_msg_content = sys_msg_content.with_reasoning_effort(
REASONING_EFFORT[reasoning_effort])
if start_date:
sys_msg_content = sys_msg_content.with_conversation_start_date(
start_date)
if browser_description is not None:
sys_msg_content = sys_msg_content.with_tools(browser_description)
if python_description is not None:
sys_msg_content = sys_msg_content.with_tools(python_description)
sys_msg = Message.from_role_and_content(Role.SYSTEM, sys_msg_content)
return sys_msg
def get_developer_message(instructions: Optional[str] = None,
tools: Optional[list[Tool]] = None) -> Message:
dev_msg_content = DeveloperContent.new()
if instructions is not None:
dev_msg_content = dev_msg_content.with_instructions(instructions)
if tools is not None:
function_tools = []
for tool in tools:
if tool.type in ("web_search_preview", "code_interpreter"):
# These are built-in tools that are added to the system message.
pass
elif tool.type == "function":
function_tools.append(tool)
else:
raise ValueError(f"tool type {tool.type} not supported")
if function_tools:
function_tool_descriptions = [
ToolDescription.new(
name=tool.name,
description=tool.description,
parameters=tool.parameters,
) for tool in function_tools
]
dev_msg_content = dev_msg_content.with_function_tools(
function_tool_descriptions)
dev_msg = Message.from_role_and_content(Role.DEVELOPER, dev_msg_content)
return dev_msg
def get_user_message(content: str) -> Message:
return Message.from_role_and_content(Role.USER, content)
def construct_harmony_messages(
request: ResponsesRequest,
prev_response: Optional[ResponsesResponse],
prev_msgs: list[Message] = [],
) -> list[Message]:
"""Construct messages from request input, includes conversation history messages if exists."""
messages: list[Message] = []
if prev_response is None:
# New conversation.
reasoning_effort = (request.reasoning.effort
if request.reasoning else None)
sys_msg = get_system_message(reasoning_effort=reasoning_effort, )
messages.append(sys_msg)
dev_msg = get_developer_message(request.instructions, request.tools)
messages.append(dev_msg)
else:
messages.extend(prev_msgs)
# Append the new input.
# Responses API supports simple text inputs without chat format.
if isinstance(request.input, str):
messages.append(get_user_message(request.input))
else:
if prev_response is not None:
prev_outputs = copy(prev_response.output)
else:
prev_outputs = []
for input_msg in request.input:
msg = parse_response_input(input_msg, prev_outputs)
if msg is not None:
messages.append(msg)
# User passes in a a tool call request and its output. We need
# to add the tool call request to prev_outputs so that the
# parse_response_input can find the tool call request when
# parsing the tool call output.
if isinstance(input_msg, ResponseFunctionToolCall):
prev_outputs.append(input_msg)
return messages
def render_for_completion(messages: list[Message]) -> list[int]:
conversation = Conversation.from_messages(messages)
responses_debug_log("Rendering conversation:")
responses_debug_log(conversation.to_json())
token_ids = get_encoding().render_conversation_for_completion(
conversation, Role.ASSISTANT)
return token_ids
def parse_output_tokens(tokens: list[int]) -> list[Message]:
return get_encoding().parse_messages_from_completion_tokens(
tokens, role=Role.ASSISTANT)
def parse_output_message(message: Message) -> list[ResponseOutputItem]:
"""
Parse a Harmony message into a list of output response items.
"""
if message.author.role != "assistant":
# This is a message from a tool to the assistant (e.g., search result).
# Don't include it in the final output for now. This aligns with
# OpenAI's behavior on models like o4-mini.
return []
output_items: list[ResponseOutputItem] = []
recipient = message.recipient
if recipient is not None and recipient.startswith("browser."):
if len(message.content) != 1:
raise ValueError("Invalid number of contents in browser message")
content = message.content[0]
browser_call = json.loads(content.text)
# TODO: translate to url properly!
if recipient == "browser.search":
action = ActionSearch(
query=f"cursor:{browser_call.get('query', '')}", type="search")
elif recipient == "browser.open":
action = ActionOpenPage(url=f"cursor:{browser_call.get('url', '')}",
type="open_page")
elif recipient == "browser.find":
action = ActionFind(pattern=browser_call["pattern"],
url=f"cursor:{browser_call.get('url', '')}",
type="find")
else:
raise ValueError(f"Unknown browser action: {recipient}")
web_search_item = ResponseFunctionWebSearch(
id=f"ws_{random_uuid()}",
action=action,
status="completed",
type="web_search_call",
)
output_items.append(web_search_item)
elif message.channel == "analysis":
for content in message.content:
reasoning_item = ResponseReasoningItem(
id=f"rs_{random_uuid()}",
summary=[],
type="reasoning",
content=[Content(text=content.text, type="reasoning_text")],
status=None,
)
output_items.append(reasoning_item)
elif message.channel == "commentary":
if message.recipient is None:
pass
elif message.recipient.startswith("functions."):
function_name = message.recipient.split(".")[-1]
for content in message.content:
random_id = random_uuid()
response_item = ResponseFunctionToolCall(
arguments=content.text,
call_id=f"call_{random_id}",
type="function_call",
name=function_name,
id=f"fc_{random_id}",
)
output_items.append(response_item)
elif message.recipient.startswith(
"python") or message.recipient.startswith("browser"):
for content in message.content:
reasoning_item = ResponseReasoningItem(
id=f"rs_{random_uuid()}",
summary=[],
type="reasoning",
content=[Content(text=content.text, type="reasoning_text")],
status=None,
)
output_items.append(reasoning_item)
else:
raise ValueError(f"Unknown recipient: {message.recipient}")
elif message.channel == "final":
contents = []
for content in message.content:
output_text = ResponseOutputText(
text=content.text,
annotations=[], # TODO
type="output_text",
logprobs=None, # TODO
)
contents.append(output_text)
text_item = ResponseOutputMessage(
id=f"msg_{random_uuid()}",
content=contents,
role=message.author.role,
status="completed",
type="message",
)
output_items.append(text_item)
else:
raise ValueError(f"Unknown channel: {message.channel}")
return output_items
def finish_reason_mapping(finish_reason: str) -> str:
match finish_reason:
case 'stop':
return 'completed'
case 'length':
return 'incomplete'
case 'timeout':
return 'failed'
case 'cancelled':
return 'cancelled'
raise RuntimeError("Should never reach here!")
async def request_preprocess(request: ResponsesRequest,
prev_response: Optional[ResponsesResponse],
harmony_adapter: HarmonyAdapter,
conversation_store: ConversationHistoryStore,
enable_store=False):
# TODO: fix default_max_tokens
sampling_params = request.to_sampling_params(
default_max_tokens=int(16384),
default_sampling_params={
"stop_token_ids": harmony_adapter.get_stop_tokens()
})
prev_response_id = request.previous_response_id
# TODO: better way to enable metrics
if len(os.getenv("TRTLLM_KVCACHE_TIME_OUTPUT_PATH", "")) > 0:
sampling_params.return_perf_metrics = True
prev_msgs = []
if enable_store:
prev_msgs = await conversation_store.get_conversation_history(
prev_response_id)
responses_debug_log(f"Prev msgs:")
for msg in prev_msgs:
responses_debug_log(f" -> {msg.to_json()}")
messages = construct_harmony_messages(request,
prev_response,
prev_msgs=prev_msgs)
if enable_store and request.store:
# Remove reasoning messages to save token usage during multi-turn conversation
msgs_to_store = [msg for msg in messages if msg.channel != "analysis"]
await conversation_store.store_messages(request.request_id,
msgs_to_store, prev_response_id)
input_tokens = render_for_completion(messages)
responses_debug_log("======= Complete Inputs to model =======")
responses_debug_log(decode_tokens(input_tokens))
responses_debug_log("========================================")
return input_tokens, sampling_params
async def create_response(
generator,
request: ResponsesRequest,
sampling_params,
model_name: str,
conversation_store: ConversationHistoryStore,
generation_result: RequestOutput = None,
enable_store=False,
create_time: int = None,
) -> ResponsesResponse:
final_res: Optional[RequestOutput] = None
response_creation_time = create_time if create_time is not None else int(
time.time())
prev_response_id = request.previous_response_id
if generation_result is not None:
final_res = generation_result
else:
final_res = await generator
if final_res is None:
raise RuntimeError("No output generated or provided")
responses_debug_log("================================================")
responses_debug_log("RAW MODEL OUTPUT:")
responses_debug_log(final_res.outputs)
responses_debug_log("================================================")
output_messages = parse_output_tokens(final_res.outputs[0].token_ids)
responses_debug_log(f"output messages: {len(output_messages)}")
for msg in output_messages:
responses_debug_log(f" -> {msg.to_json()}")
# prepare responses output
output_content = []
for msg in output_messages:
output_content.extend(parse_output_message(msg))
response = ResponsesResponse.from_request(
request=request,
sampling_params=sampling_params,
model_name=model_name,
created_time=response_creation_time,
output=output_content,
status=finish_reason_mapping(final_res.outputs[0].finish_reason),
)
if enable_store and request.store:
await conversation_store.store_response(resp=response,
resp_msgs=output_messages,
prev_resp_id=prev_response_id)
responses_debug_log("========== Response ===========")
responses_debug_log(response)
responses_debug_log("===============================")
return response
async def process_streaming_events(
request: ResponsesRequest,
sampling_params: SamplingParams,
generator,
harmony_adapter: HarmonyAdapter,
model_name: str,
conversation_store: ConversationHistoryStore,
create_time: int = None,
enable_store=False) -> AsyncGenerator[str, None]:
sequence_number = 0
response_creation_time = create_time if create_time is not None else int(
time.time())
final_res: Optional[RequestOutput] = None
def _send_event(event: OpenAIBaseModel):
nonlocal sequence_number
# Set sequence_number if the event has this attribute
if hasattr(event, 'sequence_number'):
event.sequence_number = sequence_number
sequence_number += 1
# Get event type from the event's type field if it exists
event_type = getattr(event, 'type', 'unknown')
return (f"event: {event_type}\n"
f"data: {event.model_dump_json(indent=None)}\n\n")
current_content_index = 0 # FIXME: this number is never changed
current_output_index = 0
current_item_id = "" # FIXME: this number is never changed
sent_output_item_added = False
initial_response = ResponsesResponse.from_request(
request,
sampling_params,
model_name=model_name,
created_time=response_creation_time,
output=[],
status="in_progress",
usage=None,
).model_dump()
yield _send_event(
ResponseCreatedEvent(
type="response.created",
sequence_number=-1,
response=initial_response,
))
yield _send_event(
ResponseInProgressEvent(
type="response.in_progress",
sequence_number=-1,
response=initial_response,
))
tools = [tool.model_dump() for tool in request.tools]
stream_request_id = f"responses-api-{request.request_id}"
async for res in generator:
final_res = res
output = res.outputs[0]
messages = harmony_adapter.stateful_stream_harmony_tokens_to_openai_messages(
stream_request_id, output.token_ids_diff, tools,
request.tool_choice)
stream_state = harmony_adapter.get_stream_state(stream_request_id)
assert stream_state is not None
parser = stream_state.get_parser()
if parser.state == StreamState.EXPECT_START:
current_output_index += 1
sent_output_item_added = False
if len(messages) > 0:
previous_item = messages[-1]
if previous_item.recipient is not None:
# Deal with tool call here
pass
elif previous_item.channel == "analysis":
reasoning_item = ResponseReasoningItem(
type="reasoning",
content=[
Content(
text=previous_item.content[0].text,
type="reasoning_text",
),
],
status="completed",
id=current_item_id,
summary=[],
)
yield _send_event(
ResponseReasoningTextDoneEvent(
type="response.reasoning_text.done",
item_id=current_item_id,
sequence_number=-1,
output_index=current_output_index,
content_index=current_content_index,
text=previous_item.content[0].text,
))
yield _send_event(
ResponseOutputItemDoneEvent(
type="response.output_item.done",
sequence_number=-1,
output_index=current_output_index,
item=reasoning_item,
))
elif previous_item.channel == "final":
text_content = ResponseOutputText(
type="output_text",
text=previous_item.content[0].text,
annotations=[],
)
yield _send_event(
ResponseTextDoneEvent(
type="response.output_text.done",
sequence_number=-1,
output_index=current_output_index,
content_index=current_content_index,
text=previous_item.content[0].text,
logprobs=[],
item_id=current_item_id,
))
yield _send_event(
ResponseContentPartDoneEvent(
type="response.content_part.done",
sequence_number=-1,
item_id=current_item_id,
output_index=current_output_index,
content_index=current_content_index,
part=text_content,
))
yield _send_event(
ResponseOutputItemDoneEvent(
type="response.output_item.done",
sequence_number=-1,
output_index=current_output_index,
item=ResponseOutputMessage(
id=current_item_id,
type="message",
role="assistant",
content=[text_content],
status="completed",
),
))
if parser.last_content_delta:
if (parser.current_channel == "final"
and parser.current_recipient is None):
if not sent_output_item_added:
sent_output_item_added = True
yield _send_event(
ResponseOutputItemAddedEvent(
type="response.output_item.added",
sequence_number=-1,
output_index=current_output_index,
item=ResponseOutputMessage(
id=current_item_id,
type="message",
role="assistant",
content=[],
status="in_progress",
),
))
yield _send_event(
ResponseContentPartAddedEvent(
type="response.content_part.added",
sequence_number=-1,
output_index=current_output_index,
item_id=current_item_id,
content_index=current_content_index,
part=ResponseOutputText(
type="output_text",
text="",
annotations=[],
logprobs=[],
),
))
yield _send_event(
ResponseTextDeltaEvent(
type="response.output_text.delta",
sequence_number=-1,
content_index=current_content_index,
output_index=current_output_index,
item_id=current_item_id,
delta=parser.last_content_delta,
# TODO, use logprobs from ctx.last_request_output
logprobs=[],
))
elif (parser.current_channel == "analysis"
and parser.current_recipient is None):
if not sent_output_item_added:
sent_output_item_added = True
yield _send_event(
ResponseOutputItemAddedEvent(
type="response.output_item.added",
sequence_number=-1,
output_index=current_output_index,
item=ResponseReasoningItem(
type="reasoning",
id=current_item_id,
summary=[],
status="in_progress",
),
))
yield _send_event(
ResponseContentPartAddedEvent(
type="response.content_part.added",
sequence_number=-1,
output_index=current_output_index,
item_id=current_item_id,
content_index=current_content_index,
part=ResponseOutputText(
type="output_text",
text="",
annotations=[],
logprobs=[],
),
))
yield _send_event(
ResponseReasoningTextDeltaEvent(
type="response.reasoning_text.delta",
item_id=current_item_id,
output_index=current_output_index,
content_index=current_content_index,
delta=parser.last_content_delta,
sequence_number=-1,
))
# TODO(JunyiXu-nv): support built-in tools(python/browser/code interpreter)
final_response = await create_response(generator, request, sampling_params,
model_name, conversation_store,
final_res, enable_store,
response_creation_time)
yield _send_event(
ResponseCompletedEvent(
type="response.completed",
sequence_number=-1,
response=final_response.model_dump(),
))
class ServerArrivalTimeMiddleware:
"""
Custom ASGI middleware to track server arrival time.
We implement this as a pure ASGI middleware instead of using FastAPI's
@app.middleware("http") decorator because the decorator internally uses
BaseHTTPMiddleware, which wraps the ASGI `receive` callable. This wrapping
breaks Request.is_disconnected() functionality - the wrapped receive doesn't
properly forward http.disconnect events while the middleware is waiting in
call_next(), preventing detection of client disconnections during long-running
non-streaming requests.
By implementing pure ASGI middleware, we pass through the original receive/send
callables unchanged, preserving the ability to detect client disconnections.
See: https://github.com/encode/starlette/discussions/2094
"""
def __init__(self, app):
self.app = app
async def __call__(self, scope, receive, send):
if scope["type"] == "http":
# Add arrival time to scope
scope["state"] = {}
scope["state"][
"server_arrival_time"] = get_steady_clock_now_in_seconds()
# Pass through the original receive/send - no wrapping!
await self.app(scope, receive, send)