mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
[None][feat] Cherry-pick Responses API and multiple postprocess workers support for chat harmony (#7600)
Signed-off-by: Junyi Xu <219237550+JunyiXu-nv@users.noreply.github.com> Co-authored-by: Pengyun Lin <81065165+LinPoly@users.noreply.github.com> Co-authored-by: Tao Li @ NVIDIA <tali@nvidia.com>
This commit is contained in:
parent
d60dad6b9d
commit
ac0df0a393
@ -6,7 +6,7 @@ import re
|
|||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Any, AsyncGenerator, Literal
|
from typing import Any, List, Literal
|
||||||
|
|
||||||
from openai_harmony import (Author, Conversation, DeveloperContent,
|
from openai_harmony import (Author, Conversation, DeveloperContent,
|
||||||
HarmonyEncodingName, HarmonyError, Message,
|
HarmonyEncodingName, HarmonyError, Message,
|
||||||
@ -14,15 +14,15 @@ from openai_harmony import (Author, Conversation, DeveloperContent,
|
|||||||
SystemContent, TextContent, ToolDescription,
|
SystemContent, TextContent, ToolDescription,
|
||||||
load_harmony_encoding)
|
load_harmony_encoding)
|
||||||
|
|
||||||
from tensorrt_llm.llmapi import RequestOutput
|
|
||||||
from tensorrt_llm.logger import logger
|
from tensorrt_llm.logger import logger
|
||||||
|
|
||||||
# yapf: disable
|
# yapf: disable
|
||||||
from .openai_protocol import (ChatCompletionMessageParam, ChatCompletionRequest,
|
from .openai_protocol import (ChatCompletionMessageParam,
|
||||||
ChatCompletionResponse,
|
ChatCompletionResponse,
|
||||||
ChatCompletionResponseChoice,
|
ChatCompletionResponseChoice,
|
||||||
ChatCompletionResponseStreamChoice,
|
ChatCompletionResponseStreamChoice,
|
||||||
ChatCompletionStreamResponse, ChatMessage,
|
ChatCompletionStreamResponse,
|
||||||
|
ChatCompletionToolsParam, ChatMessage,
|
||||||
DeltaFunctionCall, DeltaMessage, DeltaToolCall,
|
DeltaFunctionCall, DeltaMessage, DeltaToolCall,
|
||||||
UsageInfo)
|
UsageInfo)
|
||||||
|
|
||||||
@ -57,7 +57,8 @@ class HarmonyStreamState:
|
|||||||
# Normal case: filter based on available tools
|
# Normal case: filter based on available tools
|
||||||
self.should_filter_tools = True
|
self.should_filter_tools = True
|
||||||
self.available_tools = {
|
self.available_tools = {
|
||||||
tool.get("function", {}).get("name", "")
|
tool.get("function", {}).get("name", "") if tool.get(
|
||||||
|
"name", None) is None else tool.get("name")
|
||||||
for tool in available_tools
|
for tool in available_tools
|
||||||
}
|
}
|
||||||
self.available_tools.discard("")
|
self.available_tools.discard("")
|
||||||
@ -78,6 +79,9 @@ class HarmonyStreamState:
|
|||||||
|
|
||||||
logger.debug("Created HarmonyStreamState for request %s", request_id)
|
logger.debug("Created HarmonyStreamState for request %s", request_id)
|
||||||
|
|
||||||
|
def get_parser(self) -> StreamableParser:
|
||||||
|
return self.parser
|
||||||
|
|
||||||
def process_token_batch(self, tokens: list[int]) -> list[dict[str, Any]]:
|
def process_token_batch(self, tokens: list[int]) -> list[dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
Process a batch of tokens while maintaining parsing state.
|
Process a batch of tokens while maintaining parsing state.
|
||||||
@ -125,6 +129,42 @@ class HarmonyStreamState:
|
|||||||
|
|
||||||
return deltas
|
return deltas
|
||||||
|
|
||||||
|
def process_token_batch_to_messages(self,
|
||||||
|
tokens: list[int]) -> list[Message]:
|
||||||
|
"""
|
||||||
|
Process a batch of tokens while maintaining parsing state.
|
||||||
|
Returns OpenAI Messages for Responses API
|
||||||
|
"""
|
||||||
|
self.tokens_processed += len(tokens)
|
||||||
|
|
||||||
|
for token in tokens:
|
||||||
|
# Store previous state for transition detection
|
||||||
|
prev_channel = self.parser.current_channel
|
||||||
|
prev_recipient = self.parser.current_recipient
|
||||||
|
|
||||||
|
# Process the token
|
||||||
|
self.parser.process(token)
|
||||||
|
|
||||||
|
# Detect channel/recipient transitions AFTER processing each token
|
||||||
|
channel_changed = prev_channel != self.parser.current_channel
|
||||||
|
recipient_changed = prev_recipient != self.parser.current_recipient
|
||||||
|
|
||||||
|
if channel_changed or recipient_changed:
|
||||||
|
# Mark any active tool calls as completed if we're leaving a tool call
|
||||||
|
if prev_channel == "commentary" and prev_recipient and "functions." in str(
|
||||||
|
prev_recipient):
|
||||||
|
func_name = str(prev_recipient).split("functions.")[-1]
|
||||||
|
for tool_id, tool_info in self.tool_calls.items():
|
||||||
|
if tool_info["name"] == func_name and tool_info.get(
|
||||||
|
"active", True):
|
||||||
|
tool_info["active"] = False
|
||||||
|
|
||||||
|
# Reset channel state for new channel
|
||||||
|
self.channel_started = False
|
||||||
|
self.current_channel_state = None
|
||||||
|
|
||||||
|
return self.parser.messages
|
||||||
|
|
||||||
def _create_closing_token_delta(self) -> dict[str, Any] | None:
|
def _create_closing_token_delta(self) -> dict[str, Any] | None:
|
||||||
"""Create closing token delta for channel transition."""
|
"""Create closing token delta for channel transition."""
|
||||||
if not self.current_channel_state or not self.channel_started:
|
if not self.current_channel_state or not self.channel_started:
|
||||||
@ -317,6 +357,9 @@ class HarmonyAdapter:
|
|||||||
"<|constrain|>": 200009,
|
"<|constrain|>": 200009,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def get_stream_state(self, request_id: str) -> HarmonyStreamState | None:
|
||||||
|
return self._stream_states.get(request_id, None)
|
||||||
|
|
||||||
def get_stop_tokens(self) -> list[int]:
|
def get_stop_tokens(self) -> list[int]:
|
||||||
"""
|
"""
|
||||||
Return the list of stop token IDs for Harmony format.
|
Return the list of stop token IDs for Harmony format.
|
||||||
@ -1214,6 +1257,42 @@ class HarmonyAdapter:
|
|||||||
# Return empty deltas to continue processing
|
# Return empty deltas to continue processing
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
def stateful_stream_harmony_tokens_to_openai_messages(
|
||||||
|
self,
|
||||||
|
request_id: str,
|
||||||
|
tokens: list[int],
|
||||||
|
available_tools: list[dict[str, Any]] | None = None,
|
||||||
|
tool_choice: str | None = None) -> list[Message]:
|
||||||
|
"""
|
||||||
|
Process tokens using stateful parsing.
|
||||||
|
|
||||||
|
This method maintains persistent state across multiple calls for the same request,
|
||||||
|
ensuring proper channel transitions and tool call handling.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request_id: Request ID to maintain state per request
|
||||||
|
tokens: New tokens from this iteration
|
||||||
|
available_tools: Available tools for filtering
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of OpenAI Messages
|
||||||
|
"""
|
||||||
|
stream_state = self._stream_states.get(request_id, None)
|
||||||
|
if stream_state is None:
|
||||||
|
stream_state = self.create_stream_state(request_id, available_tools,
|
||||||
|
tool_choice)
|
||||||
|
|
||||||
|
try:
|
||||||
|
messages = stream_state.process_token_batch_to_messages(tokens)
|
||||||
|
return messages
|
||||||
|
except (HarmonyError, UnicodeDecodeError, ValueError):
|
||||||
|
logger.error(
|
||||||
|
f"Streaming: Failed to process token batch of {len(tokens)} tokens for request {request_id}",
|
||||||
|
)
|
||||||
|
logger.debug(f"Problematic streaming tokens: {tokens}")
|
||||||
|
|
||||||
|
return []
|
||||||
|
|
||||||
def create_openai_streaming_response(
|
def create_openai_streaming_response(
|
||||||
self,
|
self,
|
||||||
request_id: str,
|
request_id: str,
|
||||||
@ -1406,36 +1485,72 @@ class HarmonyAdapter:
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
async def handle_streaming_response(
|
_SERVE_HARMONY_ADAPTER: HarmonyAdapter = None
|
||||||
harmony_adapter: HarmonyAdapter,
|
|
||||||
generator: RequestOutput,
|
|
||||||
request_id: str,
|
def get_harmony_adapter():
|
||||||
request: ChatCompletionRequest,
|
global _SERVE_HARMONY_ADAPTER
|
||||||
) -> AsyncGenerator[str, None]:
|
if _SERVE_HARMONY_ADAPTER is None:
|
||||||
"""Handle streaming response with harmony format."""
|
_SERVE_HARMONY_ADAPTER = HarmonyAdapter()
|
||||||
|
|
||||||
|
return _SERVE_HARMONY_ADAPTER
|
||||||
|
|
||||||
|
|
||||||
|
def handle_streaming_response(tools: List[ChatCompletionToolsParam],
|
||||||
|
tool_choice: str, outputs: List, model: str,
|
||||||
|
request_id: str, done: bool,
|
||||||
|
num_prompt_tokens: int):
|
||||||
first_iteration = True
|
first_iteration = True
|
||||||
async for res in generator:
|
output = outputs[0]
|
||||||
output = res.outputs[0]
|
|
||||||
|
|
||||||
# Convert tools to dictionary format for harmony adapter (standard pattern)
|
# Convert tools to dictionary format for harmony adapter (standard pattern)
|
||||||
tools_dict = None
|
tools_dict = None
|
||||||
if request.tools:
|
harmony_adapter = get_harmony_adapter()
|
||||||
tools_dict = [tool.model_dump() for tool in request.tools]
|
if tools:
|
||||||
|
tools_dict = [tool.model_dump() for tool in tools]
|
||||||
|
|
||||||
# Get tool_choice from request - if "none", don't pass tools to parser
|
# Get tool_choice from request - if "none", don't pass tools to parser
|
||||||
tool_choice = getattr(request, 'tool_choice', None)
|
if tool_choice == "none":
|
||||||
if tool_choice == "none":
|
tools_for_parser = None
|
||||||
tools_for_parser = None
|
else:
|
||||||
|
tools_for_parser = tools_dict
|
||||||
|
|
||||||
|
# Create OpenAI streaming responses
|
||||||
|
try:
|
||||||
|
res = []
|
||||||
|
if done:
|
||||||
|
# Clean up state
|
||||||
|
harmony_adapter.cleanup_stream_state(request_id)
|
||||||
|
|
||||||
|
usage_info = _create_usage_info(num_prompt_tokens, outputs)
|
||||||
|
|
||||||
|
# Send final message with finish_reason
|
||||||
|
final_response = ChatCompletionStreamResponse(
|
||||||
|
model=model,
|
||||||
|
choices=[
|
||||||
|
ChatCompletionResponseStreamChoice(
|
||||||
|
index=0,
|
||||||
|
delta=DeltaMessage(),
|
||||||
|
finish_reason=output.finish_reason,
|
||||||
|
stop_reason=output.stop_reason)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
final_response_json = final_response.model_dump_json(
|
||||||
|
exclude_none=True)
|
||||||
|
final_usage_chunk = ChatCompletionStreamResponse(choices=[],
|
||||||
|
model=model,
|
||||||
|
usage=usage_info)
|
||||||
|
final_usage_json = final_usage_chunk.model_dump_json(
|
||||||
|
exclude_none=True)
|
||||||
|
res.append(f"data: {final_response_json}\n\n")
|
||||||
|
res.append(f"data: {final_usage_json}\n\n")
|
||||||
else:
|
else:
|
||||||
tools_for_parser = tools_dict
|
|
||||||
|
|
||||||
# Create OpenAI streaming responses
|
|
||||||
try:
|
|
||||||
responses = harmony_adapter.create_openai_streaming_response(
|
responses = harmony_adapter.create_openai_streaming_response(
|
||||||
request_id=request_id,
|
request_id=request_id,
|
||||||
tokens=output.token_ids_diff,
|
tokens=output.token_ids_diff,
|
||||||
available_tools=tools_for_parser,
|
available_tools=tools_for_parser,
|
||||||
model_name=request.model,
|
model_name=model,
|
||||||
tool_choice=tool_choice)
|
tool_choice=tool_choice)
|
||||||
# Send first response after receiving the first output
|
# Send first response after receiving the first output
|
||||||
if first_iteration:
|
if first_iteration:
|
||||||
@ -1446,64 +1561,44 @@ async def handle_streaming_response(
|
|||||||
delta=first_delta)
|
delta=first_delta)
|
||||||
|
|
||||||
first_response = ChatCompletionStreamResponse(
|
first_response = ChatCompletionStreamResponse(
|
||||||
model=request.model,
|
model=model,
|
||||||
choices=[choice],
|
choices=[choice],
|
||||||
)
|
)
|
||||||
|
|
||||||
response_json = first_response.model_dump_json(
|
response_json = first_response.model_dump_json(
|
||||||
exclude_none=True)
|
exclude_none=True)
|
||||||
yield f"data: {response_json}\n\n"
|
res.append(f"data: {response_json}\n\n")
|
||||||
|
|
||||||
for response in responses:
|
res.extend(responses)
|
||||||
yield response
|
|
||||||
|
|
||||||
except Exception as e:
|
return res
|
||||||
logger.error(f"Failed to create OpenAI streaming response: {e}")
|
|
||||||
logger.debug(f"Streaming error details: {traceback.format_exc()}")
|
|
||||||
# Clean up state
|
|
||||||
harmony_adapter.cleanup_stream_state(request_id)
|
|
||||||
raise e
|
|
||||||
|
|
||||||
# Clean up state
|
except Exception as e:
|
||||||
harmony_adapter.cleanup_stream_state(request_id)
|
logger.error(f"Failed to create OpenAI streaming response: {e}")
|
||||||
|
logger.debug(f"Streaming error details: {traceback.format_exc()}")
|
||||||
# Send final message with finish_reason
|
# Clean up state
|
||||||
output = generator.outputs[0]
|
harmony_adapter.cleanup_stream_state(request_id)
|
||||||
final_response = ChatCompletionStreamResponse(
|
raise e
|
||||||
model=request.model,
|
|
||||||
choices=[
|
|
||||||
ChatCompletionResponseStreamChoice(
|
|
||||||
index=0,
|
|
||||||
delta=DeltaMessage(),
|
|
||||||
finish_reason=output.finish_reason,
|
|
||||||
stop_reason=output.stop_reason)
|
|
||||||
])
|
|
||||||
|
|
||||||
yield f"data: {final_response.model_dump_json(exclude_unset=True)}\n\n"
|
|
||||||
yield "data: [DONE]\n\n"
|
|
||||||
|
|
||||||
|
|
||||||
async def handle_non_streaming_response(
|
def handle_non_streaming_response(tools: List[ChatCompletionToolsParam],
|
||||||
harmony_adapter: HarmonyAdapter, promise: RequestOutput,
|
tool_choice: str, outputs: List, model: str,
|
||||||
request: ChatCompletionRequest) -> ChatCompletionResponse:
|
num_prompt_tokens: int):
|
||||||
"""Handle non-streaming response with harmony format."""
|
"""Handle non-streaming response with harmony format."""
|
||||||
# Get final result
|
|
||||||
await promise
|
|
||||||
|
|
||||||
# Parse harmony output to OpenAI format
|
# Parse harmony output to OpenAI format
|
||||||
# Convert tools to dictionary format for harmony adapter (standard pattern)
|
# Convert tools to dictionary format for harmony adapter (standard pattern)
|
||||||
tools_dict = None
|
tools_dict = None
|
||||||
if request.tools:
|
harmony_adapter = get_harmony_adapter()
|
||||||
tools_dict = [tool.model_dump() for tool in request.tools]
|
if tools:
|
||||||
|
tools_dict = [tool.model_dump() for tool in tools]
|
||||||
|
|
||||||
# Get tool_choice from request - if "none", don't pass tools to parser
|
# Get tool_choice from request - if "none", don't pass tools to parser
|
||||||
tool_choice = getattr(request, 'tool_choice', None)
|
|
||||||
if tool_choice == "none":
|
if tool_choice == "none":
|
||||||
tools_for_parser = None
|
tools_for_parser = None
|
||||||
else:
|
else:
|
||||||
tools_for_parser = tools_dict
|
tools_for_parser = tools_dict
|
||||||
|
|
||||||
output = promise.outputs[0]
|
output = outputs[0]
|
||||||
parsed_output = harmony_adapter.harmony_output_to_openai(
|
parsed_output = harmony_adapter.harmony_output_to_openai(
|
||||||
output.token_ids, tools_for_parser, tool_choice)
|
output.token_ids, tools_for_parser, tool_choice)
|
||||||
|
|
||||||
@ -1518,11 +1613,11 @@ async def handle_non_streaming_response(
|
|||||||
output.finish_reason)
|
output.finish_reason)
|
||||||
|
|
||||||
# Create usage info from metrics (RequestOutput doesn't have usage in v1)
|
# Create usage info from metrics (RequestOutput doesn't have usage in v1)
|
||||||
usage_info = _create_usage_info(promise)
|
usage_info = _create_usage_info(num_prompt_tokens, outputs)
|
||||||
|
|
||||||
# Create response
|
# Create response
|
||||||
response = ChatCompletionResponse(
|
response = ChatCompletionResponse(
|
||||||
model=request.model,
|
model=model,
|
||||||
choices=[
|
choices=[
|
||||||
ChatCompletionResponseChoice(
|
ChatCompletionResponseChoice(
|
||||||
index=0,
|
index=0,
|
||||||
@ -1534,7 +1629,6 @@ async def handle_non_streaming_response(
|
|||||||
# Optional: Log if harmony parsing failed (for debugging)
|
# Optional: Log if harmony parsing failed (for debugging)
|
||||||
if parsed_output.get('_harmony_parsing_failed'):
|
if parsed_output.get('_harmony_parsing_failed'):
|
||||||
logger.warning("⚠️ Harmony parsing fell back to raw text decoding")
|
logger.warning("⚠️ Harmony parsing fell back to raw text decoding")
|
||||||
logger.debug(f"request\n\n{request}")
|
|
||||||
logger.debug(f"response\n\n{response}\n")
|
logger.debug(f"response\n\n{response}\n")
|
||||||
|
|
||||||
return response
|
return response
|
||||||
@ -1567,15 +1661,10 @@ def _determine_finish_reason(parsed_output: dict[str, Any],
|
|||||||
return reason
|
return reason
|
||||||
|
|
||||||
|
|
||||||
def _create_usage_info(final_res: RequestOutput) -> UsageInfo:
|
def _create_usage_info(num_prompt_tokens, outputs) -> UsageInfo:
|
||||||
"""Create usage info from RequestOutput following serving_chat.py pattern."""
|
"""Create usage info from RequestOutput following serving_chat.py pattern."""
|
||||||
# Calculate prompt tokens from prompt_token_ids and encoder_prompt_token_ids
|
|
||||||
assert final_res.prompt_token_ids is not None
|
|
||||||
num_prompt_tokens = len(final_res.prompt_token_ids)
|
|
||||||
|
|
||||||
# Calculate completion tokens from all outputs
|
# Calculate completion tokens from all outputs
|
||||||
num_generated_tokens = sum(
|
num_generated_tokens = sum(len(output.token_ids) for output in outputs)
|
||||||
len(output.token_ids) for output in final_res.outputs)
|
|
||||||
|
|
||||||
# Create usage info
|
# Create usage info
|
||||||
usage = UsageInfo(prompt_tokens=num_prompt_tokens,
|
usage = UsageInfo(prompt_tokens=num_prompt_tokens,
|
||||||
|
|||||||
@ -11,9 +11,16 @@ from openai.types.chat import \
|
|||||||
ChatCompletionContentPartParam as OpenAIChatCompletionContentPartParam
|
ChatCompletionContentPartParam as OpenAIChatCompletionContentPartParam
|
||||||
from openai.types.chat import \
|
from openai.types.chat import \
|
||||||
ChatCompletionMessageParam as OpenAIChatCompletionMessageParam
|
ChatCompletionMessageParam as OpenAIChatCompletionMessageParam
|
||||||
|
from openai.types.responses import (ResponseFunctionToolCall,
|
||||||
|
ResponseInputItemParam, ResponseOutputItem,
|
||||||
|
ResponsePrompt, ResponseReasoningItem,
|
||||||
|
ResponseStatus, ResponseTextConfig)
|
||||||
|
from openai.types.responses.response import ToolChoice
|
||||||
|
from openai.types.responses.tool import Tool
|
||||||
|
from openai.types.shared import Metadata, Reasoning
|
||||||
from openai_harmony import ReasoningEffort
|
from openai_harmony import ReasoningEffort
|
||||||
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
||||||
from typing_extensions import Annotated, Required, TypedDict
|
from typing_extensions import Annotated, Required, TypeAlias, TypedDict
|
||||||
|
|
||||||
from tensorrt_llm.executor.request import LoRARequest
|
from tensorrt_llm.executor.request import LoRARequest
|
||||||
from tensorrt_llm.llmapi import DisaggregatedParams as LlmDisaggregatedParams
|
from tensorrt_llm.llmapi import DisaggregatedParams as LlmDisaggregatedParams
|
||||||
@ -665,6 +672,208 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
|||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
ResponseInputOutputItem: TypeAlias = Union[ResponseInputItemParam,
|
||||||
|
ResponseReasoningItem,
|
||||||
|
ResponseFunctionToolCall]
|
||||||
|
|
||||||
|
|
||||||
|
class ResponsesRequest(OpenAIBaseModel):
|
||||||
|
# Ordered by official OpenAI API documentation
|
||||||
|
# https://platform.openai.com/docs/api-reference/responses/create
|
||||||
|
background: Optional[bool] = False
|
||||||
|
include: Optional[list[
|
||||||
|
Literal[
|
||||||
|
"code_interpreter_call.outputs",
|
||||||
|
"computer_call_output.output.image_url",
|
||||||
|
"file_search_call.results",
|
||||||
|
"message.input_image.image_url",
|
||||||
|
"message.output_text.logprobs",
|
||||||
|
"reasoning.encrypted_content",
|
||||||
|
],
|
||||||
|
]] = None
|
||||||
|
input: Union[str, list[ResponseInputOutputItem]]
|
||||||
|
instructions: Optional[str] = None
|
||||||
|
max_output_tokens: Optional[int] = None
|
||||||
|
max_tool_calls: Optional[int] = None
|
||||||
|
metadata: Optional[Metadata] = None
|
||||||
|
model: str
|
||||||
|
parallel_tool_calls: Optional[bool] = False
|
||||||
|
previous_response_id: Optional[str] = None
|
||||||
|
prompt: Optional[ResponsePrompt] = None
|
||||||
|
reasoning: Optional[Reasoning] = None
|
||||||
|
service_tier: Literal["auto", "default", "flex", "scale",
|
||||||
|
"priority"] = "auto"
|
||||||
|
store: Optional[bool] = True
|
||||||
|
stream: Optional[bool] = False
|
||||||
|
temperature: Optional[float] = None
|
||||||
|
text: Optional[ResponseTextConfig] = None
|
||||||
|
tool_choice: ToolChoice = "auto"
|
||||||
|
tools: list[Tool] = Field(default_factory=list)
|
||||||
|
top_logprobs: Optional[int] = 0
|
||||||
|
top_p: Optional[float] = None
|
||||||
|
truncation: Optional[Literal["auto", "disabled"]] = "disabled"
|
||||||
|
user: Optional[str] = None
|
||||||
|
|
||||||
|
request_id: str = Field(
|
||||||
|
default_factory=lambda: f"resp_{str(uuid.uuid4().hex)}",
|
||||||
|
description=(
|
||||||
|
"The request_id related to this request. If the caller does "
|
||||||
|
"not set it, a random_uuid will be generated. This id is used "
|
||||||
|
"through out the inference process and return in response."),
|
||||||
|
)
|
||||||
|
|
||||||
|
_DEFAULT_SAMPLING_PARAMS = {
|
||||||
|
"temperature": 1.0,
|
||||||
|
"top_p": 1.0,
|
||||||
|
}
|
||||||
|
|
||||||
|
def to_sampling_params(
|
||||||
|
self,
|
||||||
|
default_max_tokens: int,
|
||||||
|
default_sampling_params: Optional[dict] = None,
|
||||||
|
) -> SamplingParams:
|
||||||
|
if self.max_output_tokens is None:
|
||||||
|
max_tokens = default_max_tokens
|
||||||
|
else:
|
||||||
|
max_tokens = min(self.max_output_tokens, default_max_tokens)
|
||||||
|
|
||||||
|
default_sampling_params = default_sampling_params or {}
|
||||||
|
if (temperature := self.temperature) is None:
|
||||||
|
temperature = default_sampling_params.get(
|
||||||
|
"temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"])
|
||||||
|
if (top_p := self.top_p) is None:
|
||||||
|
top_p = default_sampling_params.get(
|
||||||
|
"top_p", self._DEFAULT_SAMPLING_PARAMS["top_p"])
|
||||||
|
stop_token_ids = default_sampling_params.get("stop_token_ids")
|
||||||
|
|
||||||
|
# Structured output
|
||||||
|
guided_decoding = None
|
||||||
|
if self.text is not None and self.text.format is not None:
|
||||||
|
response_format = self.text.format
|
||||||
|
if response_format.type == "json_schema":
|
||||||
|
guided_decoding = GuidedDecodingParams(
|
||||||
|
json=response_format.schema_)
|
||||||
|
elif response_format.type == "json_object":
|
||||||
|
raise NotImplementedError("json_object is not supported")
|
||||||
|
|
||||||
|
return SamplingParams(
|
||||||
|
temperature=temperature,
|
||||||
|
top_p=top_p,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
logprobs=self.top_logprobs,
|
||||||
|
stop_token_ids=stop_token_ids,
|
||||||
|
guided_decoding=guided_decoding,
|
||||||
|
)
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def validate_background(cls, data):
|
||||||
|
if not data.get("background"):
|
||||||
|
return data
|
||||||
|
if not data.get("store", True):
|
||||||
|
raise ValueError("background can only be used when `store` is true")
|
||||||
|
return data
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def validate_prompt(cls, data):
|
||||||
|
if data.get("prompt") is not None:
|
||||||
|
raise ValueError("prompt template is not supported")
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
class InputTokensDetails(OpenAIBaseModel):
|
||||||
|
cached_tokens: int
|
||||||
|
|
||||||
|
|
||||||
|
class OutputTokensDetails(OpenAIBaseModel):
|
||||||
|
reasoning_tokens: int
|
||||||
|
|
||||||
|
|
||||||
|
class ResponseUsage(OpenAIBaseModel):
|
||||||
|
input_tokens: int
|
||||||
|
input_tokens_details: InputTokensDetails
|
||||||
|
output_tokens: int
|
||||||
|
output_tokens_details: OutputTokensDetails
|
||||||
|
total_tokens: int
|
||||||
|
|
||||||
|
|
||||||
|
class ResponsesResponse(OpenAIBaseModel):
|
||||||
|
id: str = Field(default_factory=lambda: f"resp_{str(uuid.uuid4().hex)}")
|
||||||
|
created_at: int = Field(default_factory=lambda: int(time.time()))
|
||||||
|
# error: Optional[ResponseError] = None
|
||||||
|
# incomplete_details: Optional[IncompleteDetails] = None
|
||||||
|
instructions: Optional[str] = None
|
||||||
|
metadata: Optional[Metadata] = None
|
||||||
|
model: str
|
||||||
|
object: Literal["response"] = "response"
|
||||||
|
output: list[ResponseOutputItem]
|
||||||
|
parallel_tool_calls: bool
|
||||||
|
temperature: float
|
||||||
|
tool_choice: ToolChoice
|
||||||
|
tools: list[Tool]
|
||||||
|
top_p: float
|
||||||
|
background: bool
|
||||||
|
max_output_tokens: int
|
||||||
|
max_tool_calls: Optional[int] = None
|
||||||
|
previous_response_id: Optional[str] = None
|
||||||
|
prompt: Optional[ResponsePrompt] = None
|
||||||
|
reasoning: Optional[Reasoning] = None
|
||||||
|
service_tier: Literal["auto", "default", "flex", "scale", "priority"]
|
||||||
|
status: ResponseStatus
|
||||||
|
text: Optional[ResponseTextConfig] = None
|
||||||
|
top_logprobs: int
|
||||||
|
truncation: Literal["auto", "disabled"]
|
||||||
|
usage: Optional[ResponseUsage] = None
|
||||||
|
user: Optional[str] = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_request(
|
||||||
|
cls,
|
||||||
|
request: ResponsesRequest,
|
||||||
|
sampling_params: SamplingParams,
|
||||||
|
model_name: str,
|
||||||
|
created_time: int,
|
||||||
|
output: list[ResponseOutputItem],
|
||||||
|
status: ResponseStatus,
|
||||||
|
usage: Optional[ResponseUsage] = None,
|
||||||
|
) -> "ResponsesResponse":
|
||||||
|
return cls(
|
||||||
|
id=request.request_id,
|
||||||
|
created_at=created_time,
|
||||||
|
instructions=request.instructions,
|
||||||
|
metadata=request.metadata,
|
||||||
|
model=model_name,
|
||||||
|
output=output,
|
||||||
|
parallel_tool_calls=request.parallel_tool_calls,
|
||||||
|
temperature=sampling_params.temperature,
|
||||||
|
tool_choice=request.tool_choice,
|
||||||
|
tools=request.tools,
|
||||||
|
top_p=sampling_params.top_p,
|
||||||
|
background=request.background,
|
||||||
|
max_output_tokens=sampling_params.max_tokens,
|
||||||
|
max_tool_calls=request.max_tool_calls,
|
||||||
|
previous_response_id=request.previous_response_id,
|
||||||
|
prompt=request.prompt,
|
||||||
|
reasoning=request.reasoning,
|
||||||
|
service_tier=request.service_tier,
|
||||||
|
status=status,
|
||||||
|
text=request.text,
|
||||||
|
top_logprobs=sampling_params.logprobs,
|
||||||
|
truncation=request.truncation,
|
||||||
|
user=request.user,
|
||||||
|
usage=usage,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ResponsesStreamResponse(OpenAIBaseModel):
|
||||||
|
response: ResponsesResponse
|
||||||
|
sequence_number: int
|
||||||
|
type: Literal["response.created", "response.in_progress",
|
||||||
|
"response.completed", "response.failed",
|
||||||
|
"response.incomplete"]
|
||||||
|
|
||||||
|
|
||||||
def encode_opaque_state(opaque_state: Optional[bytes]) -> Optional[str]:
|
def encode_opaque_state(opaque_state: Optional[bytes]) -> Optional[str]:
|
||||||
if opaque_state is None:
|
if opaque_state is None:
|
||||||
return None
|
return None
|
||||||
|
|||||||
@ -41,17 +41,25 @@ from tensorrt_llm.serve.openai_protocol import (ChatCompletionRequest,
|
|||||||
CompletionResponse,
|
CompletionResponse,
|
||||||
CompletionResponseChoice,
|
CompletionResponseChoice,
|
||||||
ErrorResponse, ModelCard,
|
ErrorResponse, ModelCard,
|
||||||
ModelList, UsageInfo,
|
ModelList, ResponsesRequest,
|
||||||
|
UsageInfo,
|
||||||
to_llm_disaggregated_params)
|
to_llm_disaggregated_params)
|
||||||
from tensorrt_llm.serve.postprocess_handlers import (
|
from tensorrt_llm.serve.postprocess_handlers import (
|
||||||
ChatPostprocArgs, CompletionPostprocArgs, chat_response_post_processor,
|
ChatCompletionPostprocArgs, ChatPostprocArgs, CompletionPostprocArgs,
|
||||||
chat_stream_post_processor, completion_response_post_processor,
|
chat_harmony_post_processor, chat_harmony_streaming_post_processor,
|
||||||
completion_stream_post_processor)
|
chat_response_post_processor, chat_stream_post_processor,
|
||||||
|
completion_response_post_processor, completion_stream_post_processor)
|
||||||
|
from tensorrt_llm.serve.responses_utils import ConversationHistoryStore
|
||||||
|
from tensorrt_llm.serve.responses_utils import \
|
||||||
|
create_response as responses_api_create_response
|
||||||
|
from tensorrt_llm.serve.responses_utils import \
|
||||||
|
process_streaming_events as responses_api_process_streaming_events
|
||||||
|
from tensorrt_llm.serve.responses_utils import \
|
||||||
|
request_preprocess as responses_api_request_preprocess
|
||||||
from tensorrt_llm.version import __version__ as VERSION
|
from tensorrt_llm.version import __version__ as VERSION
|
||||||
|
|
||||||
from .._utils import nvtx_mark, set_prometheus_multiproc_dir
|
from .._utils import nvtx_mark, set_prometheus_multiproc_dir
|
||||||
from .harmony_adapter import (HarmonyAdapter, handle_non_streaming_response,
|
from .harmony_adapter import (HarmonyAdapter, get_harmony_adapter,
|
||||||
handle_streaming_response,
|
|
||||||
maybe_transform_reasoning_effort)
|
maybe_transform_reasoning_effort)
|
||||||
|
|
||||||
# yapf: enale
|
# yapf: enale
|
||||||
@ -83,6 +91,12 @@ class OpenAIServer:
|
|||||||
logger.debug("Failed to load AutoConfig for %s", hf_tokenizer_path)
|
logger.debug("Failed to load AutoConfig for %s", hf_tokenizer_path)
|
||||||
self.model_config = None
|
self.model_config = None
|
||||||
|
|
||||||
|
# Enable response storage for Responses API
|
||||||
|
self.enable_store = True
|
||||||
|
if len(os.getenv("TRTLLM_RESPONSES_API_DISABLE_STORE", "")) > 0:
|
||||||
|
self.enable_store = False
|
||||||
|
self.conversation_store = ConversationHistoryStore()
|
||||||
|
|
||||||
model_dir = Path(model)
|
model_dir = Path(model)
|
||||||
if model_dir.exists() and model_dir.is_dir():
|
if model_dir.exists() and model_dir.is_dir():
|
||||||
self.model = model_dir.name
|
self.model = model_dir.name
|
||||||
@ -104,7 +118,11 @@ class OpenAIServer:
|
|||||||
|
|
||||||
# gpt-oss
|
# gpt-oss
|
||||||
self.harmony_adapter: HarmonyAdapter | None = None
|
self.harmony_adapter: HarmonyAdapter | None = None
|
||||||
self.use_harmony = self.model_config.model_type == "gpt_oss"
|
disable_harmony = os.getenv("DISABLE_HARMONY_ADAPTER", "0") == "1"
|
||||||
|
if disable_harmony:
|
||||||
|
self.use_harmony = False
|
||||||
|
else:
|
||||||
|
self.use_harmony = (self.model_config.model_type == "gpt_oss")
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def lifespan(app: FastAPI):
|
async def lifespan(app: FastAPI):
|
||||||
@ -166,6 +184,20 @@ class OpenAIServer:
|
|||||||
return JSONResponse(content=error_response.model_dump(),
|
return JSONResponse(content=error_response.model_dump(),
|
||||||
status_code=error_response.code)
|
status_code=error_response.code)
|
||||||
|
|
||||||
|
def _create_invalid_response_id_error(self, response_id: str) -> Response:
|
||||||
|
return self.create_error_response(
|
||||||
|
err_type="InvalidRequestError",
|
||||||
|
message=(f"Invalid 'response_id': '{response_id}'. "
|
||||||
|
"Expected an ID that begins with 'resp'."),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _create_response_id_not_found_error(self, response_id: str) -> Response:
|
||||||
|
return self.create_error_response(
|
||||||
|
err_type="InvalidRequestError",
|
||||||
|
message=f"Response with id '{response_id}' not found.",
|
||||||
|
status_code=HTTPStatus.NOT_FOUND,
|
||||||
|
)
|
||||||
|
|
||||||
def register_routes(self):
|
def register_routes(self):
|
||||||
self.app.add_api_route("/health", self.health, methods=["GET"])
|
self.app.add_api_route("/health", self.health, methods=["GET"])
|
||||||
self.app.add_api_route("/health_generate", self.health_generate, methods=["GET"])
|
self.app.add_api_route("/health_generate", self.health_generate, methods=["GET"])
|
||||||
@ -182,6 +214,9 @@ class OpenAIServer:
|
|||||||
self.app.add_api_route("/v1/chat/completions",
|
self.app.add_api_route("/v1/chat/completions",
|
||||||
self.openai_chat if not self.use_harmony else self.chat_harmony,
|
self.openai_chat if not self.use_harmony else self.chat_harmony,
|
||||||
methods=["POST"])
|
methods=["POST"])
|
||||||
|
self.app.add_api_route("/v1/responses",
|
||||||
|
self.openai_responses,
|
||||||
|
methods=["POST"])
|
||||||
if self.llm.args.return_perf_metrics:
|
if self.llm.args.return_perf_metrics:
|
||||||
# register /prometheus/metrics
|
# register /prometheus/metrics
|
||||||
self.mount_metrics()
|
self.mount_metrics()
|
||||||
@ -681,11 +716,35 @@ class OpenAIServer:
|
|||||||
Chat Completion API with harmony format support.
|
Chat Completion API with harmony format support.
|
||||||
Supports both streaming and non-streaming modes.
|
Supports both streaming and non-streaming modes.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
async def create_harmony_response(
|
||||||
|
promise: RequestOutput, postproc_params: PostprocParams) -> ChatCompletionResponse:
|
||||||
|
await promise.aresult()
|
||||||
|
if self.postproc_worker_enabled:
|
||||||
|
chat_response =promise.outputs[0]._postprocess_result
|
||||||
|
else:
|
||||||
|
post_processor, args = postproc_params.post_processor, postproc_params.postproc_args
|
||||||
|
chat_response = post_processor(promise, args)
|
||||||
|
|
||||||
|
return chat_response
|
||||||
|
|
||||||
|
async def create_streaming_generator(promise: RequestOutput, postproc_params: PostprocParams):
|
||||||
|
if not self.postproc_worker_enabled:
|
||||||
|
post_processor, args = postproc_params.post_processor, postproc_params.postproc_args
|
||||||
|
|
||||||
|
async for res in promise:
|
||||||
|
pp_results = res.outputs[0]._postprocess_result if self.postproc_worker_enabled else post_processor(res, args)
|
||||||
|
# await self._extract_metrics(res)
|
||||||
|
for pp_res in pp_results:
|
||||||
|
yield pp_res
|
||||||
|
|
||||||
|
yield "data: [DONE]\n\n"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Initialize HarmonyAdapter
|
# Initialize HarmonyAdapter
|
||||||
# NOTE: WAR for Disagg failure, may affect perf if no warmup
|
# NOTE: WAR for Disagg failure, may affect perf if no warmup
|
||||||
if not self.harmony_adapter:
|
if not self.harmony_adapter:
|
||||||
self.harmony_adapter = HarmonyAdapter()
|
self.harmony_adapter = get_harmony_adapter()
|
||||||
# Convert Pydantic models to dictionaries for JSON serialization (standard pattern)
|
# Convert Pydantic models to dictionaries for JSON serialization (standard pattern)
|
||||||
tools_dict = None
|
tools_dict = None
|
||||||
if request.tools:
|
if request.tools:
|
||||||
@ -720,27 +779,37 @@ class OpenAIServer:
|
|||||||
vocab_size=self.tokenizer.tokenizer.vocab_size)
|
vocab_size=self.tokenizer.tokenizer.vocab_size)
|
||||||
sampling_params.detokenize = False # Harmony adapter handles detokenization
|
sampling_params.detokenize = False # Harmony adapter handles detokenization
|
||||||
|
|
||||||
|
postproc_args = ChatCompletionPostprocArgs.from_request(request)
|
||||||
|
postproc_params = PostprocParams(
|
||||||
|
post_processor=chat_harmony_streaming_post_processor
|
||||||
|
if request.stream else chat_harmony_post_processor,
|
||||||
|
postproc_args=postproc_args,
|
||||||
|
)
|
||||||
|
|
||||||
# Generate
|
# Generate
|
||||||
promise = self.llm.generate_async(
|
promise = self.llm.generate_async(
|
||||||
inputs=harmony_tokens,
|
inputs=harmony_tokens,
|
||||||
sampling_params=sampling_params,
|
sampling_params=sampling_params,
|
||||||
|
_postproc_params=postproc_params if self.postproc_worker_enabled else None,
|
||||||
streaming=bool(request.stream),
|
streaming=bool(request.stream),
|
||||||
lora_request=request.lora_request,
|
lora_request=request.lora_request,
|
||||||
)
|
)
|
||||||
|
postproc_args.request_id = promise.request_id
|
||||||
|
|
||||||
|
if not self.postproc_worker_enabled:
|
||||||
|
postproc_args.num_prompt_tokens = len(promise.prompt_token_ids)
|
||||||
|
|
||||||
# Disconnect cancellation
|
# Disconnect cancellation
|
||||||
asyncio.create_task(self.await_disconnected(raw_request, promise))
|
asyncio.create_task(self.await_disconnected(raw_request, promise))
|
||||||
|
|
||||||
# Handle streaming
|
# Handle streaming
|
||||||
if request.stream:
|
if request.stream:
|
||||||
return StreamingResponse(
|
return StreamingResponse(
|
||||||
handle_streaming_response(
|
content=create_streaming_generator(promise, postproc_params),
|
||||||
self.harmony_adapter, promise,
|
|
||||||
str(promise.request_id), request,
|
|
||||||
),
|
|
||||||
media_type="text/event-stream"
|
media_type="text/event-stream"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
response = await handle_non_streaming_response(self.harmony_adapter, promise, request)
|
response = await create_harmony_response(promise, postproc_params)
|
||||||
return JSONResponse(response.model_dump())
|
return JSONResponse(response.model_dump())
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -748,6 +817,80 @@ class OpenAIServer:
|
|||||||
logger.debug("Error details: %s", traceback.format_exc())
|
logger.debug("Error details: %s", traceback.format_exc())
|
||||||
return self.create_error_response(message=str(e), err_type="internal_error")
|
return self.create_error_response(message=str(e), err_type="internal_error")
|
||||||
|
|
||||||
|
async def openai_responses(self, request: ResponsesRequest, raw_request: Request) -> Response:
|
||||||
|
async def create_stream_response(generator, request: ResponsesRequest, sampling_params) -> AsyncGenerator[str, None]:
|
||||||
|
async for event_data in responses_api_process_streaming_events(
|
||||||
|
request=request,
|
||||||
|
sampling_params=sampling_params,
|
||||||
|
generator=generator,
|
||||||
|
harmony_adapter=self.harmony_adapter,
|
||||||
|
model_name=self.model,
|
||||||
|
conversation_store=self.conversation_store,
|
||||||
|
enable_store=self.enable_store
|
||||||
|
):
|
||||||
|
yield event_data
|
||||||
|
|
||||||
|
try:
|
||||||
|
if not self.use_harmony:
|
||||||
|
raise NotImplementedError("Responses API only supports harmony format for now")
|
||||||
|
|
||||||
|
# Initialize HarmonyAdapter
|
||||||
|
# NOTE: WAR for Disagg failure, may affect perf if no warmup
|
||||||
|
if not self.harmony_adapter:
|
||||||
|
self.harmony_adapter = HarmonyAdapter()
|
||||||
|
|
||||||
|
if request.background:
|
||||||
|
logger.warning("Request.background is not supported yet, will fallback to foreground processing.")
|
||||||
|
|
||||||
|
# Get prev response
|
||||||
|
prev_response = None
|
||||||
|
if self.enable_store:
|
||||||
|
prev_response_id = request.previous_response_id
|
||||||
|
if prev_response_id is not None:
|
||||||
|
if not prev_response_id.startswith("resp_"):
|
||||||
|
return self._create_invalid_response_id_error(prev_response_id)
|
||||||
|
|
||||||
|
prev_response = await self.conversation_store.load_response(prev_response_id)
|
||||||
|
if prev_response is None:
|
||||||
|
logger.debug(f"response_id {prev_response_id} not found")
|
||||||
|
return self._create_response_id_not_found_error(prev_response_id)
|
||||||
|
|
||||||
|
input_tokens, sampling_params = await responses_api_request_preprocess(
|
||||||
|
request, prev_response, self.harmony_adapter, self.conversation_store, self.enable_store)
|
||||||
|
|
||||||
|
promise = self.llm.generate_async(
|
||||||
|
inputs=input_tokens,
|
||||||
|
sampling_params=sampling_params,
|
||||||
|
streaming=request.stream,
|
||||||
|
)
|
||||||
|
|
||||||
|
asyncio.create_task(self.await_disconnected(raw_request, promise))
|
||||||
|
|
||||||
|
if request.stream:
|
||||||
|
return StreamingResponse(
|
||||||
|
create_stream_response(promise, request, sampling_params),
|
||||||
|
media_type="text/event-stream"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return await responses_api_create_response(
|
||||||
|
generator=promise,
|
||||||
|
request=request,
|
||||||
|
sampling_params=sampling_params,
|
||||||
|
model_name=self.model,
|
||||||
|
conversation_store=self.conversation_store,
|
||||||
|
generation_result=None,
|
||||||
|
enable_store=self.enable_store)
|
||||||
|
except CppExecutorError:
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
|
# If internal executor error is raised, shutdown the server
|
||||||
|
signal.raise_signal(signal.SIGINT)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
|
return self.create_error_response(str(e))
|
||||||
|
|
||||||
|
return JSONResponse(content={"detail": "None"})
|
||||||
|
|
||||||
|
|
||||||
async def __call__(self, host, port):
|
async def __call__(self, host, port):
|
||||||
# Store the binding address for server registration
|
# Store the binding address for server registration
|
||||||
self.binding_addr = f"http://{host}:{port}"
|
self.binding_addr = f"http://{host}:{port}"
|
||||||
|
|||||||
@ -9,6 +9,8 @@ from ..llmapi.reasoning_parser import (BaseReasoningParser,
|
|||||||
ReasoningParserFactory)
|
ReasoningParserFactory)
|
||||||
from ..llmapi.tokenizer import TransformersTokenizer
|
from ..llmapi.tokenizer import TransformersTokenizer
|
||||||
# yapf: disable
|
# yapf: disable
|
||||||
|
from .harmony_adapter import (handle_non_streaming_response,
|
||||||
|
handle_streaming_response)
|
||||||
from .openai_protocol import (ChatCompletionLogProbs,
|
from .openai_protocol import (ChatCompletionLogProbs,
|
||||||
ChatCompletionLogProbsContent,
|
ChatCompletionLogProbsContent,
|
||||||
ChatCompletionNamedToolChoiceParam,
|
ChatCompletionNamedToolChoiceParam,
|
||||||
@ -24,7 +26,8 @@ from .openai_protocol import (ChatCompletionLogProbs,
|
|||||||
FunctionCall, StreamOptions, ToolCall, UsageInfo,
|
FunctionCall, StreamOptions, ToolCall, UsageInfo,
|
||||||
to_disaggregated_params)
|
to_disaggregated_params)
|
||||||
|
|
||||||
# yapf: enale
|
# yapf: enable
|
||||||
|
|
||||||
|
|
||||||
@dataclass(kw_only=True)
|
@dataclass(kw_only=True)
|
||||||
class ChatPostprocArgs(PostprocArgs):
|
class ChatPostprocArgs(PostprocArgs):
|
||||||
@ -57,8 +60,7 @@ class ChatPostprocArgs(PostprocArgs):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def create_logprobs(token_ids: List[int],
|
def create_logprobs(token_ids: List[int], tokenizer: TransformersTokenizer,
|
||||||
tokenizer: TransformersTokenizer,
|
|
||||||
logprobs: List[float]) -> ChatCompletionLogProbs:
|
logprobs: List[float]) -> ChatCompletionLogProbs:
|
||||||
assert len(token_ids) == len(logprobs), \
|
assert len(token_ids) == len(logprobs), \
|
||||||
"token_ids and logprobs have different lengths"
|
"token_ids and logprobs have different lengths"
|
||||||
@ -75,12 +77,14 @@ def create_logprobs(token_ids: List[int],
|
|||||||
return chat_logprobs
|
return chat_logprobs
|
||||||
|
|
||||||
|
|
||||||
def apply_reasoning_parser(args: ChatPostprocArgs, output_index: int, text: str, streaming: bool) -> Tuple[bool, str, str]:
|
def apply_reasoning_parser(args: ChatPostprocArgs, output_index: int, text: str,
|
||||||
|
streaming: bool) -> Tuple[bool, str, str]:
|
||||||
reasoning_parser = None
|
reasoning_parser = None
|
||||||
if args.reasoning_parser is not None:
|
if args.reasoning_parser is not None:
|
||||||
if output_index not in args.reasoning_parser_dict:
|
if output_index not in args.reasoning_parser_dict:
|
||||||
args.reasoning_parser_dict[output_index] = ReasoningParserFactory.create_reasoning_parser(
|
args.reasoning_parser_dict[
|
||||||
args.reasoning_parser)
|
output_index] = ReasoningParserFactory.create_reasoning_parser(
|
||||||
|
args.reasoning_parser)
|
||||||
reasoning_parser = args.reasoning_parser_dict[output_index]
|
reasoning_parser = args.reasoning_parser_dict[output_index]
|
||||||
|
|
||||||
in_reasoning = False
|
in_reasoning = False
|
||||||
@ -97,7 +101,8 @@ def apply_reasoning_parser(args: ChatPostprocArgs, output_index: int, text: str,
|
|||||||
|
|
||||||
|
|
||||||
@nvtx_range_debug("chat_stream_post_processor")
|
@nvtx_range_debug("chat_stream_post_processor")
|
||||||
def chat_stream_post_processor(rsp: GenerationResultBase, args: ChatPostprocArgs) -> List[str]:
|
def chat_stream_post_processor(rsp: GenerationResultBase,
|
||||||
|
args: ChatPostprocArgs) -> List[str]:
|
||||||
|
|
||||||
def yield_first_chat(num_tokens: int,
|
def yield_first_chat(num_tokens: int,
|
||||||
idx: int,
|
idx: int,
|
||||||
@ -128,9 +133,13 @@ def chat_stream_post_processor(rsp: GenerationResultBase, args: ChatPostprocArgs
|
|||||||
include_continuous_usage = False
|
include_continuous_usage = False
|
||||||
if args.first_iteration:
|
if args.first_iteration:
|
||||||
for i in range(args.num_choices):
|
for i in range(args.num_choices):
|
||||||
res.append(f"data: {yield_first_chat(prompt_tokens, i, role=args.role)} \n\n")
|
res.append(
|
||||||
|
f"data: {yield_first_chat(prompt_tokens, i, role=args.role)} \n\n"
|
||||||
|
)
|
||||||
if args.echo and args.last_message_content:
|
if args.echo and args.last_message_content:
|
||||||
res.append(f"data: {yield_first_chat(prompt_tokens, i, content=args.last_message_content)} \n\n")
|
res.append(
|
||||||
|
f"data: {yield_first_chat(prompt_tokens, i, content=args.last_message_content)} \n\n"
|
||||||
|
)
|
||||||
args.first_iteration = False
|
args.first_iteration = False
|
||||||
|
|
||||||
for output in rsp.outputs:
|
for output in rsp.outputs:
|
||||||
@ -158,14 +167,18 @@ def chat_stream_post_processor(rsp: GenerationResultBase, args: ChatPostprocArgs
|
|||||||
delta_message = DeltaMessage(
|
delta_message = DeltaMessage(
|
||||||
content=delta_text, reasoning_content=reasoning_delta_text)
|
content=delta_text, reasoning_content=reasoning_delta_text)
|
||||||
|
|
||||||
choice = ChatCompletionResponseStreamChoice(index=i,
|
choice = ChatCompletionResponseStreamChoice(
|
||||||
delta=delta_message,
|
index=i,
|
||||||
finish_reason=None,
|
delta=delta_message,
|
||||||
avg_decoded_tokens_per_iter=getattr(rsp, 'avg_decoded_tokens_per_iter', None))
|
finish_reason=None,
|
||||||
|
avg_decoded_tokens_per_iter=getattr(rsp,
|
||||||
|
'avg_decoded_tokens_per_iter',
|
||||||
|
None))
|
||||||
if args.return_logprobs:
|
if args.return_logprobs:
|
||||||
logprobs = output.logprobs_diff
|
logprobs = output.logprobs_diff
|
||||||
token_ids = output.token_ids_diff
|
token_ids = output.token_ids_diff
|
||||||
choice.logprobs = create_logprobs(token_ids, args.tokenizer, logprobs)
|
choice.logprobs = create_logprobs(token_ids, args.tokenizer,
|
||||||
|
logprobs)
|
||||||
if output.finish_reason is not None:
|
if output.finish_reason is not None:
|
||||||
choice.finish_reason = output.finish_reason
|
choice.finish_reason = output.finish_reason
|
||||||
choice.stop_reason = output.stop_reason
|
choice.stop_reason = output.stop_reason
|
||||||
@ -179,57 +192,62 @@ def chat_stream_post_processor(rsp: GenerationResultBase, args: ChatPostprocArgs
|
|||||||
res.append(f"data: {data}\n\n")
|
res.append(f"data: {data}\n\n")
|
||||||
|
|
||||||
if include_usage and rsp._done:
|
if include_usage and rsp._done:
|
||||||
completion_tokens = sum(output.length
|
completion_tokens = sum(output.length for output in rsp.outputs)
|
||||||
for output in rsp.outputs)
|
|
||||||
final_usage = UsageInfo(
|
final_usage = UsageInfo(
|
||||||
prompt_tokens=prompt_tokens,
|
prompt_tokens=prompt_tokens,
|
||||||
completion_tokens=completion_tokens,
|
completion_tokens=completion_tokens,
|
||||||
total_tokens=prompt_tokens + completion_tokens,
|
total_tokens=prompt_tokens + completion_tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
final_usage_chunk = ChatCompletionStreamResponse(
|
final_usage_chunk = ChatCompletionStreamResponse(choices=[],
|
||||||
choices=[], model=args.model, usage=final_usage)
|
model=args.model,
|
||||||
|
usage=final_usage)
|
||||||
final_usage_data = final_usage_chunk.model_dump_json()
|
final_usage_data = final_usage_chunk.model_dump_json()
|
||||||
res.append(f"data: {final_usage_data}\n\n")
|
res.append(f"data: {final_usage_data}\n\n")
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
|
||||||
@nvtx_range_debug("chat_response_post_processor")
|
@nvtx_range_debug("chat_response_post_processor")
|
||||||
def chat_response_post_processor(rsp: GenerationResultBase, args: ChatPostprocArgs) -> ChatCompletionResponse:
|
def chat_response_post_processor(
|
||||||
|
rsp: GenerationResultBase,
|
||||||
|
args: ChatPostprocArgs) -> ChatCompletionResponse:
|
||||||
choices: List[ChatCompletionResponseChoice] = []
|
choices: List[ChatCompletionResponseChoice] = []
|
||||||
role = args.role
|
role = args.role
|
||||||
for output in rsp.outputs:
|
for output in rsp.outputs:
|
||||||
_, text, reasoning_text = apply_reasoning_parser(
|
_, text, reasoning_text = apply_reasoning_parser(
|
||||||
args, output.index, output.text, False)
|
args, output.index, output.text, False)
|
||||||
|
|
||||||
if args.tool_choice and isinstance(
|
if args.tool_choice and isinstance(args.tool_choice,
|
||||||
args.tool_choice,
|
ChatCompletionNamedToolChoiceParam):
|
||||||
ChatCompletionNamedToolChoiceParam):
|
|
||||||
message = ChatMessage(
|
message = ChatMessage(
|
||||||
role=role,
|
role=role,
|
||||||
content="",
|
content="",
|
||||||
tool_calls=[
|
tool_calls=[
|
||||||
ToolCall(function=FunctionCall(
|
ToolCall(function=FunctionCall(
|
||||||
name=args.tool_choice.function.name,
|
name=args.tool_choice.function.name, arguments=text))
|
||||||
arguments=text))
|
|
||||||
])
|
])
|
||||||
else:
|
else:
|
||||||
if text is None:
|
if text is None:
|
||||||
text = ""
|
text = ""
|
||||||
message = ChatMessage(
|
message = ChatMessage(role=role,
|
||||||
role=role, content=text, reasoning_content=reasoning_text)
|
content=text,
|
||||||
disaggregated_params = to_disaggregated_params(output.disaggregated_params)
|
reasoning_content=reasoning_text)
|
||||||
|
disaggregated_params = to_disaggregated_params(
|
||||||
|
output.disaggregated_params)
|
||||||
choice = ChatCompletionResponseChoice(
|
choice = ChatCompletionResponseChoice(
|
||||||
index=output.index,
|
index=output.index,
|
||||||
message=message,
|
message=message,
|
||||||
finish_reason=output.finish_reason,
|
finish_reason=output.finish_reason,
|
||||||
stop_reason=output.stop_reason,
|
stop_reason=output.stop_reason,
|
||||||
disaggregated_params=disaggregated_params,
|
disaggregated_params=disaggregated_params,
|
||||||
avg_decoded_tokens_per_iter=getattr(rsp, 'avg_decoded_tokens_per_iter', None),
|
avg_decoded_tokens_per_iter=getattr(rsp,
|
||||||
|
'avg_decoded_tokens_per_iter',
|
||||||
|
None),
|
||||||
)
|
)
|
||||||
|
|
||||||
if args.return_logprobs:
|
if args.return_logprobs:
|
||||||
choice.logprobs = create_logprobs(output.token_ids, args.tokenizer, output.logprobs)
|
choice.logprobs = create_logprobs(output.token_ids, args.tokenizer,
|
||||||
|
output.logprobs)
|
||||||
choices.append(choice)
|
choices.append(choice)
|
||||||
|
|
||||||
if args.echo and args.last_message_content:
|
if args.echo and args.last_message_content:
|
||||||
@ -238,8 +256,7 @@ def chat_response_post_processor(rsp: GenerationResultBase, args: ChatPostprocAr
|
|||||||
choice.message.content = full_message
|
choice.message.content = full_message
|
||||||
|
|
||||||
num_prompt_tokens = args.num_prompt_tokens
|
num_prompt_tokens = args.num_prompt_tokens
|
||||||
num_generated_tokens = sum(
|
num_generated_tokens = sum(len(output.token_ids) for output in rsp.outputs)
|
||||||
len(output.token_ids) for output in rsp.outputs)
|
|
||||||
usage = UsageInfo(
|
usage = UsageInfo(
|
||||||
prompt_tokens=num_prompt_tokens,
|
prompt_tokens=num_prompt_tokens,
|
||||||
completion_tokens=num_generated_tokens,
|
completion_tokens=num_generated_tokens,
|
||||||
@ -275,7 +292,8 @@ class CompletionPostprocArgs(PostprocArgs):
|
|||||||
|
|
||||||
|
|
||||||
@nvtx_range_debug("completion_stream_post_processor")
|
@nvtx_range_debug("completion_stream_post_processor")
|
||||||
def completion_stream_post_processor(rsp: DetokenizedGenerationResultBase, args: CompletionPostprocArgs) -> List[str]:
|
def completion_stream_post_processor(rsp: DetokenizedGenerationResultBase,
|
||||||
|
args: CompletionPostprocArgs) -> List[str]:
|
||||||
res: List[str] = []
|
res: List[str] = []
|
||||||
prompt_tokens = args.num_prompt_tokens
|
prompt_tokens = args.num_prompt_tokens
|
||||||
if stream_option := args.stream_options:
|
if stream_option := args.stream_options:
|
||||||
@ -293,9 +311,11 @@ def completion_stream_post_processor(rsp: DetokenizedGenerationResultBase, args:
|
|||||||
index=args.prompt_idx * args.num_choices + output.index,
|
index=args.prompt_idx * args.num_choices + output.index,
|
||||||
text=delta_text if args.detokenize else "",
|
text=delta_text if args.detokenize else "",
|
||||||
token_ids=None if args.detokenize else output.token_ids_diff,
|
token_ids=None if args.detokenize else output.token_ids_diff,
|
||||||
finish_reason = output.finish_reason,
|
finish_reason=output.finish_reason,
|
||||||
stop_reason = output.stop_reason,
|
stop_reason=output.stop_reason,
|
||||||
avg_decoded_tokens_per_iter=getattr(rsp, 'avg_decoded_tokens_per_iter', None),
|
avg_decoded_tokens_per_iter=getattr(rsp,
|
||||||
|
'avg_decoded_tokens_per_iter',
|
||||||
|
None),
|
||||||
)
|
)
|
||||||
chunk = CompletionStreamResponse(model=args.model, choices=[choice])
|
chunk = CompletionStreamResponse(model=args.model, choices=[choice])
|
||||||
if include_continuous_usage:
|
if include_continuous_usage:
|
||||||
@ -306,16 +326,16 @@ def completion_stream_post_processor(rsp: DetokenizedGenerationResultBase, args:
|
|||||||
res.append(f"data: {data}\n\n")
|
res.append(f"data: {data}\n\n")
|
||||||
|
|
||||||
if include_usage and rsp._done:
|
if include_usage and rsp._done:
|
||||||
completion_tokens = sum(output.length
|
completion_tokens = sum(output.length for output in rsp.outputs)
|
||||||
for output in rsp.outputs)
|
|
||||||
final_usage = UsageInfo(
|
final_usage = UsageInfo(
|
||||||
prompt_tokens=prompt_tokens,
|
prompt_tokens=prompt_tokens,
|
||||||
completion_tokens=completion_tokens,
|
completion_tokens=completion_tokens,
|
||||||
total_tokens=prompt_tokens + completion_tokens,
|
total_tokens=prompt_tokens + completion_tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
final_usage_chunk = ChatCompletionStreamResponse(
|
final_usage_chunk = ChatCompletionStreamResponse(choices=[],
|
||||||
choices=[], model=args.model, usage=final_usage)
|
model=args.model,
|
||||||
|
usage=final_usage)
|
||||||
final_usage_data = final_usage_chunk.model_dump_json()
|
final_usage_data = final_usage_chunk.model_dump_json()
|
||||||
res.append(f"data: {final_usage_data}\n\n")
|
res.append(f"data: {final_usage_data}\n\n")
|
||||||
args.first_iteration = False
|
args.first_iteration = False
|
||||||
@ -323,7 +343,9 @@ def completion_stream_post_processor(rsp: DetokenizedGenerationResultBase, args:
|
|||||||
|
|
||||||
|
|
||||||
@nvtx_range_debug("completion_response_post_processor")
|
@nvtx_range_debug("completion_response_post_processor")
|
||||||
def completion_response_post_processor(rsp: GenerationResult, args: CompletionPostprocArgs) -> CompletionResponse:
|
def completion_response_post_processor(
|
||||||
|
rsp: GenerationResult,
|
||||||
|
args: CompletionPostprocArgs) -> CompletionResponse:
|
||||||
prompt_tokens = args.num_prompt_tokens
|
prompt_tokens = args.num_prompt_tokens
|
||||||
completion_tokens = 0
|
completion_tokens = 0
|
||||||
choices = []
|
choices = []
|
||||||
@ -331,23 +353,75 @@ def completion_response_post_processor(rsp: GenerationResult, args: CompletionPo
|
|||||||
text = output.text
|
text = output.text
|
||||||
if args.echo:
|
if args.echo:
|
||||||
text = args.prompt + text
|
text = args.prompt + text
|
||||||
disaggregated_params = to_disaggregated_params(output.disaggregated_params)
|
disaggregated_params = to_disaggregated_params(
|
||||||
|
output.disaggregated_params)
|
||||||
choice = CompletionResponseChoice(
|
choice = CompletionResponseChoice(
|
||||||
text=text if args.detokenize else "",
|
text=text if args.detokenize else "",
|
||||||
token_ids=None if args.detokenize else output.token_ids,
|
token_ids=None if args.detokenize else output.token_ids,
|
||||||
index=args.prompt_idx * args.num_choices + output.index,
|
index=args.prompt_idx * args.num_choices + output.index,
|
||||||
disaggregated_params=disaggregated_params,
|
disaggregated_params=disaggregated_params,
|
||||||
context_logits=None if rsp.context_logits is None else rsp.context_logits.tolist(),
|
context_logits=None
|
||||||
|
if rsp.context_logits is None else rsp.context_logits.tolist(),
|
||||||
stop_reason=output.stop_reason,
|
stop_reason=output.stop_reason,
|
||||||
finish_reason=output.finish_reason,
|
finish_reason=output.finish_reason,
|
||||||
avg_decoded_tokens_per_iter=getattr(rsp, 'avg_decoded_tokens_per_iter', None),
|
avg_decoded_tokens_per_iter=getattr(rsp,
|
||||||
|
'avg_decoded_tokens_per_iter',
|
||||||
|
None),
|
||||||
)
|
)
|
||||||
|
|
||||||
completion_tokens += output.length
|
completion_tokens += output.length
|
||||||
choices.append(choice)
|
choices.append(choice)
|
||||||
|
|
||||||
usage = UsageInfo(prompt_tokens=prompt_tokens,
|
usage = UsageInfo(prompt_tokens=prompt_tokens,
|
||||||
completion_tokens=completion_tokens,
|
completion_tokens=completion_tokens,
|
||||||
total_tokens=completion_tokens + prompt_tokens)
|
total_tokens=completion_tokens + prompt_tokens)
|
||||||
response = CompletionResponse(choices=choices, model=args.model, usage=usage)
|
response = CompletionResponse(choices=choices,
|
||||||
|
model=args.model,
|
||||||
|
usage=usage)
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(kw_only=True)
|
||||||
|
class ChatCompletionPostprocArgs(PostprocArgs):
|
||||||
|
model: str
|
||||||
|
tools: Optional[List[ChatCompletionToolsParam]]
|
||||||
|
tool_choice: Optional[Union[Literal["none", "auto"],
|
||||||
|
ChatCompletionNamedToolChoiceParam]]
|
||||||
|
request_id: Optional[int] = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_request(cls, request: ChatCompletionRequest):
|
||||||
|
return cls(
|
||||||
|
model=request.model,
|
||||||
|
tools=request.tools,
|
||||||
|
tool_choice=request.tool_choice,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@nvtx_range_debug("chat_harmony_post_processor")
|
||||||
|
def chat_harmony_post_processor(
|
||||||
|
rsp: GenerationResult,
|
||||||
|
args: ChatCompletionPostprocArgs) -> ChatCompletionResponse:
|
||||||
|
response = handle_non_streaming_response(
|
||||||
|
tools=args.tools,
|
||||||
|
tool_choice=args.tool_choice,
|
||||||
|
outputs=rsp.outputs,
|
||||||
|
model=args.model,
|
||||||
|
num_prompt_tokens=args.num_prompt_tokens,
|
||||||
|
)
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
@nvtx_range_debug("chat_harmony_streaming_post_processor")
|
||||||
|
def chat_harmony_streaming_post_processor(
|
||||||
|
rsp: GenerationResult, args: ChatCompletionPostprocArgs) -> List[str]:
|
||||||
|
response = handle_streaming_response(
|
||||||
|
tools=args.tools,
|
||||||
|
tool_choice=args.tool_choice,
|
||||||
|
outputs=rsp.outputs,
|
||||||
|
model=args.model,
|
||||||
|
request_id=args.request_id,
|
||||||
|
done=rsp._done,
|
||||||
|
num_prompt_tokens=args.num_prompt_tokens,
|
||||||
|
)
|
||||||
return response
|
return response
|
||||||
|
|||||||
848
tensorrt_llm/serve/responses_utils.py
Normal file
848
tensorrt_llm/serve/responses_utils.py
Normal file
@ -0,0 +1,848 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
from collections.abc import AsyncGenerator
|
||||||
|
from copy import copy
|
||||||
|
from typing import Literal, Optional, OrderedDict, Union
|
||||||
|
|
||||||
|
# yapf: disable
|
||||||
|
from openai.types.responses import (ResponseCompletedEvent,
|
||||||
|
ResponseContentPartAddedEvent,
|
||||||
|
ResponseContentPartDoneEvent,
|
||||||
|
ResponseCreatedEvent,
|
||||||
|
ResponseFunctionToolCall,
|
||||||
|
ResponseInProgressEvent, ResponseOutputItem,
|
||||||
|
ResponseOutputItemAddedEvent,
|
||||||
|
ResponseOutputItemDoneEvent,
|
||||||
|
ResponseOutputMessage, ResponseOutputText,
|
||||||
|
ResponseReasoningItem,
|
||||||
|
ResponseReasoningTextDeltaEvent,
|
||||||
|
ResponseReasoningTextDoneEvent,
|
||||||
|
ResponseTextDeltaEvent,
|
||||||
|
ResponseTextDoneEvent)
|
||||||
|
# yapf: enable
|
||||||
|
from openai.types.responses.response_function_web_search import (
|
||||||
|
ActionFind, ActionOpenPage, ActionSearch, ResponseFunctionWebSearch)
|
||||||
|
from openai.types.responses.response_reasoning_item import Content
|
||||||
|
from openai.types.responses.tool import Tool
|
||||||
|
from openai_harmony import (Author, Conversation, DeveloperContent,
|
||||||
|
HarmonyEncodingName, Message, ReasoningEffort, Role,
|
||||||
|
StreamState, SystemContent, TextContent,
|
||||||
|
ToolDescription, load_harmony_encoding)
|
||||||
|
|
||||||
|
from tensorrt_llm.llmapi import SamplingParams
|
||||||
|
from tensorrt_llm.llmapi.llm import RequestOutput
|
||||||
|
from tensorrt_llm.logger import logger
|
||||||
|
from tensorrt_llm.serve.openai_protocol import (OpenAIBaseModel,
|
||||||
|
ResponseInputOutputItem,
|
||||||
|
ResponsesRequest,
|
||||||
|
ResponsesResponse)
|
||||||
|
|
||||||
|
from .harmony_adapter import HarmonyAdapter
|
||||||
|
|
||||||
|
REASONING_EFFORT = {
|
||||||
|
"high": ReasoningEffort.HIGH,
|
||||||
|
"medium": ReasoningEffort.MEDIUM,
|
||||||
|
"low": ReasoningEffort.LOW,
|
||||||
|
}
|
||||||
|
|
||||||
|
ENABLE_RESPONSES_DEBUG_MSG = False
|
||||||
|
|
||||||
|
|
||||||
|
def responses_debug_log(msg):
|
||||||
|
if ENABLE_RESPONSES_DEBUG_MSG:
|
||||||
|
logger.debug(msg)
|
||||||
|
|
||||||
|
|
||||||
|
_harmony_encoding = None
|
||||||
|
|
||||||
|
|
||||||
|
def random_uuid():
|
||||||
|
return str(uuid.uuid4().hex)
|
||||||
|
|
||||||
|
|
||||||
|
def get_encoding():
|
||||||
|
global _harmony_encoding
|
||||||
|
if _harmony_encoding is None:
|
||||||
|
_harmony_encoding = load_harmony_encoding(
|
||||||
|
HarmonyEncodingName.HARMONY_GPT_OSS)
|
||||||
|
return _harmony_encoding
|
||||||
|
|
||||||
|
|
||||||
|
def decode_tokens(tokens):
|
||||||
|
return get_encoding().decode(tokens)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_response_input(
|
||||||
|
input_msg: ResponseInputOutputItem,
|
||||||
|
prev_responses: list[Union[ResponseOutputItem, ResponseReasoningItem]]
|
||||||
|
) -> Message:
|
||||||
|
if not isinstance(input_msg, dict):
|
||||||
|
input_msg = input_msg.model_dump()
|
||||||
|
|
||||||
|
responses_debug_log(f"------- Parsing input -----------")
|
||||||
|
responses_debug_log(input_msg)
|
||||||
|
responses_debug_log("")
|
||||||
|
|
||||||
|
if "type" not in input_msg or input_msg["type"] == "message":
|
||||||
|
role = input_msg["role"]
|
||||||
|
content = input_msg["content"]
|
||||||
|
if role == "system":
|
||||||
|
# User is trying to set a system message. Change it to:
|
||||||
|
# <|start|>developer<|message|># Instructions
|
||||||
|
# {instructions}<|end|>
|
||||||
|
role = "developer"
|
||||||
|
text_prefix = "Instructions:\n"
|
||||||
|
else:
|
||||||
|
text_prefix = ""
|
||||||
|
if isinstance(content, str):
|
||||||
|
msg = Message.from_role_and_content(role, text_prefix + content)
|
||||||
|
elif isinstance(content, list):
|
||||||
|
contents = [
|
||||||
|
TextContent(text=text_prefix + c["text"]) for c in content
|
||||||
|
]
|
||||||
|
msg = Message.from_role_and_contents(role, contents)
|
||||||
|
else:
|
||||||
|
logger.warning("Responses API: Invalid input message type")
|
||||||
|
msg = None
|
||||||
|
elif input_msg["type"] == "function_call_output":
|
||||||
|
call_id = input_msg["call_id"]
|
||||||
|
call_response: Optional[ResponseFunctionToolCall] = None
|
||||||
|
for prev_response in reversed(prev_responses):
|
||||||
|
if isinstance(prev_response, ResponseFunctionToolCall
|
||||||
|
) and prev_response.call_id == call_id:
|
||||||
|
call_response = prev_response
|
||||||
|
break
|
||||||
|
if call_response is None:
|
||||||
|
raise ValueError(f"No call message found for {call_id}")
|
||||||
|
msg = Message.from_author_and_content(
|
||||||
|
Author.new(Role.TOOL, f"functions.{call_response.name}"),
|
||||||
|
input_msg["output"])
|
||||||
|
elif input_msg["type"] == "reasoning":
|
||||||
|
content = input_msg["content"]
|
||||||
|
assert len(content) == 1
|
||||||
|
msg = Message.from_role_and_content(Role.ASSISTANT, content[0]["text"])
|
||||||
|
elif input_msg["type"] == "function_call":
|
||||||
|
msg = Message.from_role_and_content(Role.ASSISTANT,
|
||||||
|
input_msg["arguments"])
|
||||||
|
msg = msg.with_channel("commentary")
|
||||||
|
msg = msg.with_recipient(f"functions.{input_msg['name']}")
|
||||||
|
msg = msg.with_content_type("json")
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown input type: {input_msg['type']}")
|
||||||
|
return msg
|
||||||
|
|
||||||
|
|
||||||
|
class ConversationHistoryStore:
|
||||||
|
|
||||||
|
def __init__(self, resp_capacity: int = 16, max_conversations=32):
|
||||||
|
self.response_capacity = resp_capacity
|
||||||
|
self.conversation_capacity = resp_capacity * 4
|
||||||
|
self.max_conversations = max_conversations
|
||||||
|
|
||||||
|
self.responses_lock = asyncio.Lock()
|
||||||
|
self.responses: OrderedDict[str, ResponsesResponse] = OrderedDict()
|
||||||
|
|
||||||
|
self.conversations_lock = asyncio.Lock()
|
||||||
|
self.conversations: OrderedDict[str, list[Message]] = OrderedDict()
|
||||||
|
self.response_to_conversation: dict[str, str] = {}
|
||||||
|
self.conversation_to_response: dict[str, str] = {}
|
||||||
|
|
||||||
|
async def load_response(self, resp_id: str) -> ResponsesResponse:
|
||||||
|
responses_debug_log(f"ConversationHistoryStore loading resp: {resp_id}")
|
||||||
|
async with self.responses_lock:
|
||||||
|
return self.responses.get(resp_id)
|
||||||
|
|
||||||
|
async def store_response(self,
|
||||||
|
resp: ResponsesResponse,
|
||||||
|
resp_msgs: Optional[list[Message]] = [],
|
||||||
|
prev_resp_id: Optional[str] = None) -> None:
|
||||||
|
resp_id = resp.id
|
||||||
|
responses_debug_log(f"ConversationHistoryStore storing resp: {resp_id}")
|
||||||
|
async with self.responses_lock:
|
||||||
|
self.responses[resp_id] = resp
|
||||||
|
if len(self.responses) > self.response_capacity:
|
||||||
|
self._pop_response()
|
||||||
|
|
||||||
|
async with self.conversations_lock:
|
||||||
|
conversation_id: str
|
||||||
|
if resp_id in self.response_to_conversation:
|
||||||
|
conversation_id = self.response_to_conversation[resp_id]
|
||||||
|
self.conversations[conversation_id].extend(resp_msgs)
|
||||||
|
elif prev_resp_id is not None:
|
||||||
|
conversation_id = self.response_to_conversation[prev_resp_id]
|
||||||
|
self.conversations[conversation_id].extend(resp_msgs)
|
||||||
|
while len(self.conversations[conversation_id]
|
||||||
|
) > self.conversation_capacity:
|
||||||
|
self._pop_conversation(resp_id)
|
||||||
|
else:
|
||||||
|
conversation_id = random_uuid()
|
||||||
|
self.conversations[conversation_id] = resp_msgs
|
||||||
|
|
||||||
|
responses_debug_log(
|
||||||
|
f" * storing at conversation id: {conversation_id}")
|
||||||
|
|
||||||
|
self.response_to_conversation[resp_id] = conversation_id
|
||||||
|
self.conversation_to_response[conversation_id] = resp_id
|
||||||
|
self._update_visited_conversation(conversation_id)
|
||||||
|
|
||||||
|
async def store_messages(self, resp_id: str, msgs: list[Message],
|
||||||
|
prev_resp_id: Optional[str]):
|
||||||
|
responses_debug_log(f"ConversationHistoryStore storing msg:")
|
||||||
|
for msg in msgs:
|
||||||
|
responses_debug_log(f" -> {msg.to_json()}")
|
||||||
|
|
||||||
|
async with self.conversations_lock:
|
||||||
|
conversation_id: str
|
||||||
|
if prev_resp_id is not None:
|
||||||
|
conversation_id = self.response_to_conversation[prev_resp_id]
|
||||||
|
else:
|
||||||
|
conversation_id = random_uuid()
|
||||||
|
|
||||||
|
responses_debug_log(
|
||||||
|
f" * storing at conversation: {conversation_id}")
|
||||||
|
self.conversations[conversation_id] = msgs
|
||||||
|
if len(self.conversations[conversation_id]
|
||||||
|
) > self.conversation_capacity:
|
||||||
|
self._pop_conversation(resp_id)
|
||||||
|
|
||||||
|
self.response_to_conversation[resp_id] = conversation_id
|
||||||
|
self.conversation_to_response[conversation_id] = resp_id
|
||||||
|
self._update_visited_conversation(conversation_id)
|
||||||
|
|
||||||
|
async def append_messages(self, resp_id: str, msgs: list[Message]):
|
||||||
|
responses_debug_log(f"ConversationHistoryStore appending msgs:")
|
||||||
|
for msg in msgs:
|
||||||
|
responses_debug_log(f" -> {msg.to_json()}")
|
||||||
|
|
||||||
|
async with self.conversations_lock:
|
||||||
|
assert resp_id in self.response_to_conversation
|
||||||
|
conversation_id = self.response_to_conversation[resp_id]
|
||||||
|
|
||||||
|
responses_debug_log(
|
||||||
|
f" * appending at conversation: {conversation_id}")
|
||||||
|
self.conversations[conversation_id].extend(msgs)
|
||||||
|
if len(self.conversations[conversation_id]
|
||||||
|
) > self.conversation_capacity:
|
||||||
|
self._pop_conversation(resp_id)
|
||||||
|
self._update_visited_conversation(conversation_id)
|
||||||
|
|
||||||
|
async def get_conversation_history(self, resp_id: str) -> list[Message]:
|
||||||
|
responses_debug_log(f"ConversationHistoryStore getting prev_msgs:")
|
||||||
|
responses_debug_log(f" -> prev_resp_id: {resp_id}")
|
||||||
|
async with self.conversations_lock:
|
||||||
|
if resp_id in self.response_to_conversation:
|
||||||
|
conversation_id = self.response_to_conversation[resp_id]
|
||||||
|
self._update_visited_conversation(conversation_id)
|
||||||
|
return self.conversations.get(conversation_id, [])
|
||||||
|
|
||||||
|
return []
|
||||||
|
|
||||||
|
def _update_visited_conversation(self, conversation_id) -> None:
|
||||||
|
if conversation_id not in self.conversations:
|
||||||
|
return
|
||||||
|
|
||||||
|
self.conversations.move_to_end(conversation_id)
|
||||||
|
if len(self.conversations) > self.max_conversations:
|
||||||
|
removed_id, _ = self.conversations.popitem(last=False)
|
||||||
|
responses_debug_log(
|
||||||
|
f"ConversationHistoryStore Removing conversation {removed_id}")
|
||||||
|
removed_resp_id = self.conversation_to_response[removed_id]
|
||||||
|
# The responses may have been removed due to response capacity
|
||||||
|
if removed_resp_id in self.response_to_conversation:
|
||||||
|
self.response_to_conversation.pop(removed_resp_id)
|
||||||
|
self.conversation_to_response.pop(removed_id)
|
||||||
|
|
||||||
|
def _pop_conversation(self, resp_id) -> None:
|
||||||
|
conversation_id = self.response_to_conversation.get(resp_id, None)
|
||||||
|
if conversation_id is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
conversation = self.conversations[conversation_id]
|
||||||
|
first_conversation_range = []
|
||||||
|
for i, msg in enumerate(conversation):
|
||||||
|
if msg.author.role == Role.USER:
|
||||||
|
first_conversation_range.append(i)
|
||||||
|
elif msg.channel == "final":
|
||||||
|
first_conversation_range.append(i)
|
||||||
|
break
|
||||||
|
del conversation[
|
||||||
|
first_conversation_range[0]:first_conversation_range[1] + 1]
|
||||||
|
|
||||||
|
def _pop_response(self) -> None:
|
||||||
|
responses_debug_log(f"responses type: {type(self.responses)}")
|
||||||
|
resp_id, _ = self.responses.popitem(last=False)
|
||||||
|
if resp_id in self.response_to_conversation:
|
||||||
|
self.response_to_conversation.pop(resp_id)
|
||||||
|
|
||||||
|
|
||||||
|
def get_system_message(
|
||||||
|
model_identity: Optional[str] = None,
|
||||||
|
reasoning_effort: Optional[Literal["high", "medium", "low"]] = None,
|
||||||
|
start_date: Optional[str] = None,
|
||||||
|
browser_description: Optional[str] = None,
|
||||||
|
python_description: Optional[str] = None,
|
||||||
|
) -> Message:
|
||||||
|
sys_msg_content = SystemContent.new()
|
||||||
|
if model_identity is not None:
|
||||||
|
sys_msg_content = sys_msg_content.with_model_identity(model_identity)
|
||||||
|
if reasoning_effort is not None:
|
||||||
|
sys_msg_content = sys_msg_content.with_reasoning_effort(
|
||||||
|
REASONING_EFFORT[reasoning_effort])
|
||||||
|
if start_date:
|
||||||
|
sys_msg_content = sys_msg_content.with_conversation_start_date(
|
||||||
|
start_date)
|
||||||
|
if browser_description is not None:
|
||||||
|
sys_msg_content = sys_msg_content.with_tools(browser_description)
|
||||||
|
if python_description is not None:
|
||||||
|
sys_msg_content = sys_msg_content.with_tools(python_description)
|
||||||
|
sys_msg = Message.from_role_and_content(Role.SYSTEM, sys_msg_content)
|
||||||
|
return sys_msg
|
||||||
|
|
||||||
|
|
||||||
|
def get_developer_message(instructions: Optional[str] = None,
|
||||||
|
tools: Optional[list[Tool]] = None) -> Message:
|
||||||
|
dev_msg_content = DeveloperContent.new()
|
||||||
|
if instructions is not None:
|
||||||
|
dev_msg_content = dev_msg_content.with_instructions(instructions)
|
||||||
|
if tools is not None:
|
||||||
|
function_tools = []
|
||||||
|
for tool in tools:
|
||||||
|
if tool.type in ("web_search_preview", "code_interpreter"):
|
||||||
|
# These are built-in tools that are added to the system message.
|
||||||
|
pass
|
||||||
|
elif tool.type == "function":
|
||||||
|
function_tools.append(tool)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"tool type {tool.type} not supported")
|
||||||
|
if function_tools:
|
||||||
|
function_tool_descriptions = [
|
||||||
|
ToolDescription.new(
|
||||||
|
name=tool.name,
|
||||||
|
description=tool.description,
|
||||||
|
parameters=tool.parameters,
|
||||||
|
) for tool in function_tools
|
||||||
|
]
|
||||||
|
dev_msg_content = dev_msg_content.with_function_tools(
|
||||||
|
function_tool_descriptions)
|
||||||
|
dev_msg = Message.from_role_and_content(Role.DEVELOPER, dev_msg_content)
|
||||||
|
return dev_msg
|
||||||
|
|
||||||
|
|
||||||
|
def get_user_message(content: str) -> Message:
|
||||||
|
return Message.from_role_and_content(Role.USER, content)
|
||||||
|
|
||||||
|
|
||||||
|
def construct_harmony_messages(
|
||||||
|
request: ResponsesRequest,
|
||||||
|
prev_response: Optional[ResponsesResponse],
|
||||||
|
prev_msgs: list[Message] = [],
|
||||||
|
) -> list[Message]:
|
||||||
|
"""Construct messages from request input, includes conversation history messages if exists."""
|
||||||
|
messages: list[Message] = []
|
||||||
|
if prev_response is None:
|
||||||
|
# New conversation.
|
||||||
|
reasoning_effort = (request.reasoning.effort
|
||||||
|
if request.reasoning else None)
|
||||||
|
sys_msg = get_system_message(reasoning_effort=reasoning_effort, )
|
||||||
|
messages.append(sys_msg)
|
||||||
|
dev_msg = get_developer_message(request.instructions, request.tools)
|
||||||
|
messages.append(dev_msg)
|
||||||
|
else:
|
||||||
|
messages.extend(prev_msgs)
|
||||||
|
# Append the new input.
|
||||||
|
# Responses API supports simple text inputs without chat format.
|
||||||
|
if isinstance(request.input, str):
|
||||||
|
messages.append(get_user_message(request.input))
|
||||||
|
else:
|
||||||
|
if prev_response is not None:
|
||||||
|
prev_outputs = copy(prev_response.output)
|
||||||
|
else:
|
||||||
|
prev_outputs = []
|
||||||
|
for input_msg in request.input:
|
||||||
|
msg = parse_response_input(input_msg, prev_outputs)
|
||||||
|
if msg is not None:
|
||||||
|
messages.append(msg)
|
||||||
|
# User passes in a a tool call request and its output. We need
|
||||||
|
# to add the tool call request to prev_outputs so that the
|
||||||
|
# parse_response_input can find the tool call request when
|
||||||
|
# parsing the tool call output.
|
||||||
|
if isinstance(input_msg, ResponseFunctionToolCall):
|
||||||
|
prev_outputs.append(input_msg)
|
||||||
|
return messages
|
||||||
|
|
||||||
|
|
||||||
|
def render_for_completion(messages: list[Message]) -> list[int]:
|
||||||
|
conversation = Conversation.from_messages(messages)
|
||||||
|
responses_debug_log("Rendering conversation:")
|
||||||
|
responses_debug_log(conversation.to_json())
|
||||||
|
token_ids = get_encoding().render_conversation_for_completion(
|
||||||
|
conversation, Role.ASSISTANT)
|
||||||
|
return token_ids
|
||||||
|
|
||||||
|
|
||||||
|
def parse_output_tokens(tokens: list[int]) -> list[Message]:
|
||||||
|
return get_encoding().parse_messages_from_completion_tokens(
|
||||||
|
tokens, role=Role.ASSISTANT)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_output_message(message: Message) -> list[ResponseOutputItem]:
|
||||||
|
"""
|
||||||
|
Parse a Harmony message into a list of output response items.
|
||||||
|
"""
|
||||||
|
if message.author.role != "assistant":
|
||||||
|
# This is a message from a tool to the assistant (e.g., search result).
|
||||||
|
# Don't include it in the final output for now. This aligns with
|
||||||
|
# OpenAI's behavior on models like o4-mini.
|
||||||
|
return []
|
||||||
|
|
||||||
|
output_items: list[ResponseOutputItem] = []
|
||||||
|
recipient = message.recipient
|
||||||
|
if recipient is not None and recipient.startswith("browser."):
|
||||||
|
if len(message.content) != 1:
|
||||||
|
raise ValueError("Invalid number of contents in browser message")
|
||||||
|
content = message.content[0]
|
||||||
|
browser_call = json.loads(content.text)
|
||||||
|
# TODO: translate to url properly!
|
||||||
|
if recipient == "browser.search":
|
||||||
|
action = ActionSearch(
|
||||||
|
query=f"cursor:{browser_call.get('query', '')}", type="search")
|
||||||
|
elif recipient == "browser.open":
|
||||||
|
action = ActionOpenPage(url=f"cursor:{browser_call.get('url', '')}",
|
||||||
|
type="open_page")
|
||||||
|
elif recipient == "browser.find":
|
||||||
|
action = ActionFind(pattern=browser_call["pattern"],
|
||||||
|
url=f"cursor:{browser_call.get('url', '')}",
|
||||||
|
type="find")
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown browser action: {recipient}")
|
||||||
|
web_search_item = ResponseFunctionWebSearch(
|
||||||
|
id=f"ws_{random_uuid()}",
|
||||||
|
action=action,
|
||||||
|
status="completed",
|
||||||
|
type="web_search_call",
|
||||||
|
)
|
||||||
|
output_items.append(web_search_item)
|
||||||
|
elif message.channel == "analysis":
|
||||||
|
for content in message.content:
|
||||||
|
reasoning_item = ResponseReasoningItem(
|
||||||
|
id=f"rs_{random_uuid()}",
|
||||||
|
summary=[],
|
||||||
|
type="reasoning",
|
||||||
|
content=[Content(text=content.text, type="reasoning_text")],
|
||||||
|
status=None,
|
||||||
|
)
|
||||||
|
output_items.append(reasoning_item)
|
||||||
|
elif message.channel == "commentary":
|
||||||
|
if message.recipient is None:
|
||||||
|
pass
|
||||||
|
elif message.recipient.startswith("functions."):
|
||||||
|
function_name = message.recipient.split(".")[-1]
|
||||||
|
for content in message.content:
|
||||||
|
random_id = random_uuid()
|
||||||
|
response_item = ResponseFunctionToolCall(
|
||||||
|
arguments=content.text,
|
||||||
|
call_id=f"call_{random_id}",
|
||||||
|
type="function_call",
|
||||||
|
name=function_name,
|
||||||
|
id=f"fc_{random_id}",
|
||||||
|
)
|
||||||
|
output_items.append(response_item)
|
||||||
|
elif message.recipient.startswith(
|
||||||
|
"python") or message.recipient.startswith("browser"):
|
||||||
|
for content in message.content:
|
||||||
|
reasoning_item = ResponseReasoningItem(
|
||||||
|
id=f"rs_{random_uuid()}",
|
||||||
|
summary=[],
|
||||||
|
type="reasoning",
|
||||||
|
content=[Content(text=content.text, type="reasoning_text")],
|
||||||
|
status=None,
|
||||||
|
)
|
||||||
|
output_items.append(reasoning_item)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown recipient: {message.recipient}")
|
||||||
|
elif message.channel == "final":
|
||||||
|
contents = []
|
||||||
|
for content in message.content:
|
||||||
|
output_text = ResponseOutputText(
|
||||||
|
text=content.text,
|
||||||
|
annotations=[], # TODO
|
||||||
|
type="output_text",
|
||||||
|
logprobs=None, # TODO
|
||||||
|
)
|
||||||
|
contents.append(output_text)
|
||||||
|
text_item = ResponseOutputMessage(
|
||||||
|
id=f"msg_{random_uuid()}",
|
||||||
|
content=contents,
|
||||||
|
role=message.author.role,
|
||||||
|
status="completed",
|
||||||
|
type="message",
|
||||||
|
)
|
||||||
|
output_items.append(text_item)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown channel: {message.channel}")
|
||||||
|
return output_items
|
||||||
|
|
||||||
|
|
||||||
|
def finish_reason_mapping(finish_reason: str) -> str:
|
||||||
|
match finish_reason:
|
||||||
|
case 'stop':
|
||||||
|
return 'completed'
|
||||||
|
case 'length':
|
||||||
|
return 'incomplete'
|
||||||
|
case 'timeout':
|
||||||
|
return 'failed'
|
||||||
|
case 'cancelled':
|
||||||
|
return 'cancelled'
|
||||||
|
|
||||||
|
raise RuntimeError("Should never reach here!")
|
||||||
|
|
||||||
|
|
||||||
|
async def request_preprocess(request: ResponsesRequest,
|
||||||
|
prev_response: Optional[ResponsesResponse],
|
||||||
|
harmony_adapter: HarmonyAdapter,
|
||||||
|
conversation_store: ConversationHistoryStore,
|
||||||
|
enable_store=False):
|
||||||
|
# TODO: fix default_max_tokens
|
||||||
|
sampling_params = request.to_sampling_params(
|
||||||
|
default_max_tokens=int(16384),
|
||||||
|
default_sampling_params={
|
||||||
|
"stop_token_ids": harmony_adapter.get_stop_tokens()
|
||||||
|
})
|
||||||
|
|
||||||
|
prev_response_id = request.previous_response_id
|
||||||
|
|
||||||
|
# TODO: better way to enable metrics
|
||||||
|
if len(os.getenv("TRTLLM_KVCACHE_TIME_OUTPUT_PATH", "")) > 0:
|
||||||
|
sampling_params.return_perf_metrics = True
|
||||||
|
|
||||||
|
prev_msgs = []
|
||||||
|
if enable_store:
|
||||||
|
prev_msgs = await conversation_store.get_conversation_history(
|
||||||
|
prev_response_id)
|
||||||
|
|
||||||
|
responses_debug_log(f"Prev msgs:")
|
||||||
|
for msg in prev_msgs:
|
||||||
|
responses_debug_log(f" -> {msg.to_json()}")
|
||||||
|
|
||||||
|
messages = construct_harmony_messages(request,
|
||||||
|
prev_response,
|
||||||
|
prev_msgs=prev_msgs)
|
||||||
|
|
||||||
|
if enable_store and request.store:
|
||||||
|
# Remove reasoning messages to save token usage during multi-turn conversation
|
||||||
|
msgs_to_store = [msg for msg in messages if msg.channel != "analysis"]
|
||||||
|
await conversation_store.store_messages(request.request_id,
|
||||||
|
msgs_to_store, prev_response_id)
|
||||||
|
|
||||||
|
input_tokens = render_for_completion(messages)
|
||||||
|
|
||||||
|
responses_debug_log("======= Complete Inputs to model =======")
|
||||||
|
responses_debug_log(decode_tokens(input_tokens))
|
||||||
|
responses_debug_log("========================================")
|
||||||
|
return input_tokens, sampling_params
|
||||||
|
|
||||||
|
|
||||||
|
async def create_response(
|
||||||
|
generator,
|
||||||
|
request: ResponsesRequest,
|
||||||
|
sampling_params,
|
||||||
|
model_name: str,
|
||||||
|
conversation_store: ConversationHistoryStore,
|
||||||
|
generation_result: RequestOutput = None,
|
||||||
|
enable_store=False,
|
||||||
|
create_time: int = None,
|
||||||
|
) -> ResponsesResponse:
|
||||||
|
|
||||||
|
final_res: Optional[RequestOutput] = None
|
||||||
|
response_creation_time = create_time if create_time is not None else int(
|
||||||
|
time.time())
|
||||||
|
prev_response_id = request.previous_response_id
|
||||||
|
|
||||||
|
if generation_result is not None:
|
||||||
|
final_res = generation_result
|
||||||
|
else:
|
||||||
|
final_res = await generator
|
||||||
|
|
||||||
|
if final_res is None:
|
||||||
|
raise RuntimeError("No output generated or provided")
|
||||||
|
|
||||||
|
responses_debug_log("================================================")
|
||||||
|
responses_debug_log("RAW MODEL OUTPUT:")
|
||||||
|
responses_debug_log(final_res.outputs)
|
||||||
|
responses_debug_log("================================================")
|
||||||
|
|
||||||
|
output_messages = parse_output_tokens(final_res.outputs[0].token_ids)
|
||||||
|
|
||||||
|
responses_debug_log(f"output messages: {len(output_messages)}")
|
||||||
|
for msg in output_messages:
|
||||||
|
responses_debug_log(f" -> {msg.to_json()}")
|
||||||
|
|
||||||
|
# prepare responses output
|
||||||
|
output_content = []
|
||||||
|
for msg in output_messages:
|
||||||
|
output_content.extend(parse_output_message(msg))
|
||||||
|
|
||||||
|
response = ResponsesResponse.from_request(
|
||||||
|
request=request,
|
||||||
|
sampling_params=sampling_params,
|
||||||
|
model_name=model_name,
|
||||||
|
created_time=response_creation_time,
|
||||||
|
output=output_content,
|
||||||
|
status=finish_reason_mapping(final_res.outputs[0].finish_reason),
|
||||||
|
)
|
||||||
|
|
||||||
|
if enable_store and request.store:
|
||||||
|
await conversation_store.store_response(resp=response,
|
||||||
|
resp_msgs=output_messages,
|
||||||
|
prev_resp_id=prev_response_id)
|
||||||
|
|
||||||
|
responses_debug_log("========== Response ===========")
|
||||||
|
responses_debug_log(response)
|
||||||
|
responses_debug_log("===============================")
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
async def process_streaming_events(
|
||||||
|
request: ResponsesRequest,
|
||||||
|
sampling_params: SamplingParams,
|
||||||
|
generator,
|
||||||
|
harmony_adapter: HarmonyAdapter,
|
||||||
|
model_name: str,
|
||||||
|
conversation_store: ConversationHistoryStore,
|
||||||
|
create_time: int = None,
|
||||||
|
enable_store=False) -> AsyncGenerator[str, None]:
|
||||||
|
sequence_number = 0
|
||||||
|
response_creation_time = create_time if create_time is not None else int(
|
||||||
|
time.time())
|
||||||
|
final_res: Optional[RequestOutput] = None
|
||||||
|
|
||||||
|
def _send_event(event: OpenAIBaseModel):
|
||||||
|
nonlocal sequence_number
|
||||||
|
# Set sequence_number if the event has this attribute
|
||||||
|
if hasattr(event, 'sequence_number'):
|
||||||
|
event.sequence_number = sequence_number
|
||||||
|
sequence_number += 1
|
||||||
|
# Get event type from the event's type field if it exists
|
||||||
|
event_type = getattr(event, 'type', 'unknown')
|
||||||
|
return (f"event: {event_type}\n"
|
||||||
|
f"data: {event.model_dump_json(indent=None)}\n\n")
|
||||||
|
|
||||||
|
current_content_index = 0 # FIXME: this number is never changed
|
||||||
|
current_output_index = 0
|
||||||
|
current_item_id = "" # FIXME: this number is never changed
|
||||||
|
sent_output_item_added = False
|
||||||
|
|
||||||
|
initial_response = ResponsesResponse.from_request(
|
||||||
|
request,
|
||||||
|
sampling_params,
|
||||||
|
model_name=model_name,
|
||||||
|
created_time=response_creation_time,
|
||||||
|
output=[],
|
||||||
|
status="in_progress",
|
||||||
|
usage=None,
|
||||||
|
).model_dump()
|
||||||
|
yield _send_event(
|
||||||
|
ResponseCreatedEvent(
|
||||||
|
type="response.created",
|
||||||
|
sequence_number=-1,
|
||||||
|
response=initial_response,
|
||||||
|
))
|
||||||
|
yield _send_event(
|
||||||
|
ResponseInProgressEvent(
|
||||||
|
type="response.in_progress",
|
||||||
|
sequence_number=-1,
|
||||||
|
response=initial_response,
|
||||||
|
))
|
||||||
|
|
||||||
|
tools = [tool.model_dump() for tool in request.tools]
|
||||||
|
stream_request_id = f"responses-api-{request.request_id}"
|
||||||
|
async for res in generator:
|
||||||
|
final_res = res
|
||||||
|
output = res.outputs[0]
|
||||||
|
|
||||||
|
messages = harmony_adapter.stateful_stream_harmony_tokens_to_openai_messages(
|
||||||
|
stream_request_id, output.token_ids_diff, tools,
|
||||||
|
request.tool_choice)
|
||||||
|
stream_state = harmony_adapter.get_stream_state(stream_request_id)
|
||||||
|
assert stream_state is not None
|
||||||
|
parser = stream_state.get_parser()
|
||||||
|
|
||||||
|
if parser.state == StreamState.EXPECT_START:
|
||||||
|
current_output_index += 1
|
||||||
|
sent_output_item_added = False
|
||||||
|
|
||||||
|
if len(messages) > 0:
|
||||||
|
previous_item = messages[-1]
|
||||||
|
if previous_item.recipient is not None:
|
||||||
|
# Deal with tool call here
|
||||||
|
pass
|
||||||
|
elif previous_item.channel == "analysis":
|
||||||
|
reasoning_item = ResponseReasoningItem(
|
||||||
|
type="reasoning",
|
||||||
|
content=[
|
||||||
|
Content(
|
||||||
|
text=previous_item.content[0].text,
|
||||||
|
type="reasoning_text",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
status="completed",
|
||||||
|
id=current_item_id,
|
||||||
|
summary=[],
|
||||||
|
)
|
||||||
|
yield _send_event(
|
||||||
|
ResponseReasoningTextDoneEvent(
|
||||||
|
type="response.reasoning_text.done",
|
||||||
|
item_id=current_item_id,
|
||||||
|
sequence_number=-1,
|
||||||
|
output_index=current_output_index,
|
||||||
|
content_index=current_content_index,
|
||||||
|
text=previous_item.content[0].text,
|
||||||
|
))
|
||||||
|
yield _send_event(
|
||||||
|
ResponseOutputItemDoneEvent(
|
||||||
|
type="response.output_item.done",
|
||||||
|
sequence_number=-1,
|
||||||
|
output_index=current_output_index,
|
||||||
|
item=reasoning_item,
|
||||||
|
))
|
||||||
|
elif previous_item.channel == "final":
|
||||||
|
text_content = ResponseOutputText(
|
||||||
|
type="output_text",
|
||||||
|
text=previous_item.content[0].text,
|
||||||
|
annotations=[],
|
||||||
|
)
|
||||||
|
yield _send_event(
|
||||||
|
ResponseTextDoneEvent(
|
||||||
|
type="response.output_text.done",
|
||||||
|
sequence_number=-1,
|
||||||
|
output_index=current_output_index,
|
||||||
|
content_index=current_content_index,
|
||||||
|
text=previous_item.content[0].text,
|
||||||
|
logprobs=[],
|
||||||
|
item_id=current_item_id,
|
||||||
|
))
|
||||||
|
yield _send_event(
|
||||||
|
ResponseContentPartDoneEvent(
|
||||||
|
type="response.content_part.done",
|
||||||
|
sequence_number=-1,
|
||||||
|
item_id=current_item_id,
|
||||||
|
output_index=current_output_index,
|
||||||
|
content_index=current_content_index,
|
||||||
|
part=text_content,
|
||||||
|
))
|
||||||
|
yield _send_event(
|
||||||
|
ResponseOutputItemDoneEvent(
|
||||||
|
type="response.output_item.done",
|
||||||
|
sequence_number=-1,
|
||||||
|
output_index=current_output_index,
|
||||||
|
item=ResponseOutputMessage(
|
||||||
|
id=current_item_id,
|
||||||
|
type="message",
|
||||||
|
role="assistant",
|
||||||
|
content=[text_content],
|
||||||
|
status="completed",
|
||||||
|
),
|
||||||
|
))
|
||||||
|
|
||||||
|
if parser.last_content_delta:
|
||||||
|
if (parser.current_channel == "final"
|
||||||
|
and parser.current_recipient is None):
|
||||||
|
if not sent_output_item_added:
|
||||||
|
sent_output_item_added = True
|
||||||
|
yield _send_event(
|
||||||
|
ResponseOutputItemAddedEvent(
|
||||||
|
type="response.output_item.added",
|
||||||
|
sequence_number=-1,
|
||||||
|
output_index=current_output_index,
|
||||||
|
item=ResponseOutputMessage(
|
||||||
|
id=current_item_id,
|
||||||
|
type="message",
|
||||||
|
role="assistant",
|
||||||
|
content=[],
|
||||||
|
status="in_progress",
|
||||||
|
),
|
||||||
|
))
|
||||||
|
yield _send_event(
|
||||||
|
ResponseContentPartAddedEvent(
|
||||||
|
type="response.content_part.added",
|
||||||
|
sequence_number=-1,
|
||||||
|
output_index=current_output_index,
|
||||||
|
item_id=current_item_id,
|
||||||
|
content_index=current_content_index,
|
||||||
|
part=ResponseOutputText(
|
||||||
|
type="output_text",
|
||||||
|
text="",
|
||||||
|
annotations=[],
|
||||||
|
logprobs=[],
|
||||||
|
),
|
||||||
|
))
|
||||||
|
yield _send_event(
|
||||||
|
ResponseTextDeltaEvent(
|
||||||
|
type="response.output_text.delta",
|
||||||
|
sequence_number=-1,
|
||||||
|
content_index=current_content_index,
|
||||||
|
output_index=current_output_index,
|
||||||
|
item_id=current_item_id,
|
||||||
|
delta=parser.last_content_delta,
|
||||||
|
# TODO, use logprobs from ctx.last_request_output
|
||||||
|
logprobs=[],
|
||||||
|
))
|
||||||
|
elif (parser.current_channel == "analysis"
|
||||||
|
and parser.current_recipient is None):
|
||||||
|
if not sent_output_item_added:
|
||||||
|
sent_output_item_added = True
|
||||||
|
yield _send_event(
|
||||||
|
ResponseOutputItemAddedEvent(
|
||||||
|
type="response.output_item.added",
|
||||||
|
sequence_number=-1,
|
||||||
|
output_index=current_output_index,
|
||||||
|
item=ResponseReasoningItem(
|
||||||
|
type="reasoning",
|
||||||
|
id=current_item_id,
|
||||||
|
summary=[],
|
||||||
|
status="in_progress",
|
||||||
|
),
|
||||||
|
))
|
||||||
|
yield _send_event(
|
||||||
|
ResponseContentPartAddedEvent(
|
||||||
|
type="response.content_part.added",
|
||||||
|
sequence_number=-1,
|
||||||
|
output_index=current_output_index,
|
||||||
|
item_id=current_item_id,
|
||||||
|
content_index=current_content_index,
|
||||||
|
part=ResponseOutputText(
|
||||||
|
type="output_text",
|
||||||
|
text="",
|
||||||
|
annotations=[],
|
||||||
|
logprobs=[],
|
||||||
|
),
|
||||||
|
))
|
||||||
|
yield _send_event(
|
||||||
|
ResponseReasoningTextDeltaEvent(
|
||||||
|
type="response.reasoning_text.delta",
|
||||||
|
item_id=current_item_id,
|
||||||
|
output_index=current_output_index,
|
||||||
|
content_index=current_content_index,
|
||||||
|
delta=parser.last_content_delta,
|
||||||
|
sequence_number=-1,
|
||||||
|
))
|
||||||
|
|
||||||
|
# TODO(JunyiXu-nv): support built-in tools(python/browser/code interpreter)
|
||||||
|
|
||||||
|
final_response = await create_response(generator, request, sampling_params,
|
||||||
|
model_name, conversation_store,
|
||||||
|
final_res, enable_store,
|
||||||
|
response_creation_time)
|
||||||
|
|
||||||
|
yield _send_event(
|
||||||
|
ResponseCompletedEvent(
|
||||||
|
type="response.completed",
|
||||||
|
sequence_number=-1,
|
||||||
|
response=final_response.model_dump(),
|
||||||
|
))
|
||||||
@ -1513,6 +1513,13 @@ def test_openai_chat_harmony(llm_root, llm_venv):
|
|||||||
str(test_root / "_test_openai_chat_harmony.py")])
|
str(test_root / "_test_openai_chat_harmony.py")])
|
||||||
|
|
||||||
|
|
||||||
|
def test_openai_responses(llm_root, llm_venv):
|
||||||
|
test_root = unittest_path() / "llmapi" / "apps"
|
||||||
|
llm_venv.run_cmd(
|
||||||
|
["-m", "pytest",
|
||||||
|
str(test_root / "_test_openai_responses.py")])
|
||||||
|
|
||||||
|
|
||||||
def test_openai_prometheus(llm_root, llm_venv):
|
def test_openai_prometheus(llm_root, llm_venv):
|
||||||
test_root = unittest_path() / "llmapi" / "apps"
|
test_root = unittest_path() / "llmapi" / "apps"
|
||||||
llm_venv.run_cmd(
|
llm_venv.run_cmd(
|
||||||
|
|||||||
@ -104,6 +104,7 @@ l0_h100:
|
|||||||
- test_e2e.py::test_trtllm_bench_request_rate_and_concurrency[enable_concurrency-enable_request_rate] # negative test
|
- test_e2e.py::test_trtllm_bench_request_rate_and_concurrency[enable_concurrency-enable_request_rate] # negative test
|
||||||
- test_e2e.py::test_trtllm_bench_help_sanity[meta-llama/Llama-3.1-8B]
|
- test_e2e.py::test_trtllm_bench_help_sanity[meta-llama/Llama-3.1-8B]
|
||||||
- test_e2e.py::test_openai_chat_harmony
|
- test_e2e.py::test_openai_chat_harmony
|
||||||
|
- test_e2e.py::test_openai_responses
|
||||||
- test_e2e.py::test_ptp_quickstart_multimodal[gemma-3-27b-it-gemma/gemma-3-27b-it-image-True]
|
- test_e2e.py::test_ptp_quickstart_multimodal[gemma-3-27b-it-gemma/gemma-3-27b-it-image-True]
|
||||||
# ------------- AutoDeploy tests ---------------
|
# ------------- AutoDeploy tests ---------------
|
||||||
- accuracy/test_llm_api_autodeploy.py::TestLlama3_1_8B::test_auto_dtype
|
- accuracy/test_llm_api_autodeploy.py::TestLlama3_1_8B::test_auto_dtype
|
||||||
|
|||||||
@ -14,10 +14,18 @@ def model():
|
|||||||
return "gpt_oss/gpt-oss-20b/"
|
return "gpt_oss/gpt-oss-20b/"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module",
|
||||||
|
params=[0, 2],
|
||||||
|
ids=["disable_processpool", "enable_processpool"])
|
||||||
|
def num_postprocess_workers(request):
|
||||||
|
return request.param
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def server(model: str):
|
def server(model: str, num_postprocess_workers: int):
|
||||||
model_path = get_model_path(model)
|
model_path = get_model_path(model)
|
||||||
with RemoteOpenAIServer(model_path) as remote_server:
|
args = ["--num_postprocess_workers", f"{num_postprocess_workers}"]
|
||||||
|
with RemoteOpenAIServer(model_path, args) as remote_server:
|
||||||
yield remote_server
|
yield remote_server
|
||||||
|
|
||||||
|
|
||||||
@ -147,6 +155,10 @@ async def test_streaming(client: openai.AsyncOpenAI, model: str):
|
|||||||
collected_chunks = []
|
collected_chunks = []
|
||||||
collected_messages = []
|
collected_messages = []
|
||||||
async for chunk in response:
|
async for chunk in response:
|
||||||
|
# Last streaming response will only contains usage info
|
||||||
|
if len(chunk.choices) <= 0:
|
||||||
|
continue
|
||||||
|
|
||||||
collected_chunks.append(chunk)
|
collected_chunks.append(chunk)
|
||||||
collected_messages.append(chunk.choices[0].delta)
|
collected_messages.append(chunk.choices[0].delta)
|
||||||
|
|
||||||
@ -198,6 +210,10 @@ async def test_streaming_tool_call(client: openai.AsyncOpenAI, model: str):
|
|||||||
reasoning_chunks: list[str] = []
|
reasoning_chunks: list[str] = []
|
||||||
tool_arg_chunks: list[str] = []
|
tool_arg_chunks: list[str] = []
|
||||||
async for chunk in response:
|
async for chunk in response:
|
||||||
|
# Last streaming response will only contains usage info
|
||||||
|
if len(chunk.choices) <= 0:
|
||||||
|
continue
|
||||||
|
|
||||||
delta = chunk.choices[0].delta
|
delta = chunk.choices[0].delta
|
||||||
if hasattr(delta, "tool_calls") and delta.tool_calls:
|
if hasattr(delta, "tool_calls") and delta.tool_calls:
|
||||||
function = delta.tool_calls[0].function
|
function = delta.tool_calls[0].function
|
||||||
|
|||||||
241
tests/unittest/llmapi/apps/_test_openai_responses.py
Normal file
241
tests/unittest/llmapi/apps/_test_openai_responses.py
Normal file
@ -0,0 +1,241 @@
|
|||||||
|
import json
|
||||||
|
|
||||||
|
import openai
|
||||||
|
import pytest
|
||||||
|
from openai.types.responses import (ResponseCompletedEvent,
|
||||||
|
ResponseReasoningTextDeltaEvent,
|
||||||
|
ResponseTextDeltaEvent)
|
||||||
|
|
||||||
|
from ..test_llm import get_model_path
|
||||||
|
from .openai_server import RemoteOpenAIServer
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.threadleak(enabled=False)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module", ids=["GPT-OSS-20B"])
|
||||||
|
def model():
|
||||||
|
return "gpt_oss/gpt-oss-20b/"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def server(model: str):
|
||||||
|
model_path = get_model_path(model)
|
||||||
|
with RemoteOpenAIServer(model_path) as remote_server:
|
||||||
|
yield remote_server
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def client(server: RemoteOpenAIServer):
|
||||||
|
return server.get_async_client()
|
||||||
|
|
||||||
|
|
||||||
|
def check_reponse(response, prefix=""):
|
||||||
|
reasoning_exist, message_exist = False, False
|
||||||
|
for output in response.output:
|
||||||
|
if output.type == "reasoning":
|
||||||
|
reasoning_exist = True
|
||||||
|
elif output.type == "message":
|
||||||
|
message_exist = True
|
||||||
|
|
||||||
|
assert reasoning_exist, f"{prefix}Reasoning content not exists!"
|
||||||
|
assert message_exist, f"{prefix}Message content not exists!"
|
||||||
|
|
||||||
|
|
||||||
|
def check_tool_calling(response, first_resp=True, prefix=""):
|
||||||
|
reasoning_exist, tool_call_exist, message_exist = False, False, False
|
||||||
|
function_call = None
|
||||||
|
for output in response.output:
|
||||||
|
if output.type == "reasoning":
|
||||||
|
reasoning_exist = True
|
||||||
|
elif output.type == "function_call":
|
||||||
|
tool_call_exist = True
|
||||||
|
function_call = output
|
||||||
|
elif output.type == "message":
|
||||||
|
message_exist = True
|
||||||
|
|
||||||
|
if first_resp:
|
||||||
|
assert reasoning_exist and tool_call_exist, f"{prefix}Invalid tool calling 1st response"
|
||||||
|
assert not message_exist, f"{prefix}Invalid tool calling 1st response"
|
||||||
|
|
||||||
|
return function_call
|
||||||
|
else:
|
||||||
|
assert reasoning_exist and message_exist, f"{prefix}Invalid tool calling 2nd response"
|
||||||
|
assert not tool_call_exist, f"{prefix}Invalid tool calling 2nd response"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="module")
|
||||||
|
async def test_reasoning(client: openai.AsyncOpenAI, model: str):
|
||||||
|
response = await client.responses.create(
|
||||||
|
model=model, input="Which one is larger as numeric, 9.9 or 9.11?")
|
||||||
|
|
||||||
|
check_reponse(response, "test_reasoning: ")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="module")
|
||||||
|
async def test_reasoning_effort(client: openai.AsyncOpenAI, model: str):
|
||||||
|
for effort in ["low", "medium", "high"]:
|
||||||
|
response = await client.responses.create(
|
||||||
|
model=model,
|
||||||
|
instructions="Use less than 1024 tokens for reasoning",
|
||||||
|
input="Which one is larger as numeric, 9.9 or 9.11?",
|
||||||
|
reasoning={"effort": effort})
|
||||||
|
check_reponse(response, f"test_reasoning_effort_{effort}: ")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="module")
|
||||||
|
async def test_chat(client: openai.AsyncOpenAI, model: str):
|
||||||
|
response = await client.responses.create(model=model,
|
||||||
|
input=[{
|
||||||
|
"role":
|
||||||
|
"developer",
|
||||||
|
"content":
|
||||||
|
"Respond in Chinese."
|
||||||
|
}, {
|
||||||
|
"role": "user",
|
||||||
|
"content": "Hello!"
|
||||||
|
}, {
|
||||||
|
"role":
|
||||||
|
"assistant",
|
||||||
|
"content":
|
||||||
|
"Hello! How can I help you?"
|
||||||
|
}, {
|
||||||
|
"role": "user",
|
||||||
|
"content": "Tell me a joke."
|
||||||
|
}])
|
||||||
|
check_reponse(response, "test_chat: ")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="module")
|
||||||
|
async def test_multi_turn_chat(client: openai.AsyncOpenAI, model: str):
|
||||||
|
response = await client.responses.create(model=model,
|
||||||
|
input="What is the answer of 1+1?")
|
||||||
|
check_reponse(response, "test_multi_turn_chat_1: ")
|
||||||
|
|
||||||
|
response_2 = await client.responses.create(
|
||||||
|
model=model,
|
||||||
|
input="What is the answer of previous question?",
|
||||||
|
previous_response_id=response.id)
|
||||||
|
check_reponse(response_2, "test_multi_turn_chat_2: ")
|
||||||
|
|
||||||
|
|
||||||
|
def get_current_weather(location: str, format: str = "celsius") -> dict:
|
||||||
|
return {"sunny": True, "temperature": 20 if format == "celsius" else 68}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="module")
|
||||||
|
async def test_tool_calls(client: openai.AsyncOpenAI, model: str):
|
||||||
|
tool_get_current_weather = {
|
||||||
|
"type": "function",
|
||||||
|
"name": "get_current_weather",
|
||||||
|
"description": "Gets the current weather in the provided location.",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"location": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The city and state, e.g. San Francisco, CA",
|
||||||
|
},
|
||||||
|
"format": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "default: celsius",
|
||||||
|
"enum": ["celsius", "fahrenheit"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["location"],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
messages = [{"role": "user", "content": "What is the weather like in SF?"}]
|
||||||
|
response = await client.responses.create(
|
||||||
|
model=model,
|
||||||
|
input=messages,
|
||||||
|
tools=[tool_get_current_weather],
|
||||||
|
)
|
||||||
|
messages.extend(response.output)
|
||||||
|
function_call = check_tool_calling(response, True, "test_tool_calls: ")
|
||||||
|
|
||||||
|
assert function_call.name == "get_current_weather"
|
||||||
|
|
||||||
|
args = json.loads(function_call.arguments)
|
||||||
|
answer = get_current_weather(**args)
|
||||||
|
messages.append({
|
||||||
|
"type": "function_call_output",
|
||||||
|
"call_id": function_call.call_id,
|
||||||
|
"output": json.dumps(answer),
|
||||||
|
})
|
||||||
|
|
||||||
|
response = await client.responses.create(model=model,
|
||||||
|
input=messages,
|
||||||
|
tools=[tool_get_current_weather])
|
||||||
|
|
||||||
|
check_tool_calling(response, False, "test_tool_calls: ")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="module")
|
||||||
|
async def test_streaming(client: openai.AsyncOpenAI, model: str):
|
||||||
|
stream = await client.responses.create(
|
||||||
|
model=model,
|
||||||
|
input="Explain the theory of relativity in brief.",
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
reasoning_deltas, message_deltas = list(), list()
|
||||||
|
async for event in stream:
|
||||||
|
if isinstance(event, ResponseTextDeltaEvent):
|
||||||
|
message_deltas.append(event.delta)
|
||||||
|
elif isinstance(event, ResponseReasoningTextDeltaEvent):
|
||||||
|
reasoning_deltas.append(event.delta)
|
||||||
|
|
||||||
|
full_response = "".join(message_deltas)
|
||||||
|
full_reasoning_response = "".join(reasoning_deltas)
|
||||||
|
assert full_response
|
||||||
|
assert full_reasoning_response
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="module")
|
||||||
|
async def test_streaming_tool_call(client: openai.AsyncOpenAI, model: str):
|
||||||
|
tool_get_current_weather = {
|
||||||
|
"type": "function",
|
||||||
|
"name": "get_current_weather",
|
||||||
|
"description": "Gets the current weather in the provided location.",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"location": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The city and state, e.g. San Francisco, CA",
|
||||||
|
},
|
||||||
|
"format": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "default: celsius",
|
||||||
|
"enum": ["celsius", "fahrenheit"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["location"],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
messages = [{"role": "user", "content": "What is the weather like in SF?"}]
|
||||||
|
stream = await client.responses.create(
|
||||||
|
model=model,
|
||||||
|
input=messages,
|
||||||
|
tools=[tool_get_current_weather],
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
function_call = None
|
||||||
|
reasoning_deltas = list()
|
||||||
|
async for event in stream:
|
||||||
|
if isinstance(event, ResponseCompletedEvent):
|
||||||
|
for output in event.response.output:
|
||||||
|
if output.type == "function_call":
|
||||||
|
function_call = output
|
||||||
|
elif isinstance(event, ResponseReasoningTextDeltaEvent):
|
||||||
|
reasoning_deltas.append(event.delta)
|
||||||
|
|
||||||
|
reasoning = "".join(reasoning_deltas)
|
||||||
|
tool_args = json.loads(function_call.arguments)
|
||||||
|
|
||||||
|
assert function_call.name == "get_current_weather", "wrong function calling name"
|
||||||
|
assert tool_args, "tool args not exists!"
|
||||||
|
assert reasoning, "reasoning not exists!"
|
||||||
|
|
||||||
|
get_current_weather(**tool_args)
|
||||||
Loading…
Reference in New Issue
Block a user