mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
[None][feat] Cherry-pick Responses API and multiple postprocess workers support for chat harmony (#7600)
Signed-off-by: Junyi Xu <219237550+JunyiXu-nv@users.noreply.github.com> Co-authored-by: Pengyun Lin <81065165+LinPoly@users.noreply.github.com> Co-authored-by: Tao Li @ NVIDIA <tali@nvidia.com>
This commit is contained in:
parent
d60dad6b9d
commit
ac0df0a393
@ -6,7 +6,7 @@ import re
|
||||
import time
|
||||
import traceback
|
||||
import uuid
|
||||
from typing import Any, AsyncGenerator, Literal
|
||||
from typing import Any, List, Literal
|
||||
|
||||
from openai_harmony import (Author, Conversation, DeveloperContent,
|
||||
HarmonyEncodingName, HarmonyError, Message,
|
||||
@ -14,15 +14,15 @@ from openai_harmony import (Author, Conversation, DeveloperContent,
|
||||
SystemContent, TextContent, ToolDescription,
|
||||
load_harmony_encoding)
|
||||
|
||||
from tensorrt_llm.llmapi import RequestOutput
|
||||
from tensorrt_llm.logger import logger
|
||||
|
||||
# yapf: disable
|
||||
from .openai_protocol import (ChatCompletionMessageParam, ChatCompletionRequest,
|
||||
from .openai_protocol import (ChatCompletionMessageParam,
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionResponseChoice,
|
||||
ChatCompletionResponseStreamChoice,
|
||||
ChatCompletionStreamResponse, ChatMessage,
|
||||
ChatCompletionStreamResponse,
|
||||
ChatCompletionToolsParam, ChatMessage,
|
||||
DeltaFunctionCall, DeltaMessage, DeltaToolCall,
|
||||
UsageInfo)
|
||||
|
||||
@ -57,7 +57,8 @@ class HarmonyStreamState:
|
||||
# Normal case: filter based on available tools
|
||||
self.should_filter_tools = True
|
||||
self.available_tools = {
|
||||
tool.get("function", {}).get("name", "")
|
||||
tool.get("function", {}).get("name", "") if tool.get(
|
||||
"name", None) is None else tool.get("name")
|
||||
for tool in available_tools
|
||||
}
|
||||
self.available_tools.discard("")
|
||||
@ -78,6 +79,9 @@ class HarmonyStreamState:
|
||||
|
||||
logger.debug("Created HarmonyStreamState for request %s", request_id)
|
||||
|
||||
def get_parser(self) -> StreamableParser:
|
||||
return self.parser
|
||||
|
||||
def process_token_batch(self, tokens: list[int]) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Process a batch of tokens while maintaining parsing state.
|
||||
@ -125,6 +129,42 @@ class HarmonyStreamState:
|
||||
|
||||
return deltas
|
||||
|
||||
def process_token_batch_to_messages(self,
|
||||
tokens: list[int]) -> list[Message]:
|
||||
"""
|
||||
Process a batch of tokens while maintaining parsing state.
|
||||
Returns OpenAI Messages for Responses API
|
||||
"""
|
||||
self.tokens_processed += len(tokens)
|
||||
|
||||
for token in tokens:
|
||||
# Store previous state for transition detection
|
||||
prev_channel = self.parser.current_channel
|
||||
prev_recipient = self.parser.current_recipient
|
||||
|
||||
# Process the token
|
||||
self.parser.process(token)
|
||||
|
||||
# Detect channel/recipient transitions AFTER processing each token
|
||||
channel_changed = prev_channel != self.parser.current_channel
|
||||
recipient_changed = prev_recipient != self.parser.current_recipient
|
||||
|
||||
if channel_changed or recipient_changed:
|
||||
# Mark any active tool calls as completed if we're leaving a tool call
|
||||
if prev_channel == "commentary" and prev_recipient and "functions." in str(
|
||||
prev_recipient):
|
||||
func_name = str(prev_recipient).split("functions.")[-1]
|
||||
for tool_id, tool_info in self.tool_calls.items():
|
||||
if tool_info["name"] == func_name and tool_info.get(
|
||||
"active", True):
|
||||
tool_info["active"] = False
|
||||
|
||||
# Reset channel state for new channel
|
||||
self.channel_started = False
|
||||
self.current_channel_state = None
|
||||
|
||||
return self.parser.messages
|
||||
|
||||
def _create_closing_token_delta(self) -> dict[str, Any] | None:
|
||||
"""Create closing token delta for channel transition."""
|
||||
if not self.current_channel_state or not self.channel_started:
|
||||
@ -317,6 +357,9 @@ class HarmonyAdapter:
|
||||
"<|constrain|>": 200009,
|
||||
}
|
||||
|
||||
def get_stream_state(self, request_id: str) -> HarmonyStreamState | None:
|
||||
return self._stream_states.get(request_id, None)
|
||||
|
||||
def get_stop_tokens(self) -> list[int]:
|
||||
"""
|
||||
Return the list of stop token IDs for Harmony format.
|
||||
@ -1214,6 +1257,42 @@ class HarmonyAdapter:
|
||||
# Return empty deltas to continue processing
|
||||
return []
|
||||
|
||||
def stateful_stream_harmony_tokens_to_openai_messages(
|
||||
self,
|
||||
request_id: str,
|
||||
tokens: list[int],
|
||||
available_tools: list[dict[str, Any]] | None = None,
|
||||
tool_choice: str | None = None) -> list[Message]:
|
||||
"""
|
||||
Process tokens using stateful parsing.
|
||||
|
||||
This method maintains persistent state across multiple calls for the same request,
|
||||
ensuring proper channel transitions and tool call handling.
|
||||
|
||||
Args:
|
||||
request_id: Request ID to maintain state per request
|
||||
tokens: New tokens from this iteration
|
||||
available_tools: Available tools for filtering
|
||||
|
||||
Returns:
|
||||
List of OpenAI Messages
|
||||
"""
|
||||
stream_state = self._stream_states.get(request_id, None)
|
||||
if stream_state is None:
|
||||
stream_state = self.create_stream_state(request_id, available_tools,
|
||||
tool_choice)
|
||||
|
||||
try:
|
||||
messages = stream_state.process_token_batch_to_messages(tokens)
|
||||
return messages
|
||||
except (HarmonyError, UnicodeDecodeError, ValueError):
|
||||
logger.error(
|
||||
f"Streaming: Failed to process token batch of {len(tokens)} tokens for request {request_id}",
|
||||
)
|
||||
logger.debug(f"Problematic streaming tokens: {tokens}")
|
||||
|
||||
return []
|
||||
|
||||
def create_openai_streaming_response(
|
||||
self,
|
||||
request_id: str,
|
||||
@ -1406,36 +1485,72 @@ class HarmonyAdapter:
|
||||
return True
|
||||
|
||||
|
||||
async def handle_streaming_response(
|
||||
harmony_adapter: HarmonyAdapter,
|
||||
generator: RequestOutput,
|
||||
request_id: str,
|
||||
request: ChatCompletionRequest,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Handle streaming response with harmony format."""
|
||||
_SERVE_HARMONY_ADAPTER: HarmonyAdapter = None
|
||||
|
||||
|
||||
def get_harmony_adapter():
|
||||
global _SERVE_HARMONY_ADAPTER
|
||||
if _SERVE_HARMONY_ADAPTER is None:
|
||||
_SERVE_HARMONY_ADAPTER = HarmonyAdapter()
|
||||
|
||||
return _SERVE_HARMONY_ADAPTER
|
||||
|
||||
|
||||
def handle_streaming_response(tools: List[ChatCompletionToolsParam],
|
||||
tool_choice: str, outputs: List, model: str,
|
||||
request_id: str, done: bool,
|
||||
num_prompt_tokens: int):
|
||||
first_iteration = True
|
||||
async for res in generator:
|
||||
output = res.outputs[0]
|
||||
output = outputs[0]
|
||||
|
||||
# Convert tools to dictionary format for harmony adapter (standard pattern)
|
||||
tools_dict = None
|
||||
if request.tools:
|
||||
tools_dict = [tool.model_dump() for tool in request.tools]
|
||||
# Convert tools to dictionary format for harmony adapter (standard pattern)
|
||||
tools_dict = None
|
||||
harmony_adapter = get_harmony_adapter()
|
||||
if tools:
|
||||
tools_dict = [tool.model_dump() for tool in tools]
|
||||
|
||||
# Get tool_choice from request - if "none", don't pass tools to parser
|
||||
tool_choice = getattr(request, 'tool_choice', None)
|
||||
if tool_choice == "none":
|
||||
tools_for_parser = None
|
||||
# Get tool_choice from request - if "none", don't pass tools to parser
|
||||
if tool_choice == "none":
|
||||
tools_for_parser = None
|
||||
else:
|
||||
tools_for_parser = tools_dict
|
||||
|
||||
# Create OpenAI streaming responses
|
||||
try:
|
||||
res = []
|
||||
if done:
|
||||
# Clean up state
|
||||
harmony_adapter.cleanup_stream_state(request_id)
|
||||
|
||||
usage_info = _create_usage_info(num_prompt_tokens, outputs)
|
||||
|
||||
# Send final message with finish_reason
|
||||
final_response = ChatCompletionStreamResponse(
|
||||
model=model,
|
||||
choices=[
|
||||
ChatCompletionResponseStreamChoice(
|
||||
index=0,
|
||||
delta=DeltaMessage(),
|
||||
finish_reason=output.finish_reason,
|
||||
stop_reason=output.stop_reason)
|
||||
],
|
||||
)
|
||||
|
||||
final_response_json = final_response.model_dump_json(
|
||||
exclude_none=True)
|
||||
final_usage_chunk = ChatCompletionStreamResponse(choices=[],
|
||||
model=model,
|
||||
usage=usage_info)
|
||||
final_usage_json = final_usage_chunk.model_dump_json(
|
||||
exclude_none=True)
|
||||
res.append(f"data: {final_response_json}\n\n")
|
||||
res.append(f"data: {final_usage_json}\n\n")
|
||||
else:
|
||||
tools_for_parser = tools_dict
|
||||
|
||||
# Create OpenAI streaming responses
|
||||
try:
|
||||
responses = harmony_adapter.create_openai_streaming_response(
|
||||
request_id=request_id,
|
||||
tokens=output.token_ids_diff,
|
||||
available_tools=tools_for_parser,
|
||||
model_name=request.model,
|
||||
model_name=model,
|
||||
tool_choice=tool_choice)
|
||||
# Send first response after receiving the first output
|
||||
if first_iteration:
|
||||
@ -1446,64 +1561,44 @@ async def handle_streaming_response(
|
||||
delta=first_delta)
|
||||
|
||||
first_response = ChatCompletionStreamResponse(
|
||||
model=request.model,
|
||||
model=model,
|
||||
choices=[choice],
|
||||
)
|
||||
|
||||
response_json = first_response.model_dump_json(
|
||||
exclude_none=True)
|
||||
yield f"data: {response_json}\n\n"
|
||||
res.append(f"data: {response_json}\n\n")
|
||||
|
||||
for response in responses:
|
||||
yield response
|
||||
res.extend(responses)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create OpenAI streaming response: {e}")
|
||||
logger.debug(f"Streaming error details: {traceback.format_exc()}")
|
||||
# Clean up state
|
||||
harmony_adapter.cleanup_stream_state(request_id)
|
||||
raise e
|
||||
return res
|
||||
|
||||
# Clean up state
|
||||
harmony_adapter.cleanup_stream_state(request_id)
|
||||
|
||||
# Send final message with finish_reason
|
||||
output = generator.outputs[0]
|
||||
final_response = ChatCompletionStreamResponse(
|
||||
model=request.model,
|
||||
choices=[
|
||||
ChatCompletionResponseStreamChoice(
|
||||
index=0,
|
||||
delta=DeltaMessage(),
|
||||
finish_reason=output.finish_reason,
|
||||
stop_reason=output.stop_reason)
|
||||
])
|
||||
|
||||
yield f"data: {final_response.model_dump_json(exclude_unset=True)}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create OpenAI streaming response: {e}")
|
||||
logger.debug(f"Streaming error details: {traceback.format_exc()}")
|
||||
# Clean up state
|
||||
harmony_adapter.cleanup_stream_state(request_id)
|
||||
raise e
|
||||
|
||||
|
||||
async def handle_non_streaming_response(
|
||||
harmony_adapter: HarmonyAdapter, promise: RequestOutput,
|
||||
request: ChatCompletionRequest) -> ChatCompletionResponse:
|
||||
def handle_non_streaming_response(tools: List[ChatCompletionToolsParam],
|
||||
tool_choice: str, outputs: List, model: str,
|
||||
num_prompt_tokens: int):
|
||||
"""Handle non-streaming response with harmony format."""
|
||||
# Get final result
|
||||
await promise
|
||||
|
||||
# Parse harmony output to OpenAI format
|
||||
# Convert tools to dictionary format for harmony adapter (standard pattern)
|
||||
tools_dict = None
|
||||
if request.tools:
|
||||
tools_dict = [tool.model_dump() for tool in request.tools]
|
||||
harmony_adapter = get_harmony_adapter()
|
||||
if tools:
|
||||
tools_dict = [tool.model_dump() for tool in tools]
|
||||
|
||||
# Get tool_choice from request - if "none", don't pass tools to parser
|
||||
tool_choice = getattr(request, 'tool_choice', None)
|
||||
if tool_choice == "none":
|
||||
tools_for_parser = None
|
||||
else:
|
||||
tools_for_parser = tools_dict
|
||||
|
||||
output = promise.outputs[0]
|
||||
output = outputs[0]
|
||||
parsed_output = harmony_adapter.harmony_output_to_openai(
|
||||
output.token_ids, tools_for_parser, tool_choice)
|
||||
|
||||
@ -1518,11 +1613,11 @@ async def handle_non_streaming_response(
|
||||
output.finish_reason)
|
||||
|
||||
# Create usage info from metrics (RequestOutput doesn't have usage in v1)
|
||||
usage_info = _create_usage_info(promise)
|
||||
usage_info = _create_usage_info(num_prompt_tokens, outputs)
|
||||
|
||||
# Create response
|
||||
response = ChatCompletionResponse(
|
||||
model=request.model,
|
||||
model=model,
|
||||
choices=[
|
||||
ChatCompletionResponseChoice(
|
||||
index=0,
|
||||
@ -1534,7 +1629,6 @@ async def handle_non_streaming_response(
|
||||
# Optional: Log if harmony parsing failed (for debugging)
|
||||
if parsed_output.get('_harmony_parsing_failed'):
|
||||
logger.warning("⚠️ Harmony parsing fell back to raw text decoding")
|
||||
logger.debug(f"request\n\n{request}")
|
||||
logger.debug(f"response\n\n{response}\n")
|
||||
|
||||
return response
|
||||
@ -1567,15 +1661,10 @@ def _determine_finish_reason(parsed_output: dict[str, Any],
|
||||
return reason
|
||||
|
||||
|
||||
def _create_usage_info(final_res: RequestOutput) -> UsageInfo:
|
||||
def _create_usage_info(num_prompt_tokens, outputs) -> UsageInfo:
|
||||
"""Create usage info from RequestOutput following serving_chat.py pattern."""
|
||||
# Calculate prompt tokens from prompt_token_ids and encoder_prompt_token_ids
|
||||
assert final_res.prompt_token_ids is not None
|
||||
num_prompt_tokens = len(final_res.prompt_token_ids)
|
||||
|
||||
# Calculate completion tokens from all outputs
|
||||
num_generated_tokens = sum(
|
||||
len(output.token_ids) for output in final_res.outputs)
|
||||
num_generated_tokens = sum(len(output.token_ids) for output in outputs)
|
||||
|
||||
# Create usage info
|
||||
usage = UsageInfo(prompt_tokens=num_prompt_tokens,
|
||||
|
||||
@ -11,9 +11,16 @@ from openai.types.chat import \
|
||||
ChatCompletionContentPartParam as OpenAIChatCompletionContentPartParam
|
||||
from openai.types.chat import \
|
||||
ChatCompletionMessageParam as OpenAIChatCompletionMessageParam
|
||||
from openai.types.responses import (ResponseFunctionToolCall,
|
||||
ResponseInputItemParam, ResponseOutputItem,
|
||||
ResponsePrompt, ResponseReasoningItem,
|
||||
ResponseStatus, ResponseTextConfig)
|
||||
from openai.types.responses.response import ToolChoice
|
||||
from openai.types.responses.tool import Tool
|
||||
from openai.types.shared import Metadata, Reasoning
|
||||
from openai_harmony import ReasoningEffort
|
||||
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
||||
from typing_extensions import Annotated, Required, TypedDict
|
||||
from typing_extensions import Annotated, Required, TypeAlias, TypedDict
|
||||
|
||||
from tensorrt_llm.executor.request import LoRARequest
|
||||
from tensorrt_llm.llmapi import DisaggregatedParams as LlmDisaggregatedParams
|
||||
@ -665,6 +672,208 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
||||
return data
|
||||
|
||||
|
||||
ResponseInputOutputItem: TypeAlias = Union[ResponseInputItemParam,
|
||||
ResponseReasoningItem,
|
||||
ResponseFunctionToolCall]
|
||||
|
||||
|
||||
class ResponsesRequest(OpenAIBaseModel):
|
||||
# Ordered by official OpenAI API documentation
|
||||
# https://platform.openai.com/docs/api-reference/responses/create
|
||||
background: Optional[bool] = False
|
||||
include: Optional[list[
|
||||
Literal[
|
||||
"code_interpreter_call.outputs",
|
||||
"computer_call_output.output.image_url",
|
||||
"file_search_call.results",
|
||||
"message.input_image.image_url",
|
||||
"message.output_text.logprobs",
|
||||
"reasoning.encrypted_content",
|
||||
],
|
||||
]] = None
|
||||
input: Union[str, list[ResponseInputOutputItem]]
|
||||
instructions: Optional[str] = None
|
||||
max_output_tokens: Optional[int] = None
|
||||
max_tool_calls: Optional[int] = None
|
||||
metadata: Optional[Metadata] = None
|
||||
model: str
|
||||
parallel_tool_calls: Optional[bool] = False
|
||||
previous_response_id: Optional[str] = None
|
||||
prompt: Optional[ResponsePrompt] = None
|
||||
reasoning: Optional[Reasoning] = None
|
||||
service_tier: Literal["auto", "default", "flex", "scale",
|
||||
"priority"] = "auto"
|
||||
store: Optional[bool] = True
|
||||
stream: Optional[bool] = False
|
||||
temperature: Optional[float] = None
|
||||
text: Optional[ResponseTextConfig] = None
|
||||
tool_choice: ToolChoice = "auto"
|
||||
tools: list[Tool] = Field(default_factory=list)
|
||||
top_logprobs: Optional[int] = 0
|
||||
top_p: Optional[float] = None
|
||||
truncation: Optional[Literal["auto", "disabled"]] = "disabled"
|
||||
user: Optional[str] = None
|
||||
|
||||
request_id: str = Field(
|
||||
default_factory=lambda: f"resp_{str(uuid.uuid4().hex)}",
|
||||
description=(
|
||||
"The request_id related to this request. If the caller does "
|
||||
"not set it, a random_uuid will be generated. This id is used "
|
||||
"through out the inference process and return in response."),
|
||||
)
|
||||
|
||||
_DEFAULT_SAMPLING_PARAMS = {
|
||||
"temperature": 1.0,
|
||||
"top_p": 1.0,
|
||||
}
|
||||
|
||||
def to_sampling_params(
|
||||
self,
|
||||
default_max_tokens: int,
|
||||
default_sampling_params: Optional[dict] = None,
|
||||
) -> SamplingParams:
|
||||
if self.max_output_tokens is None:
|
||||
max_tokens = default_max_tokens
|
||||
else:
|
||||
max_tokens = min(self.max_output_tokens, default_max_tokens)
|
||||
|
||||
default_sampling_params = default_sampling_params or {}
|
||||
if (temperature := self.temperature) is None:
|
||||
temperature = default_sampling_params.get(
|
||||
"temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"])
|
||||
if (top_p := self.top_p) is None:
|
||||
top_p = default_sampling_params.get(
|
||||
"top_p", self._DEFAULT_SAMPLING_PARAMS["top_p"])
|
||||
stop_token_ids = default_sampling_params.get("stop_token_ids")
|
||||
|
||||
# Structured output
|
||||
guided_decoding = None
|
||||
if self.text is not None and self.text.format is not None:
|
||||
response_format = self.text.format
|
||||
if response_format.type == "json_schema":
|
||||
guided_decoding = GuidedDecodingParams(
|
||||
json=response_format.schema_)
|
||||
elif response_format.type == "json_object":
|
||||
raise NotImplementedError("json_object is not supported")
|
||||
|
||||
return SamplingParams(
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
max_tokens=max_tokens,
|
||||
logprobs=self.top_logprobs,
|
||||
stop_token_ids=stop_token_ids,
|
||||
guided_decoding=guided_decoding,
|
||||
)
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_background(cls, data):
|
||||
if not data.get("background"):
|
||||
return data
|
||||
if not data.get("store", True):
|
||||
raise ValueError("background can only be used when `store` is true")
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_prompt(cls, data):
|
||||
if data.get("prompt") is not None:
|
||||
raise ValueError("prompt template is not supported")
|
||||
return data
|
||||
|
||||
|
||||
class InputTokensDetails(OpenAIBaseModel):
|
||||
cached_tokens: int
|
||||
|
||||
|
||||
class OutputTokensDetails(OpenAIBaseModel):
|
||||
reasoning_tokens: int
|
||||
|
||||
|
||||
class ResponseUsage(OpenAIBaseModel):
|
||||
input_tokens: int
|
||||
input_tokens_details: InputTokensDetails
|
||||
output_tokens: int
|
||||
output_tokens_details: OutputTokensDetails
|
||||
total_tokens: int
|
||||
|
||||
|
||||
class ResponsesResponse(OpenAIBaseModel):
|
||||
id: str = Field(default_factory=lambda: f"resp_{str(uuid.uuid4().hex)}")
|
||||
created_at: int = Field(default_factory=lambda: int(time.time()))
|
||||
# error: Optional[ResponseError] = None
|
||||
# incomplete_details: Optional[IncompleteDetails] = None
|
||||
instructions: Optional[str] = None
|
||||
metadata: Optional[Metadata] = None
|
||||
model: str
|
||||
object: Literal["response"] = "response"
|
||||
output: list[ResponseOutputItem]
|
||||
parallel_tool_calls: bool
|
||||
temperature: float
|
||||
tool_choice: ToolChoice
|
||||
tools: list[Tool]
|
||||
top_p: float
|
||||
background: bool
|
||||
max_output_tokens: int
|
||||
max_tool_calls: Optional[int] = None
|
||||
previous_response_id: Optional[str] = None
|
||||
prompt: Optional[ResponsePrompt] = None
|
||||
reasoning: Optional[Reasoning] = None
|
||||
service_tier: Literal["auto", "default", "flex", "scale", "priority"]
|
||||
status: ResponseStatus
|
||||
text: Optional[ResponseTextConfig] = None
|
||||
top_logprobs: int
|
||||
truncation: Literal["auto", "disabled"]
|
||||
usage: Optional[ResponseUsage] = None
|
||||
user: Optional[str] = None
|
||||
|
||||
@classmethod
|
||||
def from_request(
|
||||
cls,
|
||||
request: ResponsesRequest,
|
||||
sampling_params: SamplingParams,
|
||||
model_name: str,
|
||||
created_time: int,
|
||||
output: list[ResponseOutputItem],
|
||||
status: ResponseStatus,
|
||||
usage: Optional[ResponseUsage] = None,
|
||||
) -> "ResponsesResponse":
|
||||
return cls(
|
||||
id=request.request_id,
|
||||
created_at=created_time,
|
||||
instructions=request.instructions,
|
||||
metadata=request.metadata,
|
||||
model=model_name,
|
||||
output=output,
|
||||
parallel_tool_calls=request.parallel_tool_calls,
|
||||
temperature=sampling_params.temperature,
|
||||
tool_choice=request.tool_choice,
|
||||
tools=request.tools,
|
||||
top_p=sampling_params.top_p,
|
||||
background=request.background,
|
||||
max_output_tokens=sampling_params.max_tokens,
|
||||
max_tool_calls=request.max_tool_calls,
|
||||
previous_response_id=request.previous_response_id,
|
||||
prompt=request.prompt,
|
||||
reasoning=request.reasoning,
|
||||
service_tier=request.service_tier,
|
||||
status=status,
|
||||
text=request.text,
|
||||
top_logprobs=sampling_params.logprobs,
|
||||
truncation=request.truncation,
|
||||
user=request.user,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
|
||||
class ResponsesStreamResponse(OpenAIBaseModel):
|
||||
response: ResponsesResponse
|
||||
sequence_number: int
|
||||
type: Literal["response.created", "response.in_progress",
|
||||
"response.completed", "response.failed",
|
||||
"response.incomplete"]
|
||||
|
||||
|
||||
def encode_opaque_state(opaque_state: Optional[bytes]) -> Optional[str]:
|
||||
if opaque_state is None:
|
||||
return None
|
||||
|
||||
@ -41,17 +41,25 @@ from tensorrt_llm.serve.openai_protocol import (ChatCompletionRequest,
|
||||
CompletionResponse,
|
||||
CompletionResponseChoice,
|
||||
ErrorResponse, ModelCard,
|
||||
ModelList, UsageInfo,
|
||||
ModelList, ResponsesRequest,
|
||||
UsageInfo,
|
||||
to_llm_disaggregated_params)
|
||||
from tensorrt_llm.serve.postprocess_handlers import (
|
||||
ChatPostprocArgs, CompletionPostprocArgs, chat_response_post_processor,
|
||||
chat_stream_post_processor, completion_response_post_processor,
|
||||
completion_stream_post_processor)
|
||||
ChatCompletionPostprocArgs, ChatPostprocArgs, CompletionPostprocArgs,
|
||||
chat_harmony_post_processor, chat_harmony_streaming_post_processor,
|
||||
chat_response_post_processor, chat_stream_post_processor,
|
||||
completion_response_post_processor, completion_stream_post_processor)
|
||||
from tensorrt_llm.serve.responses_utils import ConversationHistoryStore
|
||||
from tensorrt_llm.serve.responses_utils import \
|
||||
create_response as responses_api_create_response
|
||||
from tensorrt_llm.serve.responses_utils import \
|
||||
process_streaming_events as responses_api_process_streaming_events
|
||||
from tensorrt_llm.serve.responses_utils import \
|
||||
request_preprocess as responses_api_request_preprocess
|
||||
from tensorrt_llm.version import __version__ as VERSION
|
||||
|
||||
from .._utils import nvtx_mark, set_prometheus_multiproc_dir
|
||||
from .harmony_adapter import (HarmonyAdapter, handle_non_streaming_response,
|
||||
handle_streaming_response,
|
||||
from .harmony_adapter import (HarmonyAdapter, get_harmony_adapter,
|
||||
maybe_transform_reasoning_effort)
|
||||
|
||||
# yapf: enale
|
||||
@ -83,6 +91,12 @@ class OpenAIServer:
|
||||
logger.debug("Failed to load AutoConfig for %s", hf_tokenizer_path)
|
||||
self.model_config = None
|
||||
|
||||
# Enable response storage for Responses API
|
||||
self.enable_store = True
|
||||
if len(os.getenv("TRTLLM_RESPONSES_API_DISABLE_STORE", "")) > 0:
|
||||
self.enable_store = False
|
||||
self.conversation_store = ConversationHistoryStore()
|
||||
|
||||
model_dir = Path(model)
|
||||
if model_dir.exists() and model_dir.is_dir():
|
||||
self.model = model_dir.name
|
||||
@ -104,7 +118,11 @@ class OpenAIServer:
|
||||
|
||||
# gpt-oss
|
||||
self.harmony_adapter: HarmonyAdapter | None = None
|
||||
self.use_harmony = self.model_config.model_type == "gpt_oss"
|
||||
disable_harmony = os.getenv("DISABLE_HARMONY_ADAPTER", "0") == "1"
|
||||
if disable_harmony:
|
||||
self.use_harmony = False
|
||||
else:
|
||||
self.use_harmony = (self.model_config.model_type == "gpt_oss")
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
@ -166,6 +184,20 @@ class OpenAIServer:
|
||||
return JSONResponse(content=error_response.model_dump(),
|
||||
status_code=error_response.code)
|
||||
|
||||
def _create_invalid_response_id_error(self, response_id: str) -> Response:
|
||||
return self.create_error_response(
|
||||
err_type="InvalidRequestError",
|
||||
message=(f"Invalid 'response_id': '{response_id}'. "
|
||||
"Expected an ID that begins with 'resp'."),
|
||||
)
|
||||
|
||||
def _create_response_id_not_found_error(self, response_id: str) -> Response:
|
||||
return self.create_error_response(
|
||||
err_type="InvalidRequestError",
|
||||
message=f"Response with id '{response_id}' not found.",
|
||||
status_code=HTTPStatus.NOT_FOUND,
|
||||
)
|
||||
|
||||
def register_routes(self):
|
||||
self.app.add_api_route("/health", self.health, methods=["GET"])
|
||||
self.app.add_api_route("/health_generate", self.health_generate, methods=["GET"])
|
||||
@ -182,6 +214,9 @@ class OpenAIServer:
|
||||
self.app.add_api_route("/v1/chat/completions",
|
||||
self.openai_chat if not self.use_harmony else self.chat_harmony,
|
||||
methods=["POST"])
|
||||
self.app.add_api_route("/v1/responses",
|
||||
self.openai_responses,
|
||||
methods=["POST"])
|
||||
if self.llm.args.return_perf_metrics:
|
||||
# register /prometheus/metrics
|
||||
self.mount_metrics()
|
||||
@ -681,11 +716,35 @@ class OpenAIServer:
|
||||
Chat Completion API with harmony format support.
|
||||
Supports both streaming and non-streaming modes.
|
||||
"""
|
||||
|
||||
async def create_harmony_response(
|
||||
promise: RequestOutput, postproc_params: PostprocParams) -> ChatCompletionResponse:
|
||||
await promise.aresult()
|
||||
if self.postproc_worker_enabled:
|
||||
chat_response =promise.outputs[0]._postprocess_result
|
||||
else:
|
||||
post_processor, args = postproc_params.post_processor, postproc_params.postproc_args
|
||||
chat_response = post_processor(promise, args)
|
||||
|
||||
return chat_response
|
||||
|
||||
async def create_streaming_generator(promise: RequestOutput, postproc_params: PostprocParams):
|
||||
if not self.postproc_worker_enabled:
|
||||
post_processor, args = postproc_params.post_processor, postproc_params.postproc_args
|
||||
|
||||
async for res in promise:
|
||||
pp_results = res.outputs[0]._postprocess_result if self.postproc_worker_enabled else post_processor(res, args)
|
||||
# await self._extract_metrics(res)
|
||||
for pp_res in pp_results:
|
||||
yield pp_res
|
||||
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
try:
|
||||
# Initialize HarmonyAdapter
|
||||
# NOTE: WAR for Disagg failure, may affect perf if no warmup
|
||||
if not self.harmony_adapter:
|
||||
self.harmony_adapter = HarmonyAdapter()
|
||||
self.harmony_adapter = get_harmony_adapter()
|
||||
# Convert Pydantic models to dictionaries for JSON serialization (standard pattern)
|
||||
tools_dict = None
|
||||
if request.tools:
|
||||
@ -720,27 +779,37 @@ class OpenAIServer:
|
||||
vocab_size=self.tokenizer.tokenizer.vocab_size)
|
||||
sampling_params.detokenize = False # Harmony adapter handles detokenization
|
||||
|
||||
postproc_args = ChatCompletionPostprocArgs.from_request(request)
|
||||
postproc_params = PostprocParams(
|
||||
post_processor=chat_harmony_streaming_post_processor
|
||||
if request.stream else chat_harmony_post_processor,
|
||||
postproc_args=postproc_args,
|
||||
)
|
||||
|
||||
# Generate
|
||||
promise = self.llm.generate_async(
|
||||
inputs=harmony_tokens,
|
||||
sampling_params=sampling_params,
|
||||
_postproc_params=postproc_params if self.postproc_worker_enabled else None,
|
||||
streaming=bool(request.stream),
|
||||
lora_request=request.lora_request,
|
||||
)
|
||||
postproc_args.request_id = promise.request_id
|
||||
|
||||
if not self.postproc_worker_enabled:
|
||||
postproc_args.num_prompt_tokens = len(promise.prompt_token_ids)
|
||||
|
||||
# Disconnect cancellation
|
||||
asyncio.create_task(self.await_disconnected(raw_request, promise))
|
||||
|
||||
# Handle streaming
|
||||
if request.stream:
|
||||
return StreamingResponse(
|
||||
handle_streaming_response(
|
||||
self.harmony_adapter, promise,
|
||||
str(promise.request_id), request,
|
||||
),
|
||||
content=create_streaming_generator(promise, postproc_params),
|
||||
media_type="text/event-stream"
|
||||
)
|
||||
else:
|
||||
response = await handle_non_streaming_response(self.harmony_adapter, promise, request)
|
||||
response = await create_harmony_response(promise, postproc_params)
|
||||
return JSONResponse(response.model_dump())
|
||||
|
||||
except Exception as e:
|
||||
@ -748,6 +817,80 @@ class OpenAIServer:
|
||||
logger.debug("Error details: %s", traceback.format_exc())
|
||||
return self.create_error_response(message=str(e), err_type="internal_error")
|
||||
|
||||
async def openai_responses(self, request: ResponsesRequest, raw_request: Request) -> Response:
|
||||
async def create_stream_response(generator, request: ResponsesRequest, sampling_params) -> AsyncGenerator[str, None]:
|
||||
async for event_data in responses_api_process_streaming_events(
|
||||
request=request,
|
||||
sampling_params=sampling_params,
|
||||
generator=generator,
|
||||
harmony_adapter=self.harmony_adapter,
|
||||
model_name=self.model,
|
||||
conversation_store=self.conversation_store,
|
||||
enable_store=self.enable_store
|
||||
):
|
||||
yield event_data
|
||||
|
||||
try:
|
||||
if not self.use_harmony:
|
||||
raise NotImplementedError("Responses API only supports harmony format for now")
|
||||
|
||||
# Initialize HarmonyAdapter
|
||||
# NOTE: WAR for Disagg failure, may affect perf if no warmup
|
||||
if not self.harmony_adapter:
|
||||
self.harmony_adapter = HarmonyAdapter()
|
||||
|
||||
if request.background:
|
||||
logger.warning("Request.background is not supported yet, will fallback to foreground processing.")
|
||||
|
||||
# Get prev response
|
||||
prev_response = None
|
||||
if self.enable_store:
|
||||
prev_response_id = request.previous_response_id
|
||||
if prev_response_id is not None:
|
||||
if not prev_response_id.startswith("resp_"):
|
||||
return self._create_invalid_response_id_error(prev_response_id)
|
||||
|
||||
prev_response = await self.conversation_store.load_response(prev_response_id)
|
||||
if prev_response is None:
|
||||
logger.debug(f"response_id {prev_response_id} not found")
|
||||
return self._create_response_id_not_found_error(prev_response_id)
|
||||
|
||||
input_tokens, sampling_params = await responses_api_request_preprocess(
|
||||
request, prev_response, self.harmony_adapter, self.conversation_store, self.enable_store)
|
||||
|
||||
promise = self.llm.generate_async(
|
||||
inputs=input_tokens,
|
||||
sampling_params=sampling_params,
|
||||
streaming=request.stream,
|
||||
)
|
||||
|
||||
asyncio.create_task(self.await_disconnected(raw_request, promise))
|
||||
|
||||
if request.stream:
|
||||
return StreamingResponse(
|
||||
create_stream_response(promise, request, sampling_params),
|
||||
media_type="text/event-stream"
|
||||
)
|
||||
else:
|
||||
return await responses_api_create_response(
|
||||
generator=promise,
|
||||
request=request,
|
||||
sampling_params=sampling_params,
|
||||
model_name=self.model,
|
||||
conversation_store=self.conversation_store,
|
||||
generation_result=None,
|
||||
enable_store=self.enable_store)
|
||||
except CppExecutorError:
|
||||
logger.error(traceback.format_exc())
|
||||
# If internal executor error is raised, shutdown the server
|
||||
signal.raise_signal(signal.SIGINT)
|
||||
except Exception as e:
|
||||
logger.error(traceback.format_exc())
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
return JSONResponse(content={"detail": "None"})
|
||||
|
||||
|
||||
async def __call__(self, host, port):
|
||||
# Store the binding address for server registration
|
||||
self.binding_addr = f"http://{host}:{port}"
|
||||
|
||||
@ -9,6 +9,8 @@ from ..llmapi.reasoning_parser import (BaseReasoningParser,
|
||||
ReasoningParserFactory)
|
||||
from ..llmapi.tokenizer import TransformersTokenizer
|
||||
# yapf: disable
|
||||
from .harmony_adapter import (handle_non_streaming_response,
|
||||
handle_streaming_response)
|
||||
from .openai_protocol import (ChatCompletionLogProbs,
|
||||
ChatCompletionLogProbsContent,
|
||||
ChatCompletionNamedToolChoiceParam,
|
||||
@ -24,7 +26,8 @@ from .openai_protocol import (ChatCompletionLogProbs,
|
||||
FunctionCall, StreamOptions, ToolCall, UsageInfo,
|
||||
to_disaggregated_params)
|
||||
|
||||
# yapf: enale
|
||||
# yapf: enable
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class ChatPostprocArgs(PostprocArgs):
|
||||
@ -57,8 +60,7 @@ class ChatPostprocArgs(PostprocArgs):
|
||||
)
|
||||
|
||||
|
||||
def create_logprobs(token_ids: List[int],
|
||||
tokenizer: TransformersTokenizer,
|
||||
def create_logprobs(token_ids: List[int], tokenizer: TransformersTokenizer,
|
||||
logprobs: List[float]) -> ChatCompletionLogProbs:
|
||||
assert len(token_ids) == len(logprobs), \
|
||||
"token_ids and logprobs have different lengths"
|
||||
@ -75,12 +77,14 @@ def create_logprobs(token_ids: List[int],
|
||||
return chat_logprobs
|
||||
|
||||
|
||||
def apply_reasoning_parser(args: ChatPostprocArgs, output_index: int, text: str, streaming: bool) -> Tuple[bool, str, str]:
|
||||
def apply_reasoning_parser(args: ChatPostprocArgs, output_index: int, text: str,
|
||||
streaming: bool) -> Tuple[bool, str, str]:
|
||||
reasoning_parser = None
|
||||
if args.reasoning_parser is not None:
|
||||
if output_index not in args.reasoning_parser_dict:
|
||||
args.reasoning_parser_dict[output_index] = ReasoningParserFactory.create_reasoning_parser(
|
||||
args.reasoning_parser)
|
||||
args.reasoning_parser_dict[
|
||||
output_index] = ReasoningParserFactory.create_reasoning_parser(
|
||||
args.reasoning_parser)
|
||||
reasoning_parser = args.reasoning_parser_dict[output_index]
|
||||
|
||||
in_reasoning = False
|
||||
@ -97,7 +101,8 @@ def apply_reasoning_parser(args: ChatPostprocArgs, output_index: int, text: str,
|
||||
|
||||
|
||||
@nvtx_range_debug("chat_stream_post_processor")
|
||||
def chat_stream_post_processor(rsp: GenerationResultBase, args: ChatPostprocArgs) -> List[str]:
|
||||
def chat_stream_post_processor(rsp: GenerationResultBase,
|
||||
args: ChatPostprocArgs) -> List[str]:
|
||||
|
||||
def yield_first_chat(num_tokens: int,
|
||||
idx: int,
|
||||
@ -128,9 +133,13 @@ def chat_stream_post_processor(rsp: GenerationResultBase, args: ChatPostprocArgs
|
||||
include_continuous_usage = False
|
||||
if args.first_iteration:
|
||||
for i in range(args.num_choices):
|
||||
res.append(f"data: {yield_first_chat(prompt_tokens, i, role=args.role)} \n\n")
|
||||
res.append(
|
||||
f"data: {yield_first_chat(prompt_tokens, i, role=args.role)} \n\n"
|
||||
)
|
||||
if args.echo and args.last_message_content:
|
||||
res.append(f"data: {yield_first_chat(prompt_tokens, i, content=args.last_message_content)} \n\n")
|
||||
res.append(
|
||||
f"data: {yield_first_chat(prompt_tokens, i, content=args.last_message_content)} \n\n"
|
||||
)
|
||||
args.first_iteration = False
|
||||
|
||||
for output in rsp.outputs:
|
||||
@ -158,14 +167,18 @@ def chat_stream_post_processor(rsp: GenerationResultBase, args: ChatPostprocArgs
|
||||
delta_message = DeltaMessage(
|
||||
content=delta_text, reasoning_content=reasoning_delta_text)
|
||||
|
||||
choice = ChatCompletionResponseStreamChoice(index=i,
|
||||
delta=delta_message,
|
||||
finish_reason=None,
|
||||
avg_decoded_tokens_per_iter=getattr(rsp, 'avg_decoded_tokens_per_iter', None))
|
||||
choice = ChatCompletionResponseStreamChoice(
|
||||
index=i,
|
||||
delta=delta_message,
|
||||
finish_reason=None,
|
||||
avg_decoded_tokens_per_iter=getattr(rsp,
|
||||
'avg_decoded_tokens_per_iter',
|
||||
None))
|
||||
if args.return_logprobs:
|
||||
logprobs = output.logprobs_diff
|
||||
token_ids = output.token_ids_diff
|
||||
choice.logprobs = create_logprobs(token_ids, args.tokenizer, logprobs)
|
||||
choice.logprobs = create_logprobs(token_ids, args.tokenizer,
|
||||
logprobs)
|
||||
if output.finish_reason is not None:
|
||||
choice.finish_reason = output.finish_reason
|
||||
choice.stop_reason = output.stop_reason
|
||||
@ -179,57 +192,62 @@ def chat_stream_post_processor(rsp: GenerationResultBase, args: ChatPostprocArgs
|
||||
res.append(f"data: {data}\n\n")
|
||||
|
||||
if include_usage and rsp._done:
|
||||
completion_tokens = sum(output.length
|
||||
for output in rsp.outputs)
|
||||
completion_tokens = sum(output.length for output in rsp.outputs)
|
||||
final_usage = UsageInfo(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
)
|
||||
|
||||
final_usage_chunk = ChatCompletionStreamResponse(
|
||||
choices=[], model=args.model, usage=final_usage)
|
||||
final_usage_chunk = ChatCompletionStreamResponse(choices=[],
|
||||
model=args.model,
|
||||
usage=final_usage)
|
||||
final_usage_data = final_usage_chunk.model_dump_json()
|
||||
res.append(f"data: {final_usage_data}\n\n")
|
||||
return res
|
||||
|
||||
|
||||
@nvtx_range_debug("chat_response_post_processor")
|
||||
def chat_response_post_processor(rsp: GenerationResultBase, args: ChatPostprocArgs) -> ChatCompletionResponse:
|
||||
def chat_response_post_processor(
|
||||
rsp: GenerationResultBase,
|
||||
args: ChatPostprocArgs) -> ChatCompletionResponse:
|
||||
choices: List[ChatCompletionResponseChoice] = []
|
||||
role = args.role
|
||||
for output in rsp.outputs:
|
||||
_, text, reasoning_text = apply_reasoning_parser(
|
||||
args, output.index, output.text, False)
|
||||
|
||||
if args.tool_choice and isinstance(
|
||||
args.tool_choice,
|
||||
ChatCompletionNamedToolChoiceParam):
|
||||
if args.tool_choice and isinstance(args.tool_choice,
|
||||
ChatCompletionNamedToolChoiceParam):
|
||||
message = ChatMessage(
|
||||
role=role,
|
||||
content="",
|
||||
tool_calls=[
|
||||
ToolCall(function=FunctionCall(
|
||||
name=args.tool_choice.function.name,
|
||||
arguments=text))
|
||||
name=args.tool_choice.function.name, arguments=text))
|
||||
])
|
||||
else:
|
||||
if text is None:
|
||||
text = ""
|
||||
message = ChatMessage(
|
||||
role=role, content=text, reasoning_content=reasoning_text)
|
||||
disaggregated_params = to_disaggregated_params(output.disaggregated_params)
|
||||
message = ChatMessage(role=role,
|
||||
content=text,
|
||||
reasoning_content=reasoning_text)
|
||||
disaggregated_params = to_disaggregated_params(
|
||||
output.disaggregated_params)
|
||||
choice = ChatCompletionResponseChoice(
|
||||
index=output.index,
|
||||
message=message,
|
||||
finish_reason=output.finish_reason,
|
||||
stop_reason=output.stop_reason,
|
||||
disaggregated_params=disaggregated_params,
|
||||
avg_decoded_tokens_per_iter=getattr(rsp, 'avg_decoded_tokens_per_iter', None),
|
||||
avg_decoded_tokens_per_iter=getattr(rsp,
|
||||
'avg_decoded_tokens_per_iter',
|
||||
None),
|
||||
)
|
||||
|
||||
if args.return_logprobs:
|
||||
choice.logprobs = create_logprobs(output.token_ids, args.tokenizer, output.logprobs)
|
||||
choice.logprobs = create_logprobs(output.token_ids, args.tokenizer,
|
||||
output.logprobs)
|
||||
choices.append(choice)
|
||||
|
||||
if args.echo and args.last_message_content:
|
||||
@ -238,8 +256,7 @@ def chat_response_post_processor(rsp: GenerationResultBase, args: ChatPostprocAr
|
||||
choice.message.content = full_message
|
||||
|
||||
num_prompt_tokens = args.num_prompt_tokens
|
||||
num_generated_tokens = sum(
|
||||
len(output.token_ids) for output in rsp.outputs)
|
||||
num_generated_tokens = sum(len(output.token_ids) for output in rsp.outputs)
|
||||
usage = UsageInfo(
|
||||
prompt_tokens=num_prompt_tokens,
|
||||
completion_tokens=num_generated_tokens,
|
||||
@ -275,7 +292,8 @@ class CompletionPostprocArgs(PostprocArgs):
|
||||
|
||||
|
||||
@nvtx_range_debug("completion_stream_post_processor")
|
||||
def completion_stream_post_processor(rsp: DetokenizedGenerationResultBase, args: CompletionPostprocArgs) -> List[str]:
|
||||
def completion_stream_post_processor(rsp: DetokenizedGenerationResultBase,
|
||||
args: CompletionPostprocArgs) -> List[str]:
|
||||
res: List[str] = []
|
||||
prompt_tokens = args.num_prompt_tokens
|
||||
if stream_option := args.stream_options:
|
||||
@ -293,9 +311,11 @@ def completion_stream_post_processor(rsp: DetokenizedGenerationResultBase, args:
|
||||
index=args.prompt_idx * args.num_choices + output.index,
|
||||
text=delta_text if args.detokenize else "",
|
||||
token_ids=None if args.detokenize else output.token_ids_diff,
|
||||
finish_reason = output.finish_reason,
|
||||
stop_reason = output.stop_reason,
|
||||
avg_decoded_tokens_per_iter=getattr(rsp, 'avg_decoded_tokens_per_iter', None),
|
||||
finish_reason=output.finish_reason,
|
||||
stop_reason=output.stop_reason,
|
||||
avg_decoded_tokens_per_iter=getattr(rsp,
|
||||
'avg_decoded_tokens_per_iter',
|
||||
None),
|
||||
)
|
||||
chunk = CompletionStreamResponse(model=args.model, choices=[choice])
|
||||
if include_continuous_usage:
|
||||
@ -306,16 +326,16 @@ def completion_stream_post_processor(rsp: DetokenizedGenerationResultBase, args:
|
||||
res.append(f"data: {data}\n\n")
|
||||
|
||||
if include_usage and rsp._done:
|
||||
completion_tokens = sum(output.length
|
||||
for output in rsp.outputs)
|
||||
completion_tokens = sum(output.length for output in rsp.outputs)
|
||||
final_usage = UsageInfo(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
)
|
||||
|
||||
final_usage_chunk = ChatCompletionStreamResponse(
|
||||
choices=[], model=args.model, usage=final_usage)
|
||||
final_usage_chunk = ChatCompletionStreamResponse(choices=[],
|
||||
model=args.model,
|
||||
usage=final_usage)
|
||||
final_usage_data = final_usage_chunk.model_dump_json()
|
||||
res.append(f"data: {final_usage_data}\n\n")
|
||||
args.first_iteration = False
|
||||
@ -323,7 +343,9 @@ def completion_stream_post_processor(rsp: DetokenizedGenerationResultBase, args:
|
||||
|
||||
|
||||
@nvtx_range_debug("completion_response_post_processor")
|
||||
def completion_response_post_processor(rsp: GenerationResult, args: CompletionPostprocArgs) -> CompletionResponse:
|
||||
def completion_response_post_processor(
|
||||
rsp: GenerationResult,
|
||||
args: CompletionPostprocArgs) -> CompletionResponse:
|
||||
prompt_tokens = args.num_prompt_tokens
|
||||
completion_tokens = 0
|
||||
choices = []
|
||||
@ -331,23 +353,75 @@ def completion_response_post_processor(rsp: GenerationResult, args: CompletionPo
|
||||
text = output.text
|
||||
if args.echo:
|
||||
text = args.prompt + text
|
||||
disaggregated_params = to_disaggregated_params(output.disaggregated_params)
|
||||
disaggregated_params = to_disaggregated_params(
|
||||
output.disaggregated_params)
|
||||
choice = CompletionResponseChoice(
|
||||
text=text if args.detokenize else "",
|
||||
token_ids=None if args.detokenize else output.token_ids,
|
||||
index=args.prompt_idx * args.num_choices + output.index,
|
||||
disaggregated_params=disaggregated_params,
|
||||
context_logits=None if rsp.context_logits is None else rsp.context_logits.tolist(),
|
||||
context_logits=None
|
||||
if rsp.context_logits is None else rsp.context_logits.tolist(),
|
||||
stop_reason=output.stop_reason,
|
||||
finish_reason=output.finish_reason,
|
||||
avg_decoded_tokens_per_iter=getattr(rsp, 'avg_decoded_tokens_per_iter', None),
|
||||
avg_decoded_tokens_per_iter=getattr(rsp,
|
||||
'avg_decoded_tokens_per_iter',
|
||||
None),
|
||||
)
|
||||
|
||||
completion_tokens += output.length
|
||||
choices.append(choice)
|
||||
|
||||
usage = UsageInfo(prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=completion_tokens + prompt_tokens)
|
||||
response = CompletionResponse(choices=choices, model=args.model, usage=usage)
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=completion_tokens + prompt_tokens)
|
||||
response = CompletionResponse(choices=choices,
|
||||
model=args.model,
|
||||
usage=usage)
|
||||
return response
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class ChatCompletionPostprocArgs(PostprocArgs):
|
||||
model: str
|
||||
tools: Optional[List[ChatCompletionToolsParam]]
|
||||
tool_choice: Optional[Union[Literal["none", "auto"],
|
||||
ChatCompletionNamedToolChoiceParam]]
|
||||
request_id: Optional[int] = None
|
||||
|
||||
@classmethod
|
||||
def from_request(cls, request: ChatCompletionRequest):
|
||||
return cls(
|
||||
model=request.model,
|
||||
tools=request.tools,
|
||||
tool_choice=request.tool_choice,
|
||||
)
|
||||
|
||||
|
||||
@nvtx_range_debug("chat_harmony_post_processor")
|
||||
def chat_harmony_post_processor(
|
||||
rsp: GenerationResult,
|
||||
args: ChatCompletionPostprocArgs) -> ChatCompletionResponse:
|
||||
response = handle_non_streaming_response(
|
||||
tools=args.tools,
|
||||
tool_choice=args.tool_choice,
|
||||
outputs=rsp.outputs,
|
||||
model=args.model,
|
||||
num_prompt_tokens=args.num_prompt_tokens,
|
||||
)
|
||||
return response
|
||||
|
||||
|
||||
@nvtx_range_debug("chat_harmony_streaming_post_processor")
|
||||
def chat_harmony_streaming_post_processor(
|
||||
rsp: GenerationResult, args: ChatCompletionPostprocArgs) -> List[str]:
|
||||
response = handle_streaming_response(
|
||||
tools=args.tools,
|
||||
tool_choice=args.tool_choice,
|
||||
outputs=rsp.outputs,
|
||||
model=args.model,
|
||||
request_id=args.request_id,
|
||||
done=rsp._done,
|
||||
num_prompt_tokens=args.num_prompt_tokens,
|
||||
)
|
||||
return response
|
||||
|
||||
848
tensorrt_llm/serve/responses_utils.py
Normal file
848
tensorrt_llm/serve/responses_utils.py
Normal file
@ -0,0 +1,848 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import AsyncGenerator
|
||||
from copy import copy
|
||||
from typing import Literal, Optional, OrderedDict, Union
|
||||
|
||||
# yapf: disable
|
||||
from openai.types.responses import (ResponseCompletedEvent,
|
||||
ResponseContentPartAddedEvent,
|
||||
ResponseContentPartDoneEvent,
|
||||
ResponseCreatedEvent,
|
||||
ResponseFunctionToolCall,
|
||||
ResponseInProgressEvent, ResponseOutputItem,
|
||||
ResponseOutputItemAddedEvent,
|
||||
ResponseOutputItemDoneEvent,
|
||||
ResponseOutputMessage, ResponseOutputText,
|
||||
ResponseReasoningItem,
|
||||
ResponseReasoningTextDeltaEvent,
|
||||
ResponseReasoningTextDoneEvent,
|
||||
ResponseTextDeltaEvent,
|
||||
ResponseTextDoneEvent)
|
||||
# yapf: enable
|
||||
from openai.types.responses.response_function_web_search import (
|
||||
ActionFind, ActionOpenPage, ActionSearch, ResponseFunctionWebSearch)
|
||||
from openai.types.responses.response_reasoning_item import Content
|
||||
from openai.types.responses.tool import Tool
|
||||
from openai_harmony import (Author, Conversation, DeveloperContent,
|
||||
HarmonyEncodingName, Message, ReasoningEffort, Role,
|
||||
StreamState, SystemContent, TextContent,
|
||||
ToolDescription, load_harmony_encoding)
|
||||
|
||||
from tensorrt_llm.llmapi import SamplingParams
|
||||
from tensorrt_llm.llmapi.llm import RequestOutput
|
||||
from tensorrt_llm.logger import logger
|
||||
from tensorrt_llm.serve.openai_protocol import (OpenAIBaseModel,
|
||||
ResponseInputOutputItem,
|
||||
ResponsesRequest,
|
||||
ResponsesResponse)
|
||||
|
||||
from .harmony_adapter import HarmonyAdapter
|
||||
|
||||
REASONING_EFFORT = {
|
||||
"high": ReasoningEffort.HIGH,
|
||||
"medium": ReasoningEffort.MEDIUM,
|
||||
"low": ReasoningEffort.LOW,
|
||||
}
|
||||
|
||||
ENABLE_RESPONSES_DEBUG_MSG = False
|
||||
|
||||
|
||||
def responses_debug_log(msg):
|
||||
if ENABLE_RESPONSES_DEBUG_MSG:
|
||||
logger.debug(msg)
|
||||
|
||||
|
||||
_harmony_encoding = None
|
||||
|
||||
|
||||
def random_uuid():
|
||||
return str(uuid.uuid4().hex)
|
||||
|
||||
|
||||
def get_encoding():
|
||||
global _harmony_encoding
|
||||
if _harmony_encoding is None:
|
||||
_harmony_encoding = load_harmony_encoding(
|
||||
HarmonyEncodingName.HARMONY_GPT_OSS)
|
||||
return _harmony_encoding
|
||||
|
||||
|
||||
def decode_tokens(tokens):
|
||||
return get_encoding().decode(tokens)
|
||||
|
||||
|
||||
def parse_response_input(
|
||||
input_msg: ResponseInputOutputItem,
|
||||
prev_responses: list[Union[ResponseOutputItem, ResponseReasoningItem]]
|
||||
) -> Message:
|
||||
if not isinstance(input_msg, dict):
|
||||
input_msg = input_msg.model_dump()
|
||||
|
||||
responses_debug_log(f"------- Parsing input -----------")
|
||||
responses_debug_log(input_msg)
|
||||
responses_debug_log("")
|
||||
|
||||
if "type" not in input_msg or input_msg["type"] == "message":
|
||||
role = input_msg["role"]
|
||||
content = input_msg["content"]
|
||||
if role == "system":
|
||||
# User is trying to set a system message. Change it to:
|
||||
# <|start|>developer<|message|># Instructions
|
||||
# {instructions}<|end|>
|
||||
role = "developer"
|
||||
text_prefix = "Instructions:\n"
|
||||
else:
|
||||
text_prefix = ""
|
||||
if isinstance(content, str):
|
||||
msg = Message.from_role_and_content(role, text_prefix + content)
|
||||
elif isinstance(content, list):
|
||||
contents = [
|
||||
TextContent(text=text_prefix + c["text"]) for c in content
|
||||
]
|
||||
msg = Message.from_role_and_contents(role, contents)
|
||||
else:
|
||||
logger.warning("Responses API: Invalid input message type")
|
||||
msg = None
|
||||
elif input_msg["type"] == "function_call_output":
|
||||
call_id = input_msg["call_id"]
|
||||
call_response: Optional[ResponseFunctionToolCall] = None
|
||||
for prev_response in reversed(prev_responses):
|
||||
if isinstance(prev_response, ResponseFunctionToolCall
|
||||
) and prev_response.call_id == call_id:
|
||||
call_response = prev_response
|
||||
break
|
||||
if call_response is None:
|
||||
raise ValueError(f"No call message found for {call_id}")
|
||||
msg = Message.from_author_and_content(
|
||||
Author.new(Role.TOOL, f"functions.{call_response.name}"),
|
||||
input_msg["output"])
|
||||
elif input_msg["type"] == "reasoning":
|
||||
content = input_msg["content"]
|
||||
assert len(content) == 1
|
||||
msg = Message.from_role_and_content(Role.ASSISTANT, content[0]["text"])
|
||||
elif input_msg["type"] == "function_call":
|
||||
msg = Message.from_role_and_content(Role.ASSISTANT,
|
||||
input_msg["arguments"])
|
||||
msg = msg.with_channel("commentary")
|
||||
msg = msg.with_recipient(f"functions.{input_msg['name']}")
|
||||
msg = msg.with_content_type("json")
|
||||
else:
|
||||
raise ValueError(f"Unknown input type: {input_msg['type']}")
|
||||
return msg
|
||||
|
||||
|
||||
class ConversationHistoryStore:
|
||||
|
||||
def __init__(self, resp_capacity: int = 16, max_conversations=32):
|
||||
self.response_capacity = resp_capacity
|
||||
self.conversation_capacity = resp_capacity * 4
|
||||
self.max_conversations = max_conversations
|
||||
|
||||
self.responses_lock = asyncio.Lock()
|
||||
self.responses: OrderedDict[str, ResponsesResponse] = OrderedDict()
|
||||
|
||||
self.conversations_lock = asyncio.Lock()
|
||||
self.conversations: OrderedDict[str, list[Message]] = OrderedDict()
|
||||
self.response_to_conversation: dict[str, str] = {}
|
||||
self.conversation_to_response: dict[str, str] = {}
|
||||
|
||||
async def load_response(self, resp_id: str) -> ResponsesResponse:
|
||||
responses_debug_log(f"ConversationHistoryStore loading resp: {resp_id}")
|
||||
async with self.responses_lock:
|
||||
return self.responses.get(resp_id)
|
||||
|
||||
async def store_response(self,
|
||||
resp: ResponsesResponse,
|
||||
resp_msgs: Optional[list[Message]] = [],
|
||||
prev_resp_id: Optional[str] = None) -> None:
|
||||
resp_id = resp.id
|
||||
responses_debug_log(f"ConversationHistoryStore storing resp: {resp_id}")
|
||||
async with self.responses_lock:
|
||||
self.responses[resp_id] = resp
|
||||
if len(self.responses) > self.response_capacity:
|
||||
self._pop_response()
|
||||
|
||||
async with self.conversations_lock:
|
||||
conversation_id: str
|
||||
if resp_id in self.response_to_conversation:
|
||||
conversation_id = self.response_to_conversation[resp_id]
|
||||
self.conversations[conversation_id].extend(resp_msgs)
|
||||
elif prev_resp_id is not None:
|
||||
conversation_id = self.response_to_conversation[prev_resp_id]
|
||||
self.conversations[conversation_id].extend(resp_msgs)
|
||||
while len(self.conversations[conversation_id]
|
||||
) > self.conversation_capacity:
|
||||
self._pop_conversation(resp_id)
|
||||
else:
|
||||
conversation_id = random_uuid()
|
||||
self.conversations[conversation_id] = resp_msgs
|
||||
|
||||
responses_debug_log(
|
||||
f" * storing at conversation id: {conversation_id}")
|
||||
|
||||
self.response_to_conversation[resp_id] = conversation_id
|
||||
self.conversation_to_response[conversation_id] = resp_id
|
||||
self._update_visited_conversation(conversation_id)
|
||||
|
||||
async def store_messages(self, resp_id: str, msgs: list[Message],
|
||||
prev_resp_id: Optional[str]):
|
||||
responses_debug_log(f"ConversationHistoryStore storing msg:")
|
||||
for msg in msgs:
|
||||
responses_debug_log(f" -> {msg.to_json()}")
|
||||
|
||||
async with self.conversations_lock:
|
||||
conversation_id: str
|
||||
if prev_resp_id is not None:
|
||||
conversation_id = self.response_to_conversation[prev_resp_id]
|
||||
else:
|
||||
conversation_id = random_uuid()
|
||||
|
||||
responses_debug_log(
|
||||
f" * storing at conversation: {conversation_id}")
|
||||
self.conversations[conversation_id] = msgs
|
||||
if len(self.conversations[conversation_id]
|
||||
) > self.conversation_capacity:
|
||||
self._pop_conversation(resp_id)
|
||||
|
||||
self.response_to_conversation[resp_id] = conversation_id
|
||||
self.conversation_to_response[conversation_id] = resp_id
|
||||
self._update_visited_conversation(conversation_id)
|
||||
|
||||
async def append_messages(self, resp_id: str, msgs: list[Message]):
|
||||
responses_debug_log(f"ConversationHistoryStore appending msgs:")
|
||||
for msg in msgs:
|
||||
responses_debug_log(f" -> {msg.to_json()}")
|
||||
|
||||
async with self.conversations_lock:
|
||||
assert resp_id in self.response_to_conversation
|
||||
conversation_id = self.response_to_conversation[resp_id]
|
||||
|
||||
responses_debug_log(
|
||||
f" * appending at conversation: {conversation_id}")
|
||||
self.conversations[conversation_id].extend(msgs)
|
||||
if len(self.conversations[conversation_id]
|
||||
) > self.conversation_capacity:
|
||||
self._pop_conversation(resp_id)
|
||||
self._update_visited_conversation(conversation_id)
|
||||
|
||||
async def get_conversation_history(self, resp_id: str) -> list[Message]:
|
||||
responses_debug_log(f"ConversationHistoryStore getting prev_msgs:")
|
||||
responses_debug_log(f" -> prev_resp_id: {resp_id}")
|
||||
async with self.conversations_lock:
|
||||
if resp_id in self.response_to_conversation:
|
||||
conversation_id = self.response_to_conversation[resp_id]
|
||||
self._update_visited_conversation(conversation_id)
|
||||
return self.conversations.get(conversation_id, [])
|
||||
|
||||
return []
|
||||
|
||||
def _update_visited_conversation(self, conversation_id) -> None:
|
||||
if conversation_id not in self.conversations:
|
||||
return
|
||||
|
||||
self.conversations.move_to_end(conversation_id)
|
||||
if len(self.conversations) > self.max_conversations:
|
||||
removed_id, _ = self.conversations.popitem(last=False)
|
||||
responses_debug_log(
|
||||
f"ConversationHistoryStore Removing conversation {removed_id}")
|
||||
removed_resp_id = self.conversation_to_response[removed_id]
|
||||
# The responses may have been removed due to response capacity
|
||||
if removed_resp_id in self.response_to_conversation:
|
||||
self.response_to_conversation.pop(removed_resp_id)
|
||||
self.conversation_to_response.pop(removed_id)
|
||||
|
||||
def _pop_conversation(self, resp_id) -> None:
|
||||
conversation_id = self.response_to_conversation.get(resp_id, None)
|
||||
if conversation_id is None:
|
||||
return
|
||||
|
||||
conversation = self.conversations[conversation_id]
|
||||
first_conversation_range = []
|
||||
for i, msg in enumerate(conversation):
|
||||
if msg.author.role == Role.USER:
|
||||
first_conversation_range.append(i)
|
||||
elif msg.channel == "final":
|
||||
first_conversation_range.append(i)
|
||||
break
|
||||
del conversation[
|
||||
first_conversation_range[0]:first_conversation_range[1] + 1]
|
||||
|
||||
def _pop_response(self) -> None:
|
||||
responses_debug_log(f"responses type: {type(self.responses)}")
|
||||
resp_id, _ = self.responses.popitem(last=False)
|
||||
if resp_id in self.response_to_conversation:
|
||||
self.response_to_conversation.pop(resp_id)
|
||||
|
||||
|
||||
def get_system_message(
|
||||
model_identity: Optional[str] = None,
|
||||
reasoning_effort: Optional[Literal["high", "medium", "low"]] = None,
|
||||
start_date: Optional[str] = None,
|
||||
browser_description: Optional[str] = None,
|
||||
python_description: Optional[str] = None,
|
||||
) -> Message:
|
||||
sys_msg_content = SystemContent.new()
|
||||
if model_identity is not None:
|
||||
sys_msg_content = sys_msg_content.with_model_identity(model_identity)
|
||||
if reasoning_effort is not None:
|
||||
sys_msg_content = sys_msg_content.with_reasoning_effort(
|
||||
REASONING_EFFORT[reasoning_effort])
|
||||
if start_date:
|
||||
sys_msg_content = sys_msg_content.with_conversation_start_date(
|
||||
start_date)
|
||||
if browser_description is not None:
|
||||
sys_msg_content = sys_msg_content.with_tools(browser_description)
|
||||
if python_description is not None:
|
||||
sys_msg_content = sys_msg_content.with_tools(python_description)
|
||||
sys_msg = Message.from_role_and_content(Role.SYSTEM, sys_msg_content)
|
||||
return sys_msg
|
||||
|
||||
|
||||
def get_developer_message(instructions: Optional[str] = None,
|
||||
tools: Optional[list[Tool]] = None) -> Message:
|
||||
dev_msg_content = DeveloperContent.new()
|
||||
if instructions is not None:
|
||||
dev_msg_content = dev_msg_content.with_instructions(instructions)
|
||||
if tools is not None:
|
||||
function_tools = []
|
||||
for tool in tools:
|
||||
if tool.type in ("web_search_preview", "code_interpreter"):
|
||||
# These are built-in tools that are added to the system message.
|
||||
pass
|
||||
elif tool.type == "function":
|
||||
function_tools.append(tool)
|
||||
else:
|
||||
raise ValueError(f"tool type {tool.type} not supported")
|
||||
if function_tools:
|
||||
function_tool_descriptions = [
|
||||
ToolDescription.new(
|
||||
name=tool.name,
|
||||
description=tool.description,
|
||||
parameters=tool.parameters,
|
||||
) for tool in function_tools
|
||||
]
|
||||
dev_msg_content = dev_msg_content.with_function_tools(
|
||||
function_tool_descriptions)
|
||||
dev_msg = Message.from_role_and_content(Role.DEVELOPER, dev_msg_content)
|
||||
return dev_msg
|
||||
|
||||
|
||||
def get_user_message(content: str) -> Message:
|
||||
return Message.from_role_and_content(Role.USER, content)
|
||||
|
||||
|
||||
def construct_harmony_messages(
|
||||
request: ResponsesRequest,
|
||||
prev_response: Optional[ResponsesResponse],
|
||||
prev_msgs: list[Message] = [],
|
||||
) -> list[Message]:
|
||||
"""Construct messages from request input, includes conversation history messages if exists."""
|
||||
messages: list[Message] = []
|
||||
if prev_response is None:
|
||||
# New conversation.
|
||||
reasoning_effort = (request.reasoning.effort
|
||||
if request.reasoning else None)
|
||||
sys_msg = get_system_message(reasoning_effort=reasoning_effort, )
|
||||
messages.append(sys_msg)
|
||||
dev_msg = get_developer_message(request.instructions, request.tools)
|
||||
messages.append(dev_msg)
|
||||
else:
|
||||
messages.extend(prev_msgs)
|
||||
# Append the new input.
|
||||
# Responses API supports simple text inputs without chat format.
|
||||
if isinstance(request.input, str):
|
||||
messages.append(get_user_message(request.input))
|
||||
else:
|
||||
if prev_response is not None:
|
||||
prev_outputs = copy(prev_response.output)
|
||||
else:
|
||||
prev_outputs = []
|
||||
for input_msg in request.input:
|
||||
msg = parse_response_input(input_msg, prev_outputs)
|
||||
if msg is not None:
|
||||
messages.append(msg)
|
||||
# User passes in a a tool call request and its output. We need
|
||||
# to add the tool call request to prev_outputs so that the
|
||||
# parse_response_input can find the tool call request when
|
||||
# parsing the tool call output.
|
||||
if isinstance(input_msg, ResponseFunctionToolCall):
|
||||
prev_outputs.append(input_msg)
|
||||
return messages
|
||||
|
||||
|
||||
def render_for_completion(messages: list[Message]) -> list[int]:
|
||||
conversation = Conversation.from_messages(messages)
|
||||
responses_debug_log("Rendering conversation:")
|
||||
responses_debug_log(conversation.to_json())
|
||||
token_ids = get_encoding().render_conversation_for_completion(
|
||||
conversation, Role.ASSISTANT)
|
||||
return token_ids
|
||||
|
||||
|
||||
def parse_output_tokens(tokens: list[int]) -> list[Message]:
|
||||
return get_encoding().parse_messages_from_completion_tokens(
|
||||
tokens, role=Role.ASSISTANT)
|
||||
|
||||
|
||||
def parse_output_message(message: Message) -> list[ResponseOutputItem]:
|
||||
"""
|
||||
Parse a Harmony message into a list of output response items.
|
||||
"""
|
||||
if message.author.role != "assistant":
|
||||
# This is a message from a tool to the assistant (e.g., search result).
|
||||
# Don't include it in the final output for now. This aligns with
|
||||
# OpenAI's behavior on models like o4-mini.
|
||||
return []
|
||||
|
||||
output_items: list[ResponseOutputItem] = []
|
||||
recipient = message.recipient
|
||||
if recipient is not None and recipient.startswith("browser."):
|
||||
if len(message.content) != 1:
|
||||
raise ValueError("Invalid number of contents in browser message")
|
||||
content = message.content[0]
|
||||
browser_call = json.loads(content.text)
|
||||
# TODO: translate to url properly!
|
||||
if recipient == "browser.search":
|
||||
action = ActionSearch(
|
||||
query=f"cursor:{browser_call.get('query', '')}", type="search")
|
||||
elif recipient == "browser.open":
|
||||
action = ActionOpenPage(url=f"cursor:{browser_call.get('url', '')}",
|
||||
type="open_page")
|
||||
elif recipient == "browser.find":
|
||||
action = ActionFind(pattern=browser_call["pattern"],
|
||||
url=f"cursor:{browser_call.get('url', '')}",
|
||||
type="find")
|
||||
else:
|
||||
raise ValueError(f"Unknown browser action: {recipient}")
|
||||
web_search_item = ResponseFunctionWebSearch(
|
||||
id=f"ws_{random_uuid()}",
|
||||
action=action,
|
||||
status="completed",
|
||||
type="web_search_call",
|
||||
)
|
||||
output_items.append(web_search_item)
|
||||
elif message.channel == "analysis":
|
||||
for content in message.content:
|
||||
reasoning_item = ResponseReasoningItem(
|
||||
id=f"rs_{random_uuid()}",
|
||||
summary=[],
|
||||
type="reasoning",
|
||||
content=[Content(text=content.text, type="reasoning_text")],
|
||||
status=None,
|
||||
)
|
||||
output_items.append(reasoning_item)
|
||||
elif message.channel == "commentary":
|
||||
if message.recipient is None:
|
||||
pass
|
||||
elif message.recipient.startswith("functions."):
|
||||
function_name = message.recipient.split(".")[-1]
|
||||
for content in message.content:
|
||||
random_id = random_uuid()
|
||||
response_item = ResponseFunctionToolCall(
|
||||
arguments=content.text,
|
||||
call_id=f"call_{random_id}",
|
||||
type="function_call",
|
||||
name=function_name,
|
||||
id=f"fc_{random_id}",
|
||||
)
|
||||
output_items.append(response_item)
|
||||
elif message.recipient.startswith(
|
||||
"python") or message.recipient.startswith("browser"):
|
||||
for content in message.content:
|
||||
reasoning_item = ResponseReasoningItem(
|
||||
id=f"rs_{random_uuid()}",
|
||||
summary=[],
|
||||
type="reasoning",
|
||||
content=[Content(text=content.text, type="reasoning_text")],
|
||||
status=None,
|
||||
)
|
||||
output_items.append(reasoning_item)
|
||||
else:
|
||||
raise ValueError(f"Unknown recipient: {message.recipient}")
|
||||
elif message.channel == "final":
|
||||
contents = []
|
||||
for content in message.content:
|
||||
output_text = ResponseOutputText(
|
||||
text=content.text,
|
||||
annotations=[], # TODO
|
||||
type="output_text",
|
||||
logprobs=None, # TODO
|
||||
)
|
||||
contents.append(output_text)
|
||||
text_item = ResponseOutputMessage(
|
||||
id=f"msg_{random_uuid()}",
|
||||
content=contents,
|
||||
role=message.author.role,
|
||||
status="completed",
|
||||
type="message",
|
||||
)
|
||||
output_items.append(text_item)
|
||||
else:
|
||||
raise ValueError(f"Unknown channel: {message.channel}")
|
||||
return output_items
|
||||
|
||||
|
||||
def finish_reason_mapping(finish_reason: str) -> str:
|
||||
match finish_reason:
|
||||
case 'stop':
|
||||
return 'completed'
|
||||
case 'length':
|
||||
return 'incomplete'
|
||||
case 'timeout':
|
||||
return 'failed'
|
||||
case 'cancelled':
|
||||
return 'cancelled'
|
||||
|
||||
raise RuntimeError("Should never reach here!")
|
||||
|
||||
|
||||
async def request_preprocess(request: ResponsesRequest,
|
||||
prev_response: Optional[ResponsesResponse],
|
||||
harmony_adapter: HarmonyAdapter,
|
||||
conversation_store: ConversationHistoryStore,
|
||||
enable_store=False):
|
||||
# TODO: fix default_max_tokens
|
||||
sampling_params = request.to_sampling_params(
|
||||
default_max_tokens=int(16384),
|
||||
default_sampling_params={
|
||||
"stop_token_ids": harmony_adapter.get_stop_tokens()
|
||||
})
|
||||
|
||||
prev_response_id = request.previous_response_id
|
||||
|
||||
# TODO: better way to enable metrics
|
||||
if len(os.getenv("TRTLLM_KVCACHE_TIME_OUTPUT_PATH", "")) > 0:
|
||||
sampling_params.return_perf_metrics = True
|
||||
|
||||
prev_msgs = []
|
||||
if enable_store:
|
||||
prev_msgs = await conversation_store.get_conversation_history(
|
||||
prev_response_id)
|
||||
|
||||
responses_debug_log(f"Prev msgs:")
|
||||
for msg in prev_msgs:
|
||||
responses_debug_log(f" -> {msg.to_json()}")
|
||||
|
||||
messages = construct_harmony_messages(request,
|
||||
prev_response,
|
||||
prev_msgs=prev_msgs)
|
||||
|
||||
if enable_store and request.store:
|
||||
# Remove reasoning messages to save token usage during multi-turn conversation
|
||||
msgs_to_store = [msg for msg in messages if msg.channel != "analysis"]
|
||||
await conversation_store.store_messages(request.request_id,
|
||||
msgs_to_store, prev_response_id)
|
||||
|
||||
input_tokens = render_for_completion(messages)
|
||||
|
||||
responses_debug_log("======= Complete Inputs to model =======")
|
||||
responses_debug_log(decode_tokens(input_tokens))
|
||||
responses_debug_log("========================================")
|
||||
return input_tokens, sampling_params
|
||||
|
||||
|
||||
async def create_response(
|
||||
generator,
|
||||
request: ResponsesRequest,
|
||||
sampling_params,
|
||||
model_name: str,
|
||||
conversation_store: ConversationHistoryStore,
|
||||
generation_result: RequestOutput = None,
|
||||
enable_store=False,
|
||||
create_time: int = None,
|
||||
) -> ResponsesResponse:
|
||||
|
||||
final_res: Optional[RequestOutput] = None
|
||||
response_creation_time = create_time if create_time is not None else int(
|
||||
time.time())
|
||||
prev_response_id = request.previous_response_id
|
||||
|
||||
if generation_result is not None:
|
||||
final_res = generation_result
|
||||
else:
|
||||
final_res = await generator
|
||||
|
||||
if final_res is None:
|
||||
raise RuntimeError("No output generated or provided")
|
||||
|
||||
responses_debug_log("================================================")
|
||||
responses_debug_log("RAW MODEL OUTPUT:")
|
||||
responses_debug_log(final_res.outputs)
|
||||
responses_debug_log("================================================")
|
||||
|
||||
output_messages = parse_output_tokens(final_res.outputs[0].token_ids)
|
||||
|
||||
responses_debug_log(f"output messages: {len(output_messages)}")
|
||||
for msg in output_messages:
|
||||
responses_debug_log(f" -> {msg.to_json()}")
|
||||
|
||||
# prepare responses output
|
||||
output_content = []
|
||||
for msg in output_messages:
|
||||
output_content.extend(parse_output_message(msg))
|
||||
|
||||
response = ResponsesResponse.from_request(
|
||||
request=request,
|
||||
sampling_params=sampling_params,
|
||||
model_name=model_name,
|
||||
created_time=response_creation_time,
|
||||
output=output_content,
|
||||
status=finish_reason_mapping(final_res.outputs[0].finish_reason),
|
||||
)
|
||||
|
||||
if enable_store and request.store:
|
||||
await conversation_store.store_response(resp=response,
|
||||
resp_msgs=output_messages,
|
||||
prev_resp_id=prev_response_id)
|
||||
|
||||
responses_debug_log("========== Response ===========")
|
||||
responses_debug_log(response)
|
||||
responses_debug_log("===============================")
|
||||
return response
|
||||
|
||||
|
||||
async def process_streaming_events(
|
||||
request: ResponsesRequest,
|
||||
sampling_params: SamplingParams,
|
||||
generator,
|
||||
harmony_adapter: HarmonyAdapter,
|
||||
model_name: str,
|
||||
conversation_store: ConversationHistoryStore,
|
||||
create_time: int = None,
|
||||
enable_store=False) -> AsyncGenerator[str, None]:
|
||||
sequence_number = 0
|
||||
response_creation_time = create_time if create_time is not None else int(
|
||||
time.time())
|
||||
final_res: Optional[RequestOutput] = None
|
||||
|
||||
def _send_event(event: OpenAIBaseModel):
|
||||
nonlocal sequence_number
|
||||
# Set sequence_number if the event has this attribute
|
||||
if hasattr(event, 'sequence_number'):
|
||||
event.sequence_number = sequence_number
|
||||
sequence_number += 1
|
||||
# Get event type from the event's type field if it exists
|
||||
event_type = getattr(event, 'type', 'unknown')
|
||||
return (f"event: {event_type}\n"
|
||||
f"data: {event.model_dump_json(indent=None)}\n\n")
|
||||
|
||||
current_content_index = 0 # FIXME: this number is never changed
|
||||
current_output_index = 0
|
||||
current_item_id = "" # FIXME: this number is never changed
|
||||
sent_output_item_added = False
|
||||
|
||||
initial_response = ResponsesResponse.from_request(
|
||||
request,
|
||||
sampling_params,
|
||||
model_name=model_name,
|
||||
created_time=response_creation_time,
|
||||
output=[],
|
||||
status="in_progress",
|
||||
usage=None,
|
||||
).model_dump()
|
||||
yield _send_event(
|
||||
ResponseCreatedEvent(
|
||||
type="response.created",
|
||||
sequence_number=-1,
|
||||
response=initial_response,
|
||||
))
|
||||
yield _send_event(
|
||||
ResponseInProgressEvent(
|
||||
type="response.in_progress",
|
||||
sequence_number=-1,
|
||||
response=initial_response,
|
||||
))
|
||||
|
||||
tools = [tool.model_dump() for tool in request.tools]
|
||||
stream_request_id = f"responses-api-{request.request_id}"
|
||||
async for res in generator:
|
||||
final_res = res
|
||||
output = res.outputs[0]
|
||||
|
||||
messages = harmony_adapter.stateful_stream_harmony_tokens_to_openai_messages(
|
||||
stream_request_id, output.token_ids_diff, tools,
|
||||
request.tool_choice)
|
||||
stream_state = harmony_adapter.get_stream_state(stream_request_id)
|
||||
assert stream_state is not None
|
||||
parser = stream_state.get_parser()
|
||||
|
||||
if parser.state == StreamState.EXPECT_START:
|
||||
current_output_index += 1
|
||||
sent_output_item_added = False
|
||||
|
||||
if len(messages) > 0:
|
||||
previous_item = messages[-1]
|
||||
if previous_item.recipient is not None:
|
||||
# Deal with tool call here
|
||||
pass
|
||||
elif previous_item.channel == "analysis":
|
||||
reasoning_item = ResponseReasoningItem(
|
||||
type="reasoning",
|
||||
content=[
|
||||
Content(
|
||||
text=previous_item.content[0].text,
|
||||
type="reasoning_text",
|
||||
),
|
||||
],
|
||||
status="completed",
|
||||
id=current_item_id,
|
||||
summary=[],
|
||||
)
|
||||
yield _send_event(
|
||||
ResponseReasoningTextDoneEvent(
|
||||
type="response.reasoning_text.done",
|
||||
item_id=current_item_id,
|
||||
sequence_number=-1,
|
||||
output_index=current_output_index,
|
||||
content_index=current_content_index,
|
||||
text=previous_item.content[0].text,
|
||||
))
|
||||
yield _send_event(
|
||||
ResponseOutputItemDoneEvent(
|
||||
type="response.output_item.done",
|
||||
sequence_number=-1,
|
||||
output_index=current_output_index,
|
||||
item=reasoning_item,
|
||||
))
|
||||
elif previous_item.channel == "final":
|
||||
text_content = ResponseOutputText(
|
||||
type="output_text",
|
||||
text=previous_item.content[0].text,
|
||||
annotations=[],
|
||||
)
|
||||
yield _send_event(
|
||||
ResponseTextDoneEvent(
|
||||
type="response.output_text.done",
|
||||
sequence_number=-1,
|
||||
output_index=current_output_index,
|
||||
content_index=current_content_index,
|
||||
text=previous_item.content[0].text,
|
||||
logprobs=[],
|
||||
item_id=current_item_id,
|
||||
))
|
||||
yield _send_event(
|
||||
ResponseContentPartDoneEvent(
|
||||
type="response.content_part.done",
|
||||
sequence_number=-1,
|
||||
item_id=current_item_id,
|
||||
output_index=current_output_index,
|
||||
content_index=current_content_index,
|
||||
part=text_content,
|
||||
))
|
||||
yield _send_event(
|
||||
ResponseOutputItemDoneEvent(
|
||||
type="response.output_item.done",
|
||||
sequence_number=-1,
|
||||
output_index=current_output_index,
|
||||
item=ResponseOutputMessage(
|
||||
id=current_item_id,
|
||||
type="message",
|
||||
role="assistant",
|
||||
content=[text_content],
|
||||
status="completed",
|
||||
),
|
||||
))
|
||||
|
||||
if parser.last_content_delta:
|
||||
if (parser.current_channel == "final"
|
||||
and parser.current_recipient is None):
|
||||
if not sent_output_item_added:
|
||||
sent_output_item_added = True
|
||||
yield _send_event(
|
||||
ResponseOutputItemAddedEvent(
|
||||
type="response.output_item.added",
|
||||
sequence_number=-1,
|
||||
output_index=current_output_index,
|
||||
item=ResponseOutputMessage(
|
||||
id=current_item_id,
|
||||
type="message",
|
||||
role="assistant",
|
||||
content=[],
|
||||
status="in_progress",
|
||||
),
|
||||
))
|
||||
yield _send_event(
|
||||
ResponseContentPartAddedEvent(
|
||||
type="response.content_part.added",
|
||||
sequence_number=-1,
|
||||
output_index=current_output_index,
|
||||
item_id=current_item_id,
|
||||
content_index=current_content_index,
|
||||
part=ResponseOutputText(
|
||||
type="output_text",
|
||||
text="",
|
||||
annotations=[],
|
||||
logprobs=[],
|
||||
),
|
||||
))
|
||||
yield _send_event(
|
||||
ResponseTextDeltaEvent(
|
||||
type="response.output_text.delta",
|
||||
sequence_number=-1,
|
||||
content_index=current_content_index,
|
||||
output_index=current_output_index,
|
||||
item_id=current_item_id,
|
||||
delta=parser.last_content_delta,
|
||||
# TODO, use logprobs from ctx.last_request_output
|
||||
logprobs=[],
|
||||
))
|
||||
elif (parser.current_channel == "analysis"
|
||||
and parser.current_recipient is None):
|
||||
if not sent_output_item_added:
|
||||
sent_output_item_added = True
|
||||
yield _send_event(
|
||||
ResponseOutputItemAddedEvent(
|
||||
type="response.output_item.added",
|
||||
sequence_number=-1,
|
||||
output_index=current_output_index,
|
||||
item=ResponseReasoningItem(
|
||||
type="reasoning",
|
||||
id=current_item_id,
|
||||
summary=[],
|
||||
status="in_progress",
|
||||
),
|
||||
))
|
||||
yield _send_event(
|
||||
ResponseContentPartAddedEvent(
|
||||
type="response.content_part.added",
|
||||
sequence_number=-1,
|
||||
output_index=current_output_index,
|
||||
item_id=current_item_id,
|
||||
content_index=current_content_index,
|
||||
part=ResponseOutputText(
|
||||
type="output_text",
|
||||
text="",
|
||||
annotations=[],
|
||||
logprobs=[],
|
||||
),
|
||||
))
|
||||
yield _send_event(
|
||||
ResponseReasoningTextDeltaEvent(
|
||||
type="response.reasoning_text.delta",
|
||||
item_id=current_item_id,
|
||||
output_index=current_output_index,
|
||||
content_index=current_content_index,
|
||||
delta=parser.last_content_delta,
|
||||
sequence_number=-1,
|
||||
))
|
||||
|
||||
# TODO(JunyiXu-nv): support built-in tools(python/browser/code interpreter)
|
||||
|
||||
final_response = await create_response(generator, request, sampling_params,
|
||||
model_name, conversation_store,
|
||||
final_res, enable_store,
|
||||
response_creation_time)
|
||||
|
||||
yield _send_event(
|
||||
ResponseCompletedEvent(
|
||||
type="response.completed",
|
||||
sequence_number=-1,
|
||||
response=final_response.model_dump(),
|
||||
))
|
||||
@ -1513,6 +1513,13 @@ def test_openai_chat_harmony(llm_root, llm_venv):
|
||||
str(test_root / "_test_openai_chat_harmony.py")])
|
||||
|
||||
|
||||
def test_openai_responses(llm_root, llm_venv):
|
||||
test_root = unittest_path() / "llmapi" / "apps"
|
||||
llm_venv.run_cmd(
|
||||
["-m", "pytest",
|
||||
str(test_root / "_test_openai_responses.py")])
|
||||
|
||||
|
||||
def test_openai_prometheus(llm_root, llm_venv):
|
||||
test_root = unittest_path() / "llmapi" / "apps"
|
||||
llm_venv.run_cmd(
|
||||
|
||||
@ -104,6 +104,7 @@ l0_h100:
|
||||
- test_e2e.py::test_trtllm_bench_request_rate_and_concurrency[enable_concurrency-enable_request_rate] # negative test
|
||||
- test_e2e.py::test_trtllm_bench_help_sanity[meta-llama/Llama-3.1-8B]
|
||||
- test_e2e.py::test_openai_chat_harmony
|
||||
- test_e2e.py::test_openai_responses
|
||||
- test_e2e.py::test_ptp_quickstart_multimodal[gemma-3-27b-it-gemma/gemma-3-27b-it-image-True]
|
||||
# ------------- AutoDeploy tests ---------------
|
||||
- accuracy/test_llm_api_autodeploy.py::TestLlama3_1_8B::test_auto_dtype
|
||||
|
||||
@ -14,10 +14,18 @@ def model():
|
||||
return "gpt_oss/gpt-oss-20b/"
|
||||
|
||||
|
||||
@pytest.fixture(scope="module",
|
||||
params=[0, 2],
|
||||
ids=["disable_processpool", "enable_processpool"])
|
||||
def num_postprocess_workers(request):
|
||||
return request.param
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def server(model: str):
|
||||
def server(model: str, num_postprocess_workers: int):
|
||||
model_path = get_model_path(model)
|
||||
with RemoteOpenAIServer(model_path) as remote_server:
|
||||
args = ["--num_postprocess_workers", f"{num_postprocess_workers}"]
|
||||
with RemoteOpenAIServer(model_path, args) as remote_server:
|
||||
yield remote_server
|
||||
|
||||
|
||||
@ -147,6 +155,10 @@ async def test_streaming(client: openai.AsyncOpenAI, model: str):
|
||||
collected_chunks = []
|
||||
collected_messages = []
|
||||
async for chunk in response:
|
||||
# Last streaming response will only contains usage info
|
||||
if len(chunk.choices) <= 0:
|
||||
continue
|
||||
|
||||
collected_chunks.append(chunk)
|
||||
collected_messages.append(chunk.choices[0].delta)
|
||||
|
||||
@ -198,6 +210,10 @@ async def test_streaming_tool_call(client: openai.AsyncOpenAI, model: str):
|
||||
reasoning_chunks: list[str] = []
|
||||
tool_arg_chunks: list[str] = []
|
||||
async for chunk in response:
|
||||
# Last streaming response will only contains usage info
|
||||
if len(chunk.choices) <= 0:
|
||||
continue
|
||||
|
||||
delta = chunk.choices[0].delta
|
||||
if hasattr(delta, "tool_calls") and delta.tool_calls:
|
||||
function = delta.tool_calls[0].function
|
||||
|
||||
241
tests/unittest/llmapi/apps/_test_openai_responses.py
Normal file
241
tests/unittest/llmapi/apps/_test_openai_responses.py
Normal file
@ -0,0 +1,241 @@
|
||||
import json
|
||||
|
||||
import openai
|
||||
import pytest
|
||||
from openai.types.responses import (ResponseCompletedEvent,
|
||||
ResponseReasoningTextDeltaEvent,
|
||||
ResponseTextDeltaEvent)
|
||||
|
||||
from ..test_llm import get_model_path
|
||||
from .openai_server import RemoteOpenAIServer
|
||||
|
||||
pytestmark = pytest.mark.threadleak(enabled=False)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module", ids=["GPT-OSS-20B"])
|
||||
def model():
|
||||
return "gpt_oss/gpt-oss-20b/"
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def server(model: str):
|
||||
model_path = get_model_path(model)
|
||||
with RemoteOpenAIServer(model_path) as remote_server:
|
||||
yield remote_server
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def client(server: RemoteOpenAIServer):
|
||||
return server.get_async_client()
|
||||
|
||||
|
||||
def check_reponse(response, prefix=""):
|
||||
reasoning_exist, message_exist = False, False
|
||||
for output in response.output:
|
||||
if output.type == "reasoning":
|
||||
reasoning_exist = True
|
||||
elif output.type == "message":
|
||||
message_exist = True
|
||||
|
||||
assert reasoning_exist, f"{prefix}Reasoning content not exists!"
|
||||
assert message_exist, f"{prefix}Message content not exists!"
|
||||
|
||||
|
||||
def check_tool_calling(response, first_resp=True, prefix=""):
|
||||
reasoning_exist, tool_call_exist, message_exist = False, False, False
|
||||
function_call = None
|
||||
for output in response.output:
|
||||
if output.type == "reasoning":
|
||||
reasoning_exist = True
|
||||
elif output.type == "function_call":
|
||||
tool_call_exist = True
|
||||
function_call = output
|
||||
elif output.type == "message":
|
||||
message_exist = True
|
||||
|
||||
if first_resp:
|
||||
assert reasoning_exist and tool_call_exist, f"{prefix}Invalid tool calling 1st response"
|
||||
assert not message_exist, f"{prefix}Invalid tool calling 1st response"
|
||||
|
||||
return function_call
|
||||
else:
|
||||
assert reasoning_exist and message_exist, f"{prefix}Invalid tool calling 2nd response"
|
||||
assert not tool_call_exist, f"{prefix}Invalid tool calling 2nd response"
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="module")
|
||||
async def test_reasoning(client: openai.AsyncOpenAI, model: str):
|
||||
response = await client.responses.create(
|
||||
model=model, input="Which one is larger as numeric, 9.9 or 9.11?")
|
||||
|
||||
check_reponse(response, "test_reasoning: ")
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="module")
|
||||
async def test_reasoning_effort(client: openai.AsyncOpenAI, model: str):
|
||||
for effort in ["low", "medium", "high"]:
|
||||
response = await client.responses.create(
|
||||
model=model,
|
||||
instructions="Use less than 1024 tokens for reasoning",
|
||||
input="Which one is larger as numeric, 9.9 or 9.11?",
|
||||
reasoning={"effort": effort})
|
||||
check_reponse(response, f"test_reasoning_effort_{effort}: ")
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="module")
|
||||
async def test_chat(client: openai.AsyncOpenAI, model: str):
|
||||
response = await client.responses.create(model=model,
|
||||
input=[{
|
||||
"role":
|
||||
"developer",
|
||||
"content":
|
||||
"Respond in Chinese."
|
||||
}, {
|
||||
"role": "user",
|
||||
"content": "Hello!"
|
||||
}, {
|
||||
"role":
|
||||
"assistant",
|
||||
"content":
|
||||
"Hello! How can I help you?"
|
||||
}, {
|
||||
"role": "user",
|
||||
"content": "Tell me a joke."
|
||||
}])
|
||||
check_reponse(response, "test_chat: ")
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="module")
|
||||
async def test_multi_turn_chat(client: openai.AsyncOpenAI, model: str):
|
||||
response = await client.responses.create(model=model,
|
||||
input="What is the answer of 1+1?")
|
||||
check_reponse(response, "test_multi_turn_chat_1: ")
|
||||
|
||||
response_2 = await client.responses.create(
|
||||
model=model,
|
||||
input="What is the answer of previous question?",
|
||||
previous_response_id=response.id)
|
||||
check_reponse(response_2, "test_multi_turn_chat_2: ")
|
||||
|
||||
|
||||
def get_current_weather(location: str, format: str = "celsius") -> dict:
|
||||
return {"sunny": True, "temperature": 20 if format == "celsius" else 68}
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="module")
|
||||
async def test_tool_calls(client: openai.AsyncOpenAI, model: str):
|
||||
tool_get_current_weather = {
|
||||
"type": "function",
|
||||
"name": "get_current_weather",
|
||||
"description": "Gets the current weather in the provided location.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA",
|
||||
},
|
||||
"format": {
|
||||
"type": "string",
|
||||
"description": "default: celsius",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
},
|
||||
},
|
||||
"required": ["location"],
|
||||
}
|
||||
}
|
||||
messages = [{"role": "user", "content": "What is the weather like in SF?"}]
|
||||
response = await client.responses.create(
|
||||
model=model,
|
||||
input=messages,
|
||||
tools=[tool_get_current_weather],
|
||||
)
|
||||
messages.extend(response.output)
|
||||
function_call = check_tool_calling(response, True, "test_tool_calls: ")
|
||||
|
||||
assert function_call.name == "get_current_weather"
|
||||
|
||||
args = json.loads(function_call.arguments)
|
||||
answer = get_current_weather(**args)
|
||||
messages.append({
|
||||
"type": "function_call_output",
|
||||
"call_id": function_call.call_id,
|
||||
"output": json.dumps(answer),
|
||||
})
|
||||
|
||||
response = await client.responses.create(model=model,
|
||||
input=messages,
|
||||
tools=[tool_get_current_weather])
|
||||
|
||||
check_tool_calling(response, False, "test_tool_calls: ")
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="module")
|
||||
async def test_streaming(client: openai.AsyncOpenAI, model: str):
|
||||
stream = await client.responses.create(
|
||||
model=model,
|
||||
input="Explain the theory of relativity in brief.",
|
||||
stream=True,
|
||||
)
|
||||
|
||||
reasoning_deltas, message_deltas = list(), list()
|
||||
async for event in stream:
|
||||
if isinstance(event, ResponseTextDeltaEvent):
|
||||
message_deltas.append(event.delta)
|
||||
elif isinstance(event, ResponseReasoningTextDeltaEvent):
|
||||
reasoning_deltas.append(event.delta)
|
||||
|
||||
full_response = "".join(message_deltas)
|
||||
full_reasoning_response = "".join(reasoning_deltas)
|
||||
assert full_response
|
||||
assert full_reasoning_response
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="module")
|
||||
async def test_streaming_tool_call(client: openai.AsyncOpenAI, model: str):
|
||||
tool_get_current_weather = {
|
||||
"type": "function",
|
||||
"name": "get_current_weather",
|
||||
"description": "Gets the current weather in the provided location.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA",
|
||||
},
|
||||
"format": {
|
||||
"type": "string",
|
||||
"description": "default: celsius",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
},
|
||||
},
|
||||
"required": ["location"],
|
||||
}
|
||||
}
|
||||
messages = [{"role": "user", "content": "What is the weather like in SF?"}]
|
||||
stream = await client.responses.create(
|
||||
model=model,
|
||||
input=messages,
|
||||
tools=[tool_get_current_weather],
|
||||
stream=True,
|
||||
)
|
||||
|
||||
function_call = None
|
||||
reasoning_deltas = list()
|
||||
async for event in stream:
|
||||
if isinstance(event, ResponseCompletedEvent):
|
||||
for output in event.response.output:
|
||||
if output.type == "function_call":
|
||||
function_call = output
|
||||
elif isinstance(event, ResponseReasoningTextDeltaEvent):
|
||||
reasoning_deltas.append(event.delta)
|
||||
|
||||
reasoning = "".join(reasoning_deltas)
|
||||
tool_args = json.loads(function_call.arguments)
|
||||
|
||||
assert function_call.name == "get_current_weather", "wrong function calling name"
|
||||
assert tool_args, "tool args not exists!"
|
||||
assert reasoning, "reasoning not exists!"
|
||||
|
||||
get_current_weather(**tool_args)
|
||||
Loading…
Reference in New Issue
Block a user