[None][feat] Cherry-pick Responses API and multiple postprocess workers support for chat harmony (#7600)

Signed-off-by: Junyi Xu <219237550+JunyiXu-nv@users.noreply.github.com>
Co-authored-by: Pengyun Lin <81065165+LinPoly@users.noreply.github.com>
Co-authored-by: Tao Li @ NVIDIA <tali@nvidia.com>
This commit is contained in:
JunyiXu-nv 2025-09-09 19:28:29 +08:00 committed by GitHub
parent d60dad6b9d
commit ac0df0a393
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 1764 additions and 136 deletions

View File

@ -6,7 +6,7 @@ import re
import time
import traceback
import uuid
from typing import Any, AsyncGenerator, Literal
from typing import Any, List, Literal
from openai_harmony import (Author, Conversation, DeveloperContent,
HarmonyEncodingName, HarmonyError, Message,
@ -14,15 +14,15 @@ from openai_harmony import (Author, Conversation, DeveloperContent,
SystemContent, TextContent, ToolDescription,
load_harmony_encoding)
from tensorrt_llm.llmapi import RequestOutput
from tensorrt_llm.logger import logger
# yapf: disable
from .openai_protocol import (ChatCompletionMessageParam, ChatCompletionRequest,
from .openai_protocol import (ChatCompletionMessageParam,
ChatCompletionResponse,
ChatCompletionResponseChoice,
ChatCompletionResponseStreamChoice,
ChatCompletionStreamResponse, ChatMessage,
ChatCompletionStreamResponse,
ChatCompletionToolsParam, ChatMessage,
DeltaFunctionCall, DeltaMessage, DeltaToolCall,
UsageInfo)
@ -57,7 +57,8 @@ class HarmonyStreamState:
# Normal case: filter based on available tools
self.should_filter_tools = True
self.available_tools = {
tool.get("function", {}).get("name", "")
tool.get("function", {}).get("name", "") if tool.get(
"name", None) is None else tool.get("name")
for tool in available_tools
}
self.available_tools.discard("")
@ -78,6 +79,9 @@ class HarmonyStreamState:
logger.debug("Created HarmonyStreamState for request %s", request_id)
def get_parser(self) -> StreamableParser:
return self.parser
def process_token_batch(self, tokens: list[int]) -> list[dict[str, Any]]:
"""
Process a batch of tokens while maintaining parsing state.
@ -125,6 +129,42 @@ class HarmonyStreamState:
return deltas
def process_token_batch_to_messages(self,
tokens: list[int]) -> list[Message]:
"""
Process a batch of tokens while maintaining parsing state.
Returns OpenAI Messages for Responses API
"""
self.tokens_processed += len(tokens)
for token in tokens:
# Store previous state for transition detection
prev_channel = self.parser.current_channel
prev_recipient = self.parser.current_recipient
# Process the token
self.parser.process(token)
# Detect channel/recipient transitions AFTER processing each token
channel_changed = prev_channel != self.parser.current_channel
recipient_changed = prev_recipient != self.parser.current_recipient
if channel_changed or recipient_changed:
# Mark any active tool calls as completed if we're leaving a tool call
if prev_channel == "commentary" and prev_recipient and "functions." in str(
prev_recipient):
func_name = str(prev_recipient).split("functions.")[-1]
for tool_id, tool_info in self.tool_calls.items():
if tool_info["name"] == func_name and tool_info.get(
"active", True):
tool_info["active"] = False
# Reset channel state for new channel
self.channel_started = False
self.current_channel_state = None
return self.parser.messages
def _create_closing_token_delta(self) -> dict[str, Any] | None:
"""Create closing token delta for channel transition."""
if not self.current_channel_state or not self.channel_started:
@ -317,6 +357,9 @@ class HarmonyAdapter:
"<|constrain|>": 200009,
}
def get_stream_state(self, request_id: str) -> HarmonyStreamState | None:
return self._stream_states.get(request_id, None)
def get_stop_tokens(self) -> list[int]:
"""
Return the list of stop token IDs for Harmony format.
@ -1214,6 +1257,42 @@ class HarmonyAdapter:
# Return empty deltas to continue processing
return []
def stateful_stream_harmony_tokens_to_openai_messages(
self,
request_id: str,
tokens: list[int],
available_tools: list[dict[str, Any]] | None = None,
tool_choice: str | None = None) -> list[Message]:
"""
Process tokens using stateful parsing.
This method maintains persistent state across multiple calls for the same request,
ensuring proper channel transitions and tool call handling.
Args:
request_id: Request ID to maintain state per request
tokens: New tokens from this iteration
available_tools: Available tools for filtering
Returns:
List of OpenAI Messages
"""
stream_state = self._stream_states.get(request_id, None)
if stream_state is None:
stream_state = self.create_stream_state(request_id, available_tools,
tool_choice)
try:
messages = stream_state.process_token_batch_to_messages(tokens)
return messages
except (HarmonyError, UnicodeDecodeError, ValueError):
logger.error(
f"Streaming: Failed to process token batch of {len(tokens)} tokens for request {request_id}",
)
logger.debug(f"Problematic streaming tokens: {tokens}")
return []
def create_openai_streaming_response(
self,
request_id: str,
@ -1406,36 +1485,72 @@ class HarmonyAdapter:
return True
async def handle_streaming_response(
harmony_adapter: HarmonyAdapter,
generator: RequestOutput,
request_id: str,
request: ChatCompletionRequest,
) -> AsyncGenerator[str, None]:
"""Handle streaming response with harmony format."""
_SERVE_HARMONY_ADAPTER: HarmonyAdapter = None
def get_harmony_adapter():
global _SERVE_HARMONY_ADAPTER
if _SERVE_HARMONY_ADAPTER is None:
_SERVE_HARMONY_ADAPTER = HarmonyAdapter()
return _SERVE_HARMONY_ADAPTER
def handle_streaming_response(tools: List[ChatCompletionToolsParam],
tool_choice: str, outputs: List, model: str,
request_id: str, done: bool,
num_prompt_tokens: int):
first_iteration = True
async for res in generator:
output = res.outputs[0]
output = outputs[0]
# Convert tools to dictionary format for harmony adapter (standard pattern)
tools_dict = None
if request.tools:
tools_dict = [tool.model_dump() for tool in request.tools]
# Convert tools to dictionary format for harmony adapter (standard pattern)
tools_dict = None
harmony_adapter = get_harmony_adapter()
if tools:
tools_dict = [tool.model_dump() for tool in tools]
# Get tool_choice from request - if "none", don't pass tools to parser
tool_choice = getattr(request, 'tool_choice', None)
if tool_choice == "none":
tools_for_parser = None
# Get tool_choice from request - if "none", don't pass tools to parser
if tool_choice == "none":
tools_for_parser = None
else:
tools_for_parser = tools_dict
# Create OpenAI streaming responses
try:
res = []
if done:
# Clean up state
harmony_adapter.cleanup_stream_state(request_id)
usage_info = _create_usage_info(num_prompt_tokens, outputs)
# Send final message with finish_reason
final_response = ChatCompletionStreamResponse(
model=model,
choices=[
ChatCompletionResponseStreamChoice(
index=0,
delta=DeltaMessage(),
finish_reason=output.finish_reason,
stop_reason=output.stop_reason)
],
)
final_response_json = final_response.model_dump_json(
exclude_none=True)
final_usage_chunk = ChatCompletionStreamResponse(choices=[],
model=model,
usage=usage_info)
final_usage_json = final_usage_chunk.model_dump_json(
exclude_none=True)
res.append(f"data: {final_response_json}\n\n")
res.append(f"data: {final_usage_json}\n\n")
else:
tools_for_parser = tools_dict
# Create OpenAI streaming responses
try:
responses = harmony_adapter.create_openai_streaming_response(
request_id=request_id,
tokens=output.token_ids_diff,
available_tools=tools_for_parser,
model_name=request.model,
model_name=model,
tool_choice=tool_choice)
# Send first response after receiving the first output
if first_iteration:
@ -1446,64 +1561,44 @@ async def handle_streaming_response(
delta=first_delta)
first_response = ChatCompletionStreamResponse(
model=request.model,
model=model,
choices=[choice],
)
response_json = first_response.model_dump_json(
exclude_none=True)
yield f"data: {response_json}\n\n"
res.append(f"data: {response_json}\n\n")
for response in responses:
yield response
res.extend(responses)
except Exception as e:
logger.error(f"Failed to create OpenAI streaming response: {e}")
logger.debug(f"Streaming error details: {traceback.format_exc()}")
# Clean up state
harmony_adapter.cleanup_stream_state(request_id)
raise e
return res
# Clean up state
harmony_adapter.cleanup_stream_state(request_id)
# Send final message with finish_reason
output = generator.outputs[0]
final_response = ChatCompletionStreamResponse(
model=request.model,
choices=[
ChatCompletionResponseStreamChoice(
index=0,
delta=DeltaMessage(),
finish_reason=output.finish_reason,
stop_reason=output.stop_reason)
])
yield f"data: {final_response.model_dump_json(exclude_unset=True)}\n\n"
yield "data: [DONE]\n\n"
except Exception as e:
logger.error(f"Failed to create OpenAI streaming response: {e}")
logger.debug(f"Streaming error details: {traceback.format_exc()}")
# Clean up state
harmony_adapter.cleanup_stream_state(request_id)
raise e
async def handle_non_streaming_response(
harmony_adapter: HarmonyAdapter, promise: RequestOutput,
request: ChatCompletionRequest) -> ChatCompletionResponse:
def handle_non_streaming_response(tools: List[ChatCompletionToolsParam],
tool_choice: str, outputs: List, model: str,
num_prompt_tokens: int):
"""Handle non-streaming response with harmony format."""
# Get final result
await promise
# Parse harmony output to OpenAI format
# Convert tools to dictionary format for harmony adapter (standard pattern)
tools_dict = None
if request.tools:
tools_dict = [tool.model_dump() for tool in request.tools]
harmony_adapter = get_harmony_adapter()
if tools:
tools_dict = [tool.model_dump() for tool in tools]
# Get tool_choice from request - if "none", don't pass tools to parser
tool_choice = getattr(request, 'tool_choice', None)
if tool_choice == "none":
tools_for_parser = None
else:
tools_for_parser = tools_dict
output = promise.outputs[0]
output = outputs[0]
parsed_output = harmony_adapter.harmony_output_to_openai(
output.token_ids, tools_for_parser, tool_choice)
@ -1518,11 +1613,11 @@ async def handle_non_streaming_response(
output.finish_reason)
# Create usage info from metrics (RequestOutput doesn't have usage in v1)
usage_info = _create_usage_info(promise)
usage_info = _create_usage_info(num_prompt_tokens, outputs)
# Create response
response = ChatCompletionResponse(
model=request.model,
model=model,
choices=[
ChatCompletionResponseChoice(
index=0,
@ -1534,7 +1629,6 @@ async def handle_non_streaming_response(
# Optional: Log if harmony parsing failed (for debugging)
if parsed_output.get('_harmony_parsing_failed'):
logger.warning("⚠️ Harmony parsing fell back to raw text decoding")
logger.debug(f"request\n\n{request}")
logger.debug(f"response\n\n{response}\n")
return response
@ -1567,15 +1661,10 @@ def _determine_finish_reason(parsed_output: dict[str, Any],
return reason
def _create_usage_info(final_res: RequestOutput) -> UsageInfo:
def _create_usage_info(num_prompt_tokens, outputs) -> UsageInfo:
"""Create usage info from RequestOutput following serving_chat.py pattern."""
# Calculate prompt tokens from prompt_token_ids and encoder_prompt_token_ids
assert final_res.prompt_token_ids is not None
num_prompt_tokens = len(final_res.prompt_token_ids)
# Calculate completion tokens from all outputs
num_generated_tokens = sum(
len(output.token_ids) for output in final_res.outputs)
num_generated_tokens = sum(len(output.token_ids) for output in outputs)
# Create usage info
usage = UsageInfo(prompt_tokens=num_prompt_tokens,

View File

@ -11,9 +11,16 @@ from openai.types.chat import \
ChatCompletionContentPartParam as OpenAIChatCompletionContentPartParam
from openai.types.chat import \
ChatCompletionMessageParam as OpenAIChatCompletionMessageParam
from openai.types.responses import (ResponseFunctionToolCall,
ResponseInputItemParam, ResponseOutputItem,
ResponsePrompt, ResponseReasoningItem,
ResponseStatus, ResponseTextConfig)
from openai.types.responses.response import ToolChoice
from openai.types.responses.tool import Tool
from openai.types.shared import Metadata, Reasoning
from openai_harmony import ReasoningEffort
from pydantic import BaseModel, ConfigDict, Field, model_validator
from typing_extensions import Annotated, Required, TypedDict
from typing_extensions import Annotated, Required, TypeAlias, TypedDict
from tensorrt_llm.executor.request import LoRARequest
from tensorrt_llm.llmapi import DisaggregatedParams as LlmDisaggregatedParams
@ -665,6 +672,208 @@ class ChatCompletionRequest(OpenAIBaseModel):
return data
ResponseInputOutputItem: TypeAlias = Union[ResponseInputItemParam,
ResponseReasoningItem,
ResponseFunctionToolCall]
class ResponsesRequest(OpenAIBaseModel):
# Ordered by official OpenAI API documentation
# https://platform.openai.com/docs/api-reference/responses/create
background: Optional[bool] = False
include: Optional[list[
Literal[
"code_interpreter_call.outputs",
"computer_call_output.output.image_url",
"file_search_call.results",
"message.input_image.image_url",
"message.output_text.logprobs",
"reasoning.encrypted_content",
],
]] = None
input: Union[str, list[ResponseInputOutputItem]]
instructions: Optional[str] = None
max_output_tokens: Optional[int] = None
max_tool_calls: Optional[int] = None
metadata: Optional[Metadata] = None
model: str
parallel_tool_calls: Optional[bool] = False
previous_response_id: Optional[str] = None
prompt: Optional[ResponsePrompt] = None
reasoning: Optional[Reasoning] = None
service_tier: Literal["auto", "default", "flex", "scale",
"priority"] = "auto"
store: Optional[bool] = True
stream: Optional[bool] = False
temperature: Optional[float] = None
text: Optional[ResponseTextConfig] = None
tool_choice: ToolChoice = "auto"
tools: list[Tool] = Field(default_factory=list)
top_logprobs: Optional[int] = 0
top_p: Optional[float] = None
truncation: Optional[Literal["auto", "disabled"]] = "disabled"
user: Optional[str] = None
request_id: str = Field(
default_factory=lambda: f"resp_{str(uuid.uuid4().hex)}",
description=(
"The request_id related to this request. If the caller does "
"not set it, a random_uuid will be generated. This id is used "
"through out the inference process and return in response."),
)
_DEFAULT_SAMPLING_PARAMS = {
"temperature": 1.0,
"top_p": 1.0,
}
def to_sampling_params(
self,
default_max_tokens: int,
default_sampling_params: Optional[dict] = None,
) -> SamplingParams:
if self.max_output_tokens is None:
max_tokens = default_max_tokens
else:
max_tokens = min(self.max_output_tokens, default_max_tokens)
default_sampling_params = default_sampling_params or {}
if (temperature := self.temperature) is None:
temperature = default_sampling_params.get(
"temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"])
if (top_p := self.top_p) is None:
top_p = default_sampling_params.get(
"top_p", self._DEFAULT_SAMPLING_PARAMS["top_p"])
stop_token_ids = default_sampling_params.get("stop_token_ids")
# Structured output
guided_decoding = None
if self.text is not None and self.text.format is not None:
response_format = self.text.format
if response_format.type == "json_schema":
guided_decoding = GuidedDecodingParams(
json=response_format.schema_)
elif response_format.type == "json_object":
raise NotImplementedError("json_object is not supported")
return SamplingParams(
temperature=temperature,
top_p=top_p,
max_tokens=max_tokens,
logprobs=self.top_logprobs,
stop_token_ids=stop_token_ids,
guided_decoding=guided_decoding,
)
@model_validator(mode="before")
@classmethod
def validate_background(cls, data):
if not data.get("background"):
return data
if not data.get("store", True):
raise ValueError("background can only be used when `store` is true")
return data
@model_validator(mode="before")
@classmethod
def validate_prompt(cls, data):
if data.get("prompt") is not None:
raise ValueError("prompt template is not supported")
return data
class InputTokensDetails(OpenAIBaseModel):
cached_tokens: int
class OutputTokensDetails(OpenAIBaseModel):
reasoning_tokens: int
class ResponseUsage(OpenAIBaseModel):
input_tokens: int
input_tokens_details: InputTokensDetails
output_tokens: int
output_tokens_details: OutputTokensDetails
total_tokens: int
class ResponsesResponse(OpenAIBaseModel):
id: str = Field(default_factory=lambda: f"resp_{str(uuid.uuid4().hex)}")
created_at: int = Field(default_factory=lambda: int(time.time()))
# error: Optional[ResponseError] = None
# incomplete_details: Optional[IncompleteDetails] = None
instructions: Optional[str] = None
metadata: Optional[Metadata] = None
model: str
object: Literal["response"] = "response"
output: list[ResponseOutputItem]
parallel_tool_calls: bool
temperature: float
tool_choice: ToolChoice
tools: list[Tool]
top_p: float
background: bool
max_output_tokens: int
max_tool_calls: Optional[int] = None
previous_response_id: Optional[str] = None
prompt: Optional[ResponsePrompt] = None
reasoning: Optional[Reasoning] = None
service_tier: Literal["auto", "default", "flex", "scale", "priority"]
status: ResponseStatus
text: Optional[ResponseTextConfig] = None
top_logprobs: int
truncation: Literal["auto", "disabled"]
usage: Optional[ResponseUsage] = None
user: Optional[str] = None
@classmethod
def from_request(
cls,
request: ResponsesRequest,
sampling_params: SamplingParams,
model_name: str,
created_time: int,
output: list[ResponseOutputItem],
status: ResponseStatus,
usage: Optional[ResponseUsage] = None,
) -> "ResponsesResponse":
return cls(
id=request.request_id,
created_at=created_time,
instructions=request.instructions,
metadata=request.metadata,
model=model_name,
output=output,
parallel_tool_calls=request.parallel_tool_calls,
temperature=sampling_params.temperature,
tool_choice=request.tool_choice,
tools=request.tools,
top_p=sampling_params.top_p,
background=request.background,
max_output_tokens=sampling_params.max_tokens,
max_tool_calls=request.max_tool_calls,
previous_response_id=request.previous_response_id,
prompt=request.prompt,
reasoning=request.reasoning,
service_tier=request.service_tier,
status=status,
text=request.text,
top_logprobs=sampling_params.logprobs,
truncation=request.truncation,
user=request.user,
usage=usage,
)
class ResponsesStreamResponse(OpenAIBaseModel):
response: ResponsesResponse
sequence_number: int
type: Literal["response.created", "response.in_progress",
"response.completed", "response.failed",
"response.incomplete"]
def encode_opaque_state(opaque_state: Optional[bytes]) -> Optional[str]:
if opaque_state is None:
return None

View File

@ -41,17 +41,25 @@ from tensorrt_llm.serve.openai_protocol import (ChatCompletionRequest,
CompletionResponse,
CompletionResponseChoice,
ErrorResponse, ModelCard,
ModelList, UsageInfo,
ModelList, ResponsesRequest,
UsageInfo,
to_llm_disaggregated_params)
from tensorrt_llm.serve.postprocess_handlers import (
ChatPostprocArgs, CompletionPostprocArgs, chat_response_post_processor,
chat_stream_post_processor, completion_response_post_processor,
completion_stream_post_processor)
ChatCompletionPostprocArgs, ChatPostprocArgs, CompletionPostprocArgs,
chat_harmony_post_processor, chat_harmony_streaming_post_processor,
chat_response_post_processor, chat_stream_post_processor,
completion_response_post_processor, completion_stream_post_processor)
from tensorrt_llm.serve.responses_utils import ConversationHistoryStore
from tensorrt_llm.serve.responses_utils import \
create_response as responses_api_create_response
from tensorrt_llm.serve.responses_utils import \
process_streaming_events as responses_api_process_streaming_events
from tensorrt_llm.serve.responses_utils import \
request_preprocess as responses_api_request_preprocess
from tensorrt_llm.version import __version__ as VERSION
from .._utils import nvtx_mark, set_prometheus_multiproc_dir
from .harmony_adapter import (HarmonyAdapter, handle_non_streaming_response,
handle_streaming_response,
from .harmony_adapter import (HarmonyAdapter, get_harmony_adapter,
maybe_transform_reasoning_effort)
# yapf: enale
@ -83,6 +91,12 @@ class OpenAIServer:
logger.debug("Failed to load AutoConfig for %s", hf_tokenizer_path)
self.model_config = None
# Enable response storage for Responses API
self.enable_store = True
if len(os.getenv("TRTLLM_RESPONSES_API_DISABLE_STORE", "")) > 0:
self.enable_store = False
self.conversation_store = ConversationHistoryStore()
model_dir = Path(model)
if model_dir.exists() and model_dir.is_dir():
self.model = model_dir.name
@ -104,7 +118,11 @@ class OpenAIServer:
# gpt-oss
self.harmony_adapter: HarmonyAdapter | None = None
self.use_harmony = self.model_config.model_type == "gpt_oss"
disable_harmony = os.getenv("DISABLE_HARMONY_ADAPTER", "0") == "1"
if disable_harmony:
self.use_harmony = False
else:
self.use_harmony = (self.model_config.model_type == "gpt_oss")
@asynccontextmanager
async def lifespan(app: FastAPI):
@ -166,6 +184,20 @@ class OpenAIServer:
return JSONResponse(content=error_response.model_dump(),
status_code=error_response.code)
def _create_invalid_response_id_error(self, response_id: str) -> Response:
return self.create_error_response(
err_type="InvalidRequestError",
message=(f"Invalid 'response_id': '{response_id}'. "
"Expected an ID that begins with 'resp'."),
)
def _create_response_id_not_found_error(self, response_id: str) -> Response:
return self.create_error_response(
err_type="InvalidRequestError",
message=f"Response with id '{response_id}' not found.",
status_code=HTTPStatus.NOT_FOUND,
)
def register_routes(self):
self.app.add_api_route("/health", self.health, methods=["GET"])
self.app.add_api_route("/health_generate", self.health_generate, methods=["GET"])
@ -182,6 +214,9 @@ class OpenAIServer:
self.app.add_api_route("/v1/chat/completions",
self.openai_chat if not self.use_harmony else self.chat_harmony,
methods=["POST"])
self.app.add_api_route("/v1/responses",
self.openai_responses,
methods=["POST"])
if self.llm.args.return_perf_metrics:
# register /prometheus/metrics
self.mount_metrics()
@ -681,11 +716,35 @@ class OpenAIServer:
Chat Completion API with harmony format support.
Supports both streaming and non-streaming modes.
"""
async def create_harmony_response(
promise: RequestOutput, postproc_params: PostprocParams) -> ChatCompletionResponse:
await promise.aresult()
if self.postproc_worker_enabled:
chat_response =promise.outputs[0]._postprocess_result
else:
post_processor, args = postproc_params.post_processor, postproc_params.postproc_args
chat_response = post_processor(promise, args)
return chat_response
async def create_streaming_generator(promise: RequestOutput, postproc_params: PostprocParams):
if not self.postproc_worker_enabled:
post_processor, args = postproc_params.post_processor, postproc_params.postproc_args
async for res in promise:
pp_results = res.outputs[0]._postprocess_result if self.postproc_worker_enabled else post_processor(res, args)
# await self._extract_metrics(res)
for pp_res in pp_results:
yield pp_res
yield "data: [DONE]\n\n"
try:
# Initialize HarmonyAdapter
# NOTE: WAR for Disagg failure, may affect perf if no warmup
if not self.harmony_adapter:
self.harmony_adapter = HarmonyAdapter()
self.harmony_adapter = get_harmony_adapter()
# Convert Pydantic models to dictionaries for JSON serialization (standard pattern)
tools_dict = None
if request.tools:
@ -720,27 +779,37 @@ class OpenAIServer:
vocab_size=self.tokenizer.tokenizer.vocab_size)
sampling_params.detokenize = False # Harmony adapter handles detokenization
postproc_args = ChatCompletionPostprocArgs.from_request(request)
postproc_params = PostprocParams(
post_processor=chat_harmony_streaming_post_processor
if request.stream else chat_harmony_post_processor,
postproc_args=postproc_args,
)
# Generate
promise = self.llm.generate_async(
inputs=harmony_tokens,
sampling_params=sampling_params,
_postproc_params=postproc_params if self.postproc_worker_enabled else None,
streaming=bool(request.stream),
lora_request=request.lora_request,
)
postproc_args.request_id = promise.request_id
if not self.postproc_worker_enabled:
postproc_args.num_prompt_tokens = len(promise.prompt_token_ids)
# Disconnect cancellation
asyncio.create_task(self.await_disconnected(raw_request, promise))
# Handle streaming
if request.stream:
return StreamingResponse(
handle_streaming_response(
self.harmony_adapter, promise,
str(promise.request_id), request,
),
content=create_streaming_generator(promise, postproc_params),
media_type="text/event-stream"
)
else:
response = await handle_non_streaming_response(self.harmony_adapter, promise, request)
response = await create_harmony_response(promise, postproc_params)
return JSONResponse(response.model_dump())
except Exception as e:
@ -748,6 +817,80 @@ class OpenAIServer:
logger.debug("Error details: %s", traceback.format_exc())
return self.create_error_response(message=str(e), err_type="internal_error")
async def openai_responses(self, request: ResponsesRequest, raw_request: Request) -> Response:
async def create_stream_response(generator, request: ResponsesRequest, sampling_params) -> AsyncGenerator[str, None]:
async for event_data in responses_api_process_streaming_events(
request=request,
sampling_params=sampling_params,
generator=generator,
harmony_adapter=self.harmony_adapter,
model_name=self.model,
conversation_store=self.conversation_store,
enable_store=self.enable_store
):
yield event_data
try:
if not self.use_harmony:
raise NotImplementedError("Responses API only supports harmony format for now")
# Initialize HarmonyAdapter
# NOTE: WAR for Disagg failure, may affect perf if no warmup
if not self.harmony_adapter:
self.harmony_adapter = HarmonyAdapter()
if request.background:
logger.warning("Request.background is not supported yet, will fallback to foreground processing.")
# Get prev response
prev_response = None
if self.enable_store:
prev_response_id = request.previous_response_id
if prev_response_id is not None:
if not prev_response_id.startswith("resp_"):
return self._create_invalid_response_id_error(prev_response_id)
prev_response = await self.conversation_store.load_response(prev_response_id)
if prev_response is None:
logger.debug(f"response_id {prev_response_id} not found")
return self._create_response_id_not_found_error(prev_response_id)
input_tokens, sampling_params = await responses_api_request_preprocess(
request, prev_response, self.harmony_adapter, self.conversation_store, self.enable_store)
promise = self.llm.generate_async(
inputs=input_tokens,
sampling_params=sampling_params,
streaming=request.stream,
)
asyncio.create_task(self.await_disconnected(raw_request, promise))
if request.stream:
return StreamingResponse(
create_stream_response(promise, request, sampling_params),
media_type="text/event-stream"
)
else:
return await responses_api_create_response(
generator=promise,
request=request,
sampling_params=sampling_params,
model_name=self.model,
conversation_store=self.conversation_store,
generation_result=None,
enable_store=self.enable_store)
except CppExecutorError:
logger.error(traceback.format_exc())
# If internal executor error is raised, shutdown the server
signal.raise_signal(signal.SIGINT)
except Exception as e:
logger.error(traceback.format_exc())
return self.create_error_response(str(e))
return JSONResponse(content={"detail": "None"})
async def __call__(self, host, port):
# Store the binding address for server registration
self.binding_addr = f"http://{host}:{port}"

View File

@ -9,6 +9,8 @@ from ..llmapi.reasoning_parser import (BaseReasoningParser,
ReasoningParserFactory)
from ..llmapi.tokenizer import TransformersTokenizer
# yapf: disable
from .harmony_adapter import (handle_non_streaming_response,
handle_streaming_response)
from .openai_protocol import (ChatCompletionLogProbs,
ChatCompletionLogProbsContent,
ChatCompletionNamedToolChoiceParam,
@ -24,7 +26,8 @@ from .openai_protocol import (ChatCompletionLogProbs,
FunctionCall, StreamOptions, ToolCall, UsageInfo,
to_disaggregated_params)
# yapf: enale
# yapf: enable
@dataclass(kw_only=True)
class ChatPostprocArgs(PostprocArgs):
@ -57,8 +60,7 @@ class ChatPostprocArgs(PostprocArgs):
)
def create_logprobs(token_ids: List[int],
tokenizer: TransformersTokenizer,
def create_logprobs(token_ids: List[int], tokenizer: TransformersTokenizer,
logprobs: List[float]) -> ChatCompletionLogProbs:
assert len(token_ids) == len(logprobs), \
"token_ids and logprobs have different lengths"
@ -75,12 +77,14 @@ def create_logprobs(token_ids: List[int],
return chat_logprobs
def apply_reasoning_parser(args: ChatPostprocArgs, output_index: int, text: str, streaming: bool) -> Tuple[bool, str, str]:
def apply_reasoning_parser(args: ChatPostprocArgs, output_index: int, text: str,
streaming: bool) -> Tuple[bool, str, str]:
reasoning_parser = None
if args.reasoning_parser is not None:
if output_index not in args.reasoning_parser_dict:
args.reasoning_parser_dict[output_index] = ReasoningParserFactory.create_reasoning_parser(
args.reasoning_parser)
args.reasoning_parser_dict[
output_index] = ReasoningParserFactory.create_reasoning_parser(
args.reasoning_parser)
reasoning_parser = args.reasoning_parser_dict[output_index]
in_reasoning = False
@ -97,7 +101,8 @@ def apply_reasoning_parser(args: ChatPostprocArgs, output_index: int, text: str,
@nvtx_range_debug("chat_stream_post_processor")
def chat_stream_post_processor(rsp: GenerationResultBase, args: ChatPostprocArgs) -> List[str]:
def chat_stream_post_processor(rsp: GenerationResultBase,
args: ChatPostprocArgs) -> List[str]:
def yield_first_chat(num_tokens: int,
idx: int,
@ -128,9 +133,13 @@ def chat_stream_post_processor(rsp: GenerationResultBase, args: ChatPostprocArgs
include_continuous_usage = False
if args.first_iteration:
for i in range(args.num_choices):
res.append(f"data: {yield_first_chat(prompt_tokens, i, role=args.role)} \n\n")
res.append(
f"data: {yield_first_chat(prompt_tokens, i, role=args.role)} \n\n"
)
if args.echo and args.last_message_content:
res.append(f"data: {yield_first_chat(prompt_tokens, i, content=args.last_message_content)} \n\n")
res.append(
f"data: {yield_first_chat(prompt_tokens, i, content=args.last_message_content)} \n\n"
)
args.first_iteration = False
for output in rsp.outputs:
@ -158,14 +167,18 @@ def chat_stream_post_processor(rsp: GenerationResultBase, args: ChatPostprocArgs
delta_message = DeltaMessage(
content=delta_text, reasoning_content=reasoning_delta_text)
choice = ChatCompletionResponseStreamChoice(index=i,
delta=delta_message,
finish_reason=None,
avg_decoded_tokens_per_iter=getattr(rsp, 'avg_decoded_tokens_per_iter', None))
choice = ChatCompletionResponseStreamChoice(
index=i,
delta=delta_message,
finish_reason=None,
avg_decoded_tokens_per_iter=getattr(rsp,
'avg_decoded_tokens_per_iter',
None))
if args.return_logprobs:
logprobs = output.logprobs_diff
token_ids = output.token_ids_diff
choice.logprobs = create_logprobs(token_ids, args.tokenizer, logprobs)
choice.logprobs = create_logprobs(token_ids, args.tokenizer,
logprobs)
if output.finish_reason is not None:
choice.finish_reason = output.finish_reason
choice.stop_reason = output.stop_reason
@ -179,57 +192,62 @@ def chat_stream_post_processor(rsp: GenerationResultBase, args: ChatPostprocArgs
res.append(f"data: {data}\n\n")
if include_usage and rsp._done:
completion_tokens = sum(output.length
for output in rsp.outputs)
completion_tokens = sum(output.length for output in rsp.outputs)
final_usage = UsageInfo(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
final_usage_chunk = ChatCompletionStreamResponse(
choices=[], model=args.model, usage=final_usage)
final_usage_chunk = ChatCompletionStreamResponse(choices=[],
model=args.model,
usage=final_usage)
final_usage_data = final_usage_chunk.model_dump_json()
res.append(f"data: {final_usage_data}\n\n")
return res
@nvtx_range_debug("chat_response_post_processor")
def chat_response_post_processor(rsp: GenerationResultBase, args: ChatPostprocArgs) -> ChatCompletionResponse:
def chat_response_post_processor(
rsp: GenerationResultBase,
args: ChatPostprocArgs) -> ChatCompletionResponse:
choices: List[ChatCompletionResponseChoice] = []
role = args.role
for output in rsp.outputs:
_, text, reasoning_text = apply_reasoning_parser(
args, output.index, output.text, False)
if args.tool_choice and isinstance(
args.tool_choice,
ChatCompletionNamedToolChoiceParam):
if args.tool_choice and isinstance(args.tool_choice,
ChatCompletionNamedToolChoiceParam):
message = ChatMessage(
role=role,
content="",
tool_calls=[
ToolCall(function=FunctionCall(
name=args.tool_choice.function.name,
arguments=text))
name=args.tool_choice.function.name, arguments=text))
])
else:
if text is None:
text = ""
message = ChatMessage(
role=role, content=text, reasoning_content=reasoning_text)
disaggregated_params = to_disaggregated_params(output.disaggregated_params)
message = ChatMessage(role=role,
content=text,
reasoning_content=reasoning_text)
disaggregated_params = to_disaggregated_params(
output.disaggregated_params)
choice = ChatCompletionResponseChoice(
index=output.index,
message=message,
finish_reason=output.finish_reason,
stop_reason=output.stop_reason,
disaggregated_params=disaggregated_params,
avg_decoded_tokens_per_iter=getattr(rsp, 'avg_decoded_tokens_per_iter', None),
avg_decoded_tokens_per_iter=getattr(rsp,
'avg_decoded_tokens_per_iter',
None),
)
if args.return_logprobs:
choice.logprobs = create_logprobs(output.token_ids, args.tokenizer, output.logprobs)
choice.logprobs = create_logprobs(output.token_ids, args.tokenizer,
output.logprobs)
choices.append(choice)
if args.echo and args.last_message_content:
@ -238,8 +256,7 @@ def chat_response_post_processor(rsp: GenerationResultBase, args: ChatPostprocAr
choice.message.content = full_message
num_prompt_tokens = args.num_prompt_tokens
num_generated_tokens = sum(
len(output.token_ids) for output in rsp.outputs)
num_generated_tokens = sum(len(output.token_ids) for output in rsp.outputs)
usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
completion_tokens=num_generated_tokens,
@ -275,7 +292,8 @@ class CompletionPostprocArgs(PostprocArgs):
@nvtx_range_debug("completion_stream_post_processor")
def completion_stream_post_processor(rsp: DetokenizedGenerationResultBase, args: CompletionPostprocArgs) -> List[str]:
def completion_stream_post_processor(rsp: DetokenizedGenerationResultBase,
args: CompletionPostprocArgs) -> List[str]:
res: List[str] = []
prompt_tokens = args.num_prompt_tokens
if stream_option := args.stream_options:
@ -293,9 +311,11 @@ def completion_stream_post_processor(rsp: DetokenizedGenerationResultBase, args:
index=args.prompt_idx * args.num_choices + output.index,
text=delta_text if args.detokenize else "",
token_ids=None if args.detokenize else output.token_ids_diff,
finish_reason = output.finish_reason,
stop_reason = output.stop_reason,
avg_decoded_tokens_per_iter=getattr(rsp, 'avg_decoded_tokens_per_iter', None),
finish_reason=output.finish_reason,
stop_reason=output.stop_reason,
avg_decoded_tokens_per_iter=getattr(rsp,
'avg_decoded_tokens_per_iter',
None),
)
chunk = CompletionStreamResponse(model=args.model, choices=[choice])
if include_continuous_usage:
@ -306,16 +326,16 @@ def completion_stream_post_processor(rsp: DetokenizedGenerationResultBase, args:
res.append(f"data: {data}\n\n")
if include_usage and rsp._done:
completion_tokens = sum(output.length
for output in rsp.outputs)
completion_tokens = sum(output.length for output in rsp.outputs)
final_usage = UsageInfo(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
final_usage_chunk = ChatCompletionStreamResponse(
choices=[], model=args.model, usage=final_usage)
final_usage_chunk = ChatCompletionStreamResponse(choices=[],
model=args.model,
usage=final_usage)
final_usage_data = final_usage_chunk.model_dump_json()
res.append(f"data: {final_usage_data}\n\n")
args.first_iteration = False
@ -323,7 +343,9 @@ def completion_stream_post_processor(rsp: DetokenizedGenerationResultBase, args:
@nvtx_range_debug("completion_response_post_processor")
def completion_response_post_processor(rsp: GenerationResult, args: CompletionPostprocArgs) -> CompletionResponse:
def completion_response_post_processor(
rsp: GenerationResult,
args: CompletionPostprocArgs) -> CompletionResponse:
prompt_tokens = args.num_prompt_tokens
completion_tokens = 0
choices = []
@ -331,23 +353,75 @@ def completion_response_post_processor(rsp: GenerationResult, args: CompletionPo
text = output.text
if args.echo:
text = args.prompt + text
disaggregated_params = to_disaggregated_params(output.disaggregated_params)
disaggregated_params = to_disaggregated_params(
output.disaggregated_params)
choice = CompletionResponseChoice(
text=text if args.detokenize else "",
token_ids=None if args.detokenize else output.token_ids,
index=args.prompt_idx * args.num_choices + output.index,
disaggregated_params=disaggregated_params,
context_logits=None if rsp.context_logits is None else rsp.context_logits.tolist(),
context_logits=None
if rsp.context_logits is None else rsp.context_logits.tolist(),
stop_reason=output.stop_reason,
finish_reason=output.finish_reason,
avg_decoded_tokens_per_iter=getattr(rsp, 'avg_decoded_tokens_per_iter', None),
avg_decoded_tokens_per_iter=getattr(rsp,
'avg_decoded_tokens_per_iter',
None),
)
completion_tokens += output.length
choices.append(choice)
usage = UsageInfo(prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=completion_tokens + prompt_tokens)
response = CompletionResponse(choices=choices, model=args.model, usage=usage)
completion_tokens=completion_tokens,
total_tokens=completion_tokens + prompt_tokens)
response = CompletionResponse(choices=choices,
model=args.model,
usage=usage)
return response
@dataclass(kw_only=True)
class ChatCompletionPostprocArgs(PostprocArgs):
model: str
tools: Optional[List[ChatCompletionToolsParam]]
tool_choice: Optional[Union[Literal["none", "auto"],
ChatCompletionNamedToolChoiceParam]]
request_id: Optional[int] = None
@classmethod
def from_request(cls, request: ChatCompletionRequest):
return cls(
model=request.model,
tools=request.tools,
tool_choice=request.tool_choice,
)
@nvtx_range_debug("chat_harmony_post_processor")
def chat_harmony_post_processor(
rsp: GenerationResult,
args: ChatCompletionPostprocArgs) -> ChatCompletionResponse:
response = handle_non_streaming_response(
tools=args.tools,
tool_choice=args.tool_choice,
outputs=rsp.outputs,
model=args.model,
num_prompt_tokens=args.num_prompt_tokens,
)
return response
@nvtx_range_debug("chat_harmony_streaming_post_processor")
def chat_harmony_streaming_post_processor(
rsp: GenerationResult, args: ChatCompletionPostprocArgs) -> List[str]:
response = handle_streaming_response(
tools=args.tools,
tool_choice=args.tool_choice,
outputs=rsp.outputs,
model=args.model,
request_id=args.request_id,
done=rsp._done,
num_prompt_tokens=args.num_prompt_tokens,
)
return response

View File

@ -0,0 +1,848 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
import json
import os
import time
import uuid
from collections.abc import AsyncGenerator
from copy import copy
from typing import Literal, Optional, OrderedDict, Union
# yapf: disable
from openai.types.responses import (ResponseCompletedEvent,
ResponseContentPartAddedEvent,
ResponseContentPartDoneEvent,
ResponseCreatedEvent,
ResponseFunctionToolCall,
ResponseInProgressEvent, ResponseOutputItem,
ResponseOutputItemAddedEvent,
ResponseOutputItemDoneEvent,
ResponseOutputMessage, ResponseOutputText,
ResponseReasoningItem,
ResponseReasoningTextDeltaEvent,
ResponseReasoningTextDoneEvent,
ResponseTextDeltaEvent,
ResponseTextDoneEvent)
# yapf: enable
from openai.types.responses.response_function_web_search import (
ActionFind, ActionOpenPage, ActionSearch, ResponseFunctionWebSearch)
from openai.types.responses.response_reasoning_item import Content
from openai.types.responses.tool import Tool
from openai_harmony import (Author, Conversation, DeveloperContent,
HarmonyEncodingName, Message, ReasoningEffort, Role,
StreamState, SystemContent, TextContent,
ToolDescription, load_harmony_encoding)
from tensorrt_llm.llmapi import SamplingParams
from tensorrt_llm.llmapi.llm import RequestOutput
from tensorrt_llm.logger import logger
from tensorrt_llm.serve.openai_protocol import (OpenAIBaseModel,
ResponseInputOutputItem,
ResponsesRequest,
ResponsesResponse)
from .harmony_adapter import HarmonyAdapter
REASONING_EFFORT = {
"high": ReasoningEffort.HIGH,
"medium": ReasoningEffort.MEDIUM,
"low": ReasoningEffort.LOW,
}
ENABLE_RESPONSES_DEBUG_MSG = False
def responses_debug_log(msg):
if ENABLE_RESPONSES_DEBUG_MSG:
logger.debug(msg)
_harmony_encoding = None
def random_uuid():
return str(uuid.uuid4().hex)
def get_encoding():
global _harmony_encoding
if _harmony_encoding is None:
_harmony_encoding = load_harmony_encoding(
HarmonyEncodingName.HARMONY_GPT_OSS)
return _harmony_encoding
def decode_tokens(tokens):
return get_encoding().decode(tokens)
def parse_response_input(
input_msg: ResponseInputOutputItem,
prev_responses: list[Union[ResponseOutputItem, ResponseReasoningItem]]
) -> Message:
if not isinstance(input_msg, dict):
input_msg = input_msg.model_dump()
responses_debug_log(f"------- Parsing input -----------")
responses_debug_log(input_msg)
responses_debug_log("")
if "type" not in input_msg or input_msg["type"] == "message":
role = input_msg["role"]
content = input_msg["content"]
if role == "system":
# User is trying to set a system message. Change it to:
# <|start|>developer<|message|># Instructions
# {instructions}<|end|>
role = "developer"
text_prefix = "Instructions:\n"
else:
text_prefix = ""
if isinstance(content, str):
msg = Message.from_role_and_content(role, text_prefix + content)
elif isinstance(content, list):
contents = [
TextContent(text=text_prefix + c["text"]) for c in content
]
msg = Message.from_role_and_contents(role, contents)
else:
logger.warning("Responses API: Invalid input message type")
msg = None
elif input_msg["type"] == "function_call_output":
call_id = input_msg["call_id"]
call_response: Optional[ResponseFunctionToolCall] = None
for prev_response in reversed(prev_responses):
if isinstance(prev_response, ResponseFunctionToolCall
) and prev_response.call_id == call_id:
call_response = prev_response
break
if call_response is None:
raise ValueError(f"No call message found for {call_id}")
msg = Message.from_author_and_content(
Author.new(Role.TOOL, f"functions.{call_response.name}"),
input_msg["output"])
elif input_msg["type"] == "reasoning":
content = input_msg["content"]
assert len(content) == 1
msg = Message.from_role_and_content(Role.ASSISTANT, content[0]["text"])
elif input_msg["type"] == "function_call":
msg = Message.from_role_and_content(Role.ASSISTANT,
input_msg["arguments"])
msg = msg.with_channel("commentary")
msg = msg.with_recipient(f"functions.{input_msg['name']}")
msg = msg.with_content_type("json")
else:
raise ValueError(f"Unknown input type: {input_msg['type']}")
return msg
class ConversationHistoryStore:
def __init__(self, resp_capacity: int = 16, max_conversations=32):
self.response_capacity = resp_capacity
self.conversation_capacity = resp_capacity * 4
self.max_conversations = max_conversations
self.responses_lock = asyncio.Lock()
self.responses: OrderedDict[str, ResponsesResponse] = OrderedDict()
self.conversations_lock = asyncio.Lock()
self.conversations: OrderedDict[str, list[Message]] = OrderedDict()
self.response_to_conversation: dict[str, str] = {}
self.conversation_to_response: dict[str, str] = {}
async def load_response(self, resp_id: str) -> ResponsesResponse:
responses_debug_log(f"ConversationHistoryStore loading resp: {resp_id}")
async with self.responses_lock:
return self.responses.get(resp_id)
async def store_response(self,
resp: ResponsesResponse,
resp_msgs: Optional[list[Message]] = [],
prev_resp_id: Optional[str] = None) -> None:
resp_id = resp.id
responses_debug_log(f"ConversationHistoryStore storing resp: {resp_id}")
async with self.responses_lock:
self.responses[resp_id] = resp
if len(self.responses) > self.response_capacity:
self._pop_response()
async with self.conversations_lock:
conversation_id: str
if resp_id in self.response_to_conversation:
conversation_id = self.response_to_conversation[resp_id]
self.conversations[conversation_id].extend(resp_msgs)
elif prev_resp_id is not None:
conversation_id = self.response_to_conversation[prev_resp_id]
self.conversations[conversation_id].extend(resp_msgs)
while len(self.conversations[conversation_id]
) > self.conversation_capacity:
self._pop_conversation(resp_id)
else:
conversation_id = random_uuid()
self.conversations[conversation_id] = resp_msgs
responses_debug_log(
f" * storing at conversation id: {conversation_id}")
self.response_to_conversation[resp_id] = conversation_id
self.conversation_to_response[conversation_id] = resp_id
self._update_visited_conversation(conversation_id)
async def store_messages(self, resp_id: str, msgs: list[Message],
prev_resp_id: Optional[str]):
responses_debug_log(f"ConversationHistoryStore storing msg:")
for msg in msgs:
responses_debug_log(f" -> {msg.to_json()}")
async with self.conversations_lock:
conversation_id: str
if prev_resp_id is not None:
conversation_id = self.response_to_conversation[prev_resp_id]
else:
conversation_id = random_uuid()
responses_debug_log(
f" * storing at conversation: {conversation_id}")
self.conversations[conversation_id] = msgs
if len(self.conversations[conversation_id]
) > self.conversation_capacity:
self._pop_conversation(resp_id)
self.response_to_conversation[resp_id] = conversation_id
self.conversation_to_response[conversation_id] = resp_id
self._update_visited_conversation(conversation_id)
async def append_messages(self, resp_id: str, msgs: list[Message]):
responses_debug_log(f"ConversationHistoryStore appending msgs:")
for msg in msgs:
responses_debug_log(f" -> {msg.to_json()}")
async with self.conversations_lock:
assert resp_id in self.response_to_conversation
conversation_id = self.response_to_conversation[resp_id]
responses_debug_log(
f" * appending at conversation: {conversation_id}")
self.conversations[conversation_id].extend(msgs)
if len(self.conversations[conversation_id]
) > self.conversation_capacity:
self._pop_conversation(resp_id)
self._update_visited_conversation(conversation_id)
async def get_conversation_history(self, resp_id: str) -> list[Message]:
responses_debug_log(f"ConversationHistoryStore getting prev_msgs:")
responses_debug_log(f" -> prev_resp_id: {resp_id}")
async with self.conversations_lock:
if resp_id in self.response_to_conversation:
conversation_id = self.response_to_conversation[resp_id]
self._update_visited_conversation(conversation_id)
return self.conversations.get(conversation_id, [])
return []
def _update_visited_conversation(self, conversation_id) -> None:
if conversation_id not in self.conversations:
return
self.conversations.move_to_end(conversation_id)
if len(self.conversations) > self.max_conversations:
removed_id, _ = self.conversations.popitem(last=False)
responses_debug_log(
f"ConversationHistoryStore Removing conversation {removed_id}")
removed_resp_id = self.conversation_to_response[removed_id]
# The responses may have been removed due to response capacity
if removed_resp_id in self.response_to_conversation:
self.response_to_conversation.pop(removed_resp_id)
self.conversation_to_response.pop(removed_id)
def _pop_conversation(self, resp_id) -> None:
conversation_id = self.response_to_conversation.get(resp_id, None)
if conversation_id is None:
return
conversation = self.conversations[conversation_id]
first_conversation_range = []
for i, msg in enumerate(conversation):
if msg.author.role == Role.USER:
first_conversation_range.append(i)
elif msg.channel == "final":
first_conversation_range.append(i)
break
del conversation[
first_conversation_range[0]:first_conversation_range[1] + 1]
def _pop_response(self) -> None:
responses_debug_log(f"responses type: {type(self.responses)}")
resp_id, _ = self.responses.popitem(last=False)
if resp_id in self.response_to_conversation:
self.response_to_conversation.pop(resp_id)
def get_system_message(
model_identity: Optional[str] = None,
reasoning_effort: Optional[Literal["high", "medium", "low"]] = None,
start_date: Optional[str] = None,
browser_description: Optional[str] = None,
python_description: Optional[str] = None,
) -> Message:
sys_msg_content = SystemContent.new()
if model_identity is not None:
sys_msg_content = sys_msg_content.with_model_identity(model_identity)
if reasoning_effort is not None:
sys_msg_content = sys_msg_content.with_reasoning_effort(
REASONING_EFFORT[reasoning_effort])
if start_date:
sys_msg_content = sys_msg_content.with_conversation_start_date(
start_date)
if browser_description is not None:
sys_msg_content = sys_msg_content.with_tools(browser_description)
if python_description is not None:
sys_msg_content = sys_msg_content.with_tools(python_description)
sys_msg = Message.from_role_and_content(Role.SYSTEM, sys_msg_content)
return sys_msg
def get_developer_message(instructions: Optional[str] = None,
tools: Optional[list[Tool]] = None) -> Message:
dev_msg_content = DeveloperContent.new()
if instructions is not None:
dev_msg_content = dev_msg_content.with_instructions(instructions)
if tools is not None:
function_tools = []
for tool in tools:
if tool.type in ("web_search_preview", "code_interpreter"):
# These are built-in tools that are added to the system message.
pass
elif tool.type == "function":
function_tools.append(tool)
else:
raise ValueError(f"tool type {tool.type} not supported")
if function_tools:
function_tool_descriptions = [
ToolDescription.new(
name=tool.name,
description=tool.description,
parameters=tool.parameters,
) for tool in function_tools
]
dev_msg_content = dev_msg_content.with_function_tools(
function_tool_descriptions)
dev_msg = Message.from_role_and_content(Role.DEVELOPER, dev_msg_content)
return dev_msg
def get_user_message(content: str) -> Message:
return Message.from_role_and_content(Role.USER, content)
def construct_harmony_messages(
request: ResponsesRequest,
prev_response: Optional[ResponsesResponse],
prev_msgs: list[Message] = [],
) -> list[Message]:
"""Construct messages from request input, includes conversation history messages if exists."""
messages: list[Message] = []
if prev_response is None:
# New conversation.
reasoning_effort = (request.reasoning.effort
if request.reasoning else None)
sys_msg = get_system_message(reasoning_effort=reasoning_effort, )
messages.append(sys_msg)
dev_msg = get_developer_message(request.instructions, request.tools)
messages.append(dev_msg)
else:
messages.extend(prev_msgs)
# Append the new input.
# Responses API supports simple text inputs without chat format.
if isinstance(request.input, str):
messages.append(get_user_message(request.input))
else:
if prev_response is not None:
prev_outputs = copy(prev_response.output)
else:
prev_outputs = []
for input_msg in request.input:
msg = parse_response_input(input_msg, prev_outputs)
if msg is not None:
messages.append(msg)
# User passes in a a tool call request and its output. We need
# to add the tool call request to prev_outputs so that the
# parse_response_input can find the tool call request when
# parsing the tool call output.
if isinstance(input_msg, ResponseFunctionToolCall):
prev_outputs.append(input_msg)
return messages
def render_for_completion(messages: list[Message]) -> list[int]:
conversation = Conversation.from_messages(messages)
responses_debug_log("Rendering conversation:")
responses_debug_log(conversation.to_json())
token_ids = get_encoding().render_conversation_for_completion(
conversation, Role.ASSISTANT)
return token_ids
def parse_output_tokens(tokens: list[int]) -> list[Message]:
return get_encoding().parse_messages_from_completion_tokens(
tokens, role=Role.ASSISTANT)
def parse_output_message(message: Message) -> list[ResponseOutputItem]:
"""
Parse a Harmony message into a list of output response items.
"""
if message.author.role != "assistant":
# This is a message from a tool to the assistant (e.g., search result).
# Don't include it in the final output for now. This aligns with
# OpenAI's behavior on models like o4-mini.
return []
output_items: list[ResponseOutputItem] = []
recipient = message.recipient
if recipient is not None and recipient.startswith("browser."):
if len(message.content) != 1:
raise ValueError("Invalid number of contents in browser message")
content = message.content[0]
browser_call = json.loads(content.text)
# TODO: translate to url properly!
if recipient == "browser.search":
action = ActionSearch(
query=f"cursor:{browser_call.get('query', '')}", type="search")
elif recipient == "browser.open":
action = ActionOpenPage(url=f"cursor:{browser_call.get('url', '')}",
type="open_page")
elif recipient == "browser.find":
action = ActionFind(pattern=browser_call["pattern"],
url=f"cursor:{browser_call.get('url', '')}",
type="find")
else:
raise ValueError(f"Unknown browser action: {recipient}")
web_search_item = ResponseFunctionWebSearch(
id=f"ws_{random_uuid()}",
action=action,
status="completed",
type="web_search_call",
)
output_items.append(web_search_item)
elif message.channel == "analysis":
for content in message.content:
reasoning_item = ResponseReasoningItem(
id=f"rs_{random_uuid()}",
summary=[],
type="reasoning",
content=[Content(text=content.text, type="reasoning_text")],
status=None,
)
output_items.append(reasoning_item)
elif message.channel == "commentary":
if message.recipient is None:
pass
elif message.recipient.startswith("functions."):
function_name = message.recipient.split(".")[-1]
for content in message.content:
random_id = random_uuid()
response_item = ResponseFunctionToolCall(
arguments=content.text,
call_id=f"call_{random_id}",
type="function_call",
name=function_name,
id=f"fc_{random_id}",
)
output_items.append(response_item)
elif message.recipient.startswith(
"python") or message.recipient.startswith("browser"):
for content in message.content:
reasoning_item = ResponseReasoningItem(
id=f"rs_{random_uuid()}",
summary=[],
type="reasoning",
content=[Content(text=content.text, type="reasoning_text")],
status=None,
)
output_items.append(reasoning_item)
else:
raise ValueError(f"Unknown recipient: {message.recipient}")
elif message.channel == "final":
contents = []
for content in message.content:
output_text = ResponseOutputText(
text=content.text,
annotations=[], # TODO
type="output_text",
logprobs=None, # TODO
)
contents.append(output_text)
text_item = ResponseOutputMessage(
id=f"msg_{random_uuid()}",
content=contents,
role=message.author.role,
status="completed",
type="message",
)
output_items.append(text_item)
else:
raise ValueError(f"Unknown channel: {message.channel}")
return output_items
def finish_reason_mapping(finish_reason: str) -> str:
match finish_reason:
case 'stop':
return 'completed'
case 'length':
return 'incomplete'
case 'timeout':
return 'failed'
case 'cancelled':
return 'cancelled'
raise RuntimeError("Should never reach here!")
async def request_preprocess(request: ResponsesRequest,
prev_response: Optional[ResponsesResponse],
harmony_adapter: HarmonyAdapter,
conversation_store: ConversationHistoryStore,
enable_store=False):
# TODO: fix default_max_tokens
sampling_params = request.to_sampling_params(
default_max_tokens=int(16384),
default_sampling_params={
"stop_token_ids": harmony_adapter.get_stop_tokens()
})
prev_response_id = request.previous_response_id
# TODO: better way to enable metrics
if len(os.getenv("TRTLLM_KVCACHE_TIME_OUTPUT_PATH", "")) > 0:
sampling_params.return_perf_metrics = True
prev_msgs = []
if enable_store:
prev_msgs = await conversation_store.get_conversation_history(
prev_response_id)
responses_debug_log(f"Prev msgs:")
for msg in prev_msgs:
responses_debug_log(f" -> {msg.to_json()}")
messages = construct_harmony_messages(request,
prev_response,
prev_msgs=prev_msgs)
if enable_store and request.store:
# Remove reasoning messages to save token usage during multi-turn conversation
msgs_to_store = [msg for msg in messages if msg.channel != "analysis"]
await conversation_store.store_messages(request.request_id,
msgs_to_store, prev_response_id)
input_tokens = render_for_completion(messages)
responses_debug_log("======= Complete Inputs to model =======")
responses_debug_log(decode_tokens(input_tokens))
responses_debug_log("========================================")
return input_tokens, sampling_params
async def create_response(
generator,
request: ResponsesRequest,
sampling_params,
model_name: str,
conversation_store: ConversationHistoryStore,
generation_result: RequestOutput = None,
enable_store=False,
create_time: int = None,
) -> ResponsesResponse:
final_res: Optional[RequestOutput] = None
response_creation_time = create_time if create_time is not None else int(
time.time())
prev_response_id = request.previous_response_id
if generation_result is not None:
final_res = generation_result
else:
final_res = await generator
if final_res is None:
raise RuntimeError("No output generated or provided")
responses_debug_log("================================================")
responses_debug_log("RAW MODEL OUTPUT:")
responses_debug_log(final_res.outputs)
responses_debug_log("================================================")
output_messages = parse_output_tokens(final_res.outputs[0].token_ids)
responses_debug_log(f"output messages: {len(output_messages)}")
for msg in output_messages:
responses_debug_log(f" -> {msg.to_json()}")
# prepare responses output
output_content = []
for msg in output_messages:
output_content.extend(parse_output_message(msg))
response = ResponsesResponse.from_request(
request=request,
sampling_params=sampling_params,
model_name=model_name,
created_time=response_creation_time,
output=output_content,
status=finish_reason_mapping(final_res.outputs[0].finish_reason),
)
if enable_store and request.store:
await conversation_store.store_response(resp=response,
resp_msgs=output_messages,
prev_resp_id=prev_response_id)
responses_debug_log("========== Response ===========")
responses_debug_log(response)
responses_debug_log("===============================")
return response
async def process_streaming_events(
request: ResponsesRequest,
sampling_params: SamplingParams,
generator,
harmony_adapter: HarmonyAdapter,
model_name: str,
conversation_store: ConversationHistoryStore,
create_time: int = None,
enable_store=False) -> AsyncGenerator[str, None]:
sequence_number = 0
response_creation_time = create_time if create_time is not None else int(
time.time())
final_res: Optional[RequestOutput] = None
def _send_event(event: OpenAIBaseModel):
nonlocal sequence_number
# Set sequence_number if the event has this attribute
if hasattr(event, 'sequence_number'):
event.sequence_number = sequence_number
sequence_number += 1
# Get event type from the event's type field if it exists
event_type = getattr(event, 'type', 'unknown')
return (f"event: {event_type}\n"
f"data: {event.model_dump_json(indent=None)}\n\n")
current_content_index = 0 # FIXME: this number is never changed
current_output_index = 0
current_item_id = "" # FIXME: this number is never changed
sent_output_item_added = False
initial_response = ResponsesResponse.from_request(
request,
sampling_params,
model_name=model_name,
created_time=response_creation_time,
output=[],
status="in_progress",
usage=None,
).model_dump()
yield _send_event(
ResponseCreatedEvent(
type="response.created",
sequence_number=-1,
response=initial_response,
))
yield _send_event(
ResponseInProgressEvent(
type="response.in_progress",
sequence_number=-1,
response=initial_response,
))
tools = [tool.model_dump() for tool in request.tools]
stream_request_id = f"responses-api-{request.request_id}"
async for res in generator:
final_res = res
output = res.outputs[0]
messages = harmony_adapter.stateful_stream_harmony_tokens_to_openai_messages(
stream_request_id, output.token_ids_diff, tools,
request.tool_choice)
stream_state = harmony_adapter.get_stream_state(stream_request_id)
assert stream_state is not None
parser = stream_state.get_parser()
if parser.state == StreamState.EXPECT_START:
current_output_index += 1
sent_output_item_added = False
if len(messages) > 0:
previous_item = messages[-1]
if previous_item.recipient is not None:
# Deal with tool call here
pass
elif previous_item.channel == "analysis":
reasoning_item = ResponseReasoningItem(
type="reasoning",
content=[
Content(
text=previous_item.content[0].text,
type="reasoning_text",
),
],
status="completed",
id=current_item_id,
summary=[],
)
yield _send_event(
ResponseReasoningTextDoneEvent(
type="response.reasoning_text.done",
item_id=current_item_id,
sequence_number=-1,
output_index=current_output_index,
content_index=current_content_index,
text=previous_item.content[0].text,
))
yield _send_event(
ResponseOutputItemDoneEvent(
type="response.output_item.done",
sequence_number=-1,
output_index=current_output_index,
item=reasoning_item,
))
elif previous_item.channel == "final":
text_content = ResponseOutputText(
type="output_text",
text=previous_item.content[0].text,
annotations=[],
)
yield _send_event(
ResponseTextDoneEvent(
type="response.output_text.done",
sequence_number=-1,
output_index=current_output_index,
content_index=current_content_index,
text=previous_item.content[0].text,
logprobs=[],
item_id=current_item_id,
))
yield _send_event(
ResponseContentPartDoneEvent(
type="response.content_part.done",
sequence_number=-1,
item_id=current_item_id,
output_index=current_output_index,
content_index=current_content_index,
part=text_content,
))
yield _send_event(
ResponseOutputItemDoneEvent(
type="response.output_item.done",
sequence_number=-1,
output_index=current_output_index,
item=ResponseOutputMessage(
id=current_item_id,
type="message",
role="assistant",
content=[text_content],
status="completed",
),
))
if parser.last_content_delta:
if (parser.current_channel == "final"
and parser.current_recipient is None):
if not sent_output_item_added:
sent_output_item_added = True
yield _send_event(
ResponseOutputItemAddedEvent(
type="response.output_item.added",
sequence_number=-1,
output_index=current_output_index,
item=ResponseOutputMessage(
id=current_item_id,
type="message",
role="assistant",
content=[],
status="in_progress",
),
))
yield _send_event(
ResponseContentPartAddedEvent(
type="response.content_part.added",
sequence_number=-1,
output_index=current_output_index,
item_id=current_item_id,
content_index=current_content_index,
part=ResponseOutputText(
type="output_text",
text="",
annotations=[],
logprobs=[],
),
))
yield _send_event(
ResponseTextDeltaEvent(
type="response.output_text.delta",
sequence_number=-1,
content_index=current_content_index,
output_index=current_output_index,
item_id=current_item_id,
delta=parser.last_content_delta,
# TODO, use logprobs from ctx.last_request_output
logprobs=[],
))
elif (parser.current_channel == "analysis"
and parser.current_recipient is None):
if not sent_output_item_added:
sent_output_item_added = True
yield _send_event(
ResponseOutputItemAddedEvent(
type="response.output_item.added",
sequence_number=-1,
output_index=current_output_index,
item=ResponseReasoningItem(
type="reasoning",
id=current_item_id,
summary=[],
status="in_progress",
),
))
yield _send_event(
ResponseContentPartAddedEvent(
type="response.content_part.added",
sequence_number=-1,
output_index=current_output_index,
item_id=current_item_id,
content_index=current_content_index,
part=ResponseOutputText(
type="output_text",
text="",
annotations=[],
logprobs=[],
),
))
yield _send_event(
ResponseReasoningTextDeltaEvent(
type="response.reasoning_text.delta",
item_id=current_item_id,
output_index=current_output_index,
content_index=current_content_index,
delta=parser.last_content_delta,
sequence_number=-1,
))
# TODO(JunyiXu-nv): support built-in tools(python/browser/code interpreter)
final_response = await create_response(generator, request, sampling_params,
model_name, conversation_store,
final_res, enable_store,
response_creation_time)
yield _send_event(
ResponseCompletedEvent(
type="response.completed",
sequence_number=-1,
response=final_response.model_dump(),
))

View File

@ -1513,6 +1513,13 @@ def test_openai_chat_harmony(llm_root, llm_venv):
str(test_root / "_test_openai_chat_harmony.py")])
def test_openai_responses(llm_root, llm_venv):
test_root = unittest_path() / "llmapi" / "apps"
llm_venv.run_cmd(
["-m", "pytest",
str(test_root / "_test_openai_responses.py")])
def test_openai_prometheus(llm_root, llm_venv):
test_root = unittest_path() / "llmapi" / "apps"
llm_venv.run_cmd(

View File

@ -104,6 +104,7 @@ l0_h100:
- test_e2e.py::test_trtllm_bench_request_rate_and_concurrency[enable_concurrency-enable_request_rate] # negative test
- test_e2e.py::test_trtllm_bench_help_sanity[meta-llama/Llama-3.1-8B]
- test_e2e.py::test_openai_chat_harmony
- test_e2e.py::test_openai_responses
- test_e2e.py::test_ptp_quickstart_multimodal[gemma-3-27b-it-gemma/gemma-3-27b-it-image-True]
# ------------- AutoDeploy tests ---------------
- accuracy/test_llm_api_autodeploy.py::TestLlama3_1_8B::test_auto_dtype

View File

@ -14,10 +14,18 @@ def model():
return "gpt_oss/gpt-oss-20b/"
@pytest.fixture(scope="module",
params=[0, 2],
ids=["disable_processpool", "enable_processpool"])
def num_postprocess_workers(request):
return request.param
@pytest.fixture(scope="module")
def server(model: str):
def server(model: str, num_postprocess_workers: int):
model_path = get_model_path(model)
with RemoteOpenAIServer(model_path) as remote_server:
args = ["--num_postprocess_workers", f"{num_postprocess_workers}"]
with RemoteOpenAIServer(model_path, args) as remote_server:
yield remote_server
@ -147,6 +155,10 @@ async def test_streaming(client: openai.AsyncOpenAI, model: str):
collected_chunks = []
collected_messages = []
async for chunk in response:
# Last streaming response will only contains usage info
if len(chunk.choices) <= 0:
continue
collected_chunks.append(chunk)
collected_messages.append(chunk.choices[0].delta)
@ -198,6 +210,10 @@ async def test_streaming_tool_call(client: openai.AsyncOpenAI, model: str):
reasoning_chunks: list[str] = []
tool_arg_chunks: list[str] = []
async for chunk in response:
# Last streaming response will only contains usage info
if len(chunk.choices) <= 0:
continue
delta = chunk.choices[0].delta
if hasattr(delta, "tool_calls") and delta.tool_calls:
function = delta.tool_calls[0].function

View File

@ -0,0 +1,241 @@
import json
import openai
import pytest
from openai.types.responses import (ResponseCompletedEvent,
ResponseReasoningTextDeltaEvent,
ResponseTextDeltaEvent)
from ..test_llm import get_model_path
from .openai_server import RemoteOpenAIServer
pytestmark = pytest.mark.threadleak(enabled=False)
@pytest.fixture(scope="module", ids=["GPT-OSS-20B"])
def model():
return "gpt_oss/gpt-oss-20b/"
@pytest.fixture(scope="module")
def server(model: str):
model_path = get_model_path(model)
with RemoteOpenAIServer(model_path) as remote_server:
yield remote_server
@pytest.fixture(scope="module")
def client(server: RemoteOpenAIServer):
return server.get_async_client()
def check_reponse(response, prefix=""):
reasoning_exist, message_exist = False, False
for output in response.output:
if output.type == "reasoning":
reasoning_exist = True
elif output.type == "message":
message_exist = True
assert reasoning_exist, f"{prefix}Reasoning content not exists!"
assert message_exist, f"{prefix}Message content not exists!"
def check_tool_calling(response, first_resp=True, prefix=""):
reasoning_exist, tool_call_exist, message_exist = False, False, False
function_call = None
for output in response.output:
if output.type == "reasoning":
reasoning_exist = True
elif output.type == "function_call":
tool_call_exist = True
function_call = output
elif output.type == "message":
message_exist = True
if first_resp:
assert reasoning_exist and tool_call_exist, f"{prefix}Invalid tool calling 1st response"
assert not message_exist, f"{prefix}Invalid tool calling 1st response"
return function_call
else:
assert reasoning_exist and message_exist, f"{prefix}Invalid tool calling 2nd response"
assert not tool_call_exist, f"{prefix}Invalid tool calling 2nd response"
@pytest.mark.asyncio(loop_scope="module")
async def test_reasoning(client: openai.AsyncOpenAI, model: str):
response = await client.responses.create(
model=model, input="Which one is larger as numeric, 9.9 or 9.11?")
check_reponse(response, "test_reasoning: ")
@pytest.mark.asyncio(loop_scope="module")
async def test_reasoning_effort(client: openai.AsyncOpenAI, model: str):
for effort in ["low", "medium", "high"]:
response = await client.responses.create(
model=model,
instructions="Use less than 1024 tokens for reasoning",
input="Which one is larger as numeric, 9.9 or 9.11?",
reasoning={"effort": effort})
check_reponse(response, f"test_reasoning_effort_{effort}: ")
@pytest.mark.asyncio(loop_scope="module")
async def test_chat(client: openai.AsyncOpenAI, model: str):
response = await client.responses.create(model=model,
input=[{
"role":
"developer",
"content":
"Respond in Chinese."
}, {
"role": "user",
"content": "Hello!"
}, {
"role":
"assistant",
"content":
"Hello! How can I help you?"
}, {
"role": "user",
"content": "Tell me a joke."
}])
check_reponse(response, "test_chat: ")
@pytest.mark.asyncio(loop_scope="module")
async def test_multi_turn_chat(client: openai.AsyncOpenAI, model: str):
response = await client.responses.create(model=model,
input="What is the answer of 1+1?")
check_reponse(response, "test_multi_turn_chat_1: ")
response_2 = await client.responses.create(
model=model,
input="What is the answer of previous question?",
previous_response_id=response.id)
check_reponse(response_2, "test_multi_turn_chat_2: ")
def get_current_weather(location: str, format: str = "celsius") -> dict:
return {"sunny": True, "temperature": 20 if format == "celsius" else 68}
@pytest.mark.asyncio(loop_scope="module")
async def test_tool_calls(client: openai.AsyncOpenAI, model: str):
tool_get_current_weather = {
"type": "function",
"name": "get_current_weather",
"description": "Gets the current weather in the provided location.",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"format": {
"type": "string",
"description": "default: celsius",
"enum": ["celsius", "fahrenheit"],
},
},
"required": ["location"],
}
}
messages = [{"role": "user", "content": "What is the weather like in SF?"}]
response = await client.responses.create(
model=model,
input=messages,
tools=[tool_get_current_weather],
)
messages.extend(response.output)
function_call = check_tool_calling(response, True, "test_tool_calls: ")
assert function_call.name == "get_current_weather"
args = json.loads(function_call.arguments)
answer = get_current_weather(**args)
messages.append({
"type": "function_call_output",
"call_id": function_call.call_id,
"output": json.dumps(answer),
})
response = await client.responses.create(model=model,
input=messages,
tools=[tool_get_current_weather])
check_tool_calling(response, False, "test_tool_calls: ")
@pytest.mark.asyncio(loop_scope="module")
async def test_streaming(client: openai.AsyncOpenAI, model: str):
stream = await client.responses.create(
model=model,
input="Explain the theory of relativity in brief.",
stream=True,
)
reasoning_deltas, message_deltas = list(), list()
async for event in stream:
if isinstance(event, ResponseTextDeltaEvent):
message_deltas.append(event.delta)
elif isinstance(event, ResponseReasoningTextDeltaEvent):
reasoning_deltas.append(event.delta)
full_response = "".join(message_deltas)
full_reasoning_response = "".join(reasoning_deltas)
assert full_response
assert full_reasoning_response
@pytest.mark.asyncio(loop_scope="module")
async def test_streaming_tool_call(client: openai.AsyncOpenAI, model: str):
tool_get_current_weather = {
"type": "function",
"name": "get_current_weather",
"description": "Gets the current weather in the provided location.",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"format": {
"type": "string",
"description": "default: celsius",
"enum": ["celsius", "fahrenheit"],
},
},
"required": ["location"],
}
}
messages = [{"role": "user", "content": "What is the weather like in SF?"}]
stream = await client.responses.create(
model=model,
input=messages,
tools=[tool_get_current_weather],
stream=True,
)
function_call = None
reasoning_deltas = list()
async for event in stream:
if isinstance(event, ResponseCompletedEvent):
for output in event.response.output:
if output.type == "function_call":
function_call = output
elif isinstance(event, ResponseReasoningTextDeltaEvent):
reasoning_deltas.append(event.delta)
reasoning = "".join(reasoning_deltas)
tool_args = json.loads(function_call.arguments)
assert function_call.name == "get_current_weather", "wrong function calling name"
assert tool_args, "tool args not exists!"
assert reasoning, "reasoning not exists!"
get_current_weather(**tool_args)