[TRTLLM-7207][feat] Chat completions API for gpt-oss (#7261)

Signed-off-by: Pengyun Lin <81065165+LinPoly@users.noreply.github.com>
This commit is contained in:
Pengyun Lin 2025-08-28 10:22:06 +08:00 committed by GitHub
parent f30768e70d
commit c1e7fb9042
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 2048 additions and 166 deletions

View File

@ -35,13 +35,7 @@ OpenAI MoE models support function calling. Here is an example based on [XGramma
First, launch a server with XGrammar enabled:
```bash
cat > ./extra_llm_api_options.yaml <<EOF
guided_decoding_backend: xgrammar
EOF
trtllm-serve <model> \
--backend pytorch \
--extra_llm_api_options extra_llm_api_options.yaml
trtllm-serve <model>
```
Run the [openai_chat_client_function_calling.py](./openai_chat_client_function_calling.py) script, which queries the LLM server in two steps:
@ -68,14 +62,9 @@ The output would look similar to:
```txt
[USER PROMPT] What is the weather like in SF?
[RESPONSE 1] <|channel|>analysis<|message|>The user asks: "What is the weather like in SF?" They want the weather in SF. SF likely refers to San Francisco. We need to get the current weather. We can use get_current_weather function. We need to provide location string "San Francisco, CA". We can also ask for format? By default celsius. But maybe user expects Fahrenheit? They didn't specify. We can provide celsius or Fahrenheit. We can choose default celsius. But maybe better to provide Fahrenheit because US. But default is celsius. We can provide both? We can call function with format "fahrenheit" to be user-friendly. But the function default is celsius. We can override. Let's call get_current_weather with location "San Francisco, CA" and format "fahrenheit". Then we will get the weather. Then we will respond with friendly tone. We need to call the function.<|end|><|start|>assistant<|channel|>commentary to=get_current_weather <|constrain|>json<|message|>{
"location": "San Francisco, CA",
"format": "fahrenheit"
}<|call|>
[FUNCTION CALL] get_current_weather(**{'location': 'San Francisco, CA', 'format': 'fahrenheit'})
[RESPONSE 2] <|channel|>analysis<|message|>The user asked: "What is the weather like in SF?" We have fetched the weather: sunny true, temperature 68 (F). We need to respond in a friendly tone. Provide a friendly answer: "It's sunny and 68°F in San Francisco." Possibly add a friendly comment. Also ask if they want more details.<|end|><|start|>assistant<|channel|>final<|message|>Sure thing! Its a pleasant 68°F in SanFrancisco right now, and the sun is shining. 🌞
Anything else you'd like to know about the weather or maybe some fun things to do in the city today?<|return|>
[RESPONSE 1] [COT] Need to call get_current_weather.
[RESPONSE 1] [FUNCTION CALL] get_current_weather(**{'location': 'San Francisco, CA'})
[RESPONSE 2] Its a bright, sunny day in SanFrancisco with the temperature around 20°C (68°F). Enjoy the pleasant weather!
```
The function call works successfully:
@ -95,14 +84,14 @@ The output would look like:
```txt
[USER PROMPT] What is the weather like in NY and SF?
[RESPONSE 1] <|channel|>analysis<|message|>The user asks: "What is the weather like in NY and SF?" They want the weather in New York and San Francisco. We need to provide the weather. We can use the function get_multiple_weathers. We need to provide the list of city and state strings. For New York, we can use "New York, NY". For San Francisco, "San Francisco, CA". We can call get_multiple_weathers with those two locations. We should specify format? The default is celsius. But maybe the user might want Fahrenheit? They didn't specify. We can just use default celsius. But maybe we can provide both? But the function only returns one format. We can just use default celsius. But we can also ask the user? But the user asked "What is the weather like in NY and SF?" We can just provide the weather. We can call the function. Then we will get the weather data. Then we can respond with a friendly tone. We need to call the function.<|end|><|start|>assistant<|channel|>commentary to=get_multiple_weathers <|constrain|>json<|message|>{"locations":["New York, NY","San Francisco, CA"]}<|call|>
[FUNCTION CALL] get_multiple_weathers(**{'locations': ['New York, NY', 'San Francisco, CA']})
[RESPONSE 2] <|channel|>analysis<|message|>The user asked: "What is the weather like in NY and SF?" We called get_multiple_weathers with locations ["New York, NY","San Francisco, CA"]. The function returned: [{"sunny": true, "temperature": 20}, {"sunny": true, "temperature": 20}]. That seems to be a list of two objects, each with sunny: true, temperature: 20. But we need to interpret the function output. The function get_multiple_weathers presumably returns a list of weather data for each location. But the returned data is ambiguous: we don't know which corresponds to which location. But we can assume the order matches the input order: first is New York, second is San Francisco. The temperature is 20 degrees Celsius? The function didn't specify units, but default is celsius. So 20°C. And sunny: true. So both are sunny and 20°C. We should respond in a friendly tone, summarizing the weather for both cities. We can mention that it's sunny and 20°C in both New York and San Francisco. We can also mention that it's a nice day. We can ask if they want more details. We should not mention the function call. We should just provide the answer.<|end|><|start|>assistant<|channel|>final<|message|>Sure thing! Heres the scoop:
[RESPONSE 1] [COT] Need to call get_multiple_weathers.
[RESPONSE 1] [FUNCTION CALL] get_multiple_weathers(**{'locations': ['New York, NY', 'San Francisco, CA'], 'format': 'celsius'})
[RESPONSE 2] Heres a quick snapshot of the current weather in both cities:
- **NewYork, NY**: Its sunny and a comfortable 20°C (68°F).
- **SanFrancisco, CA**: Also sunny with a pleasant 20°C (68°F).
Looks like both coasts are enjoying a bright, mild day. Let me know if youd like a forecast for later or any other details!<|return|>
| City | Weather | Temperature |
|------|---------|-------------|
| NewYork | ☀️ Sunny | 20°C |
| SanFrancisco | ☀️ Sunny | 20°C |
```
Once again, the function call works successfully, this time using a different function: `get_multiple_weathers`.

View File

@ -1,82 +1,58 @@
import argparse
import json
import re
from openai import OpenAI
system_prompt = """You are ChatGPT, a large language model trained by OpenAI.
Knowledge cutoff: 2024-06
Current date: 2025-06-28
Reasoning: high
# Valid channels: analysis, commentary, final. Channel must be included for every message.
Calls to these tools must go to the commentary channel: 'functions'."""
developer_prompt = """# Instructions
Use a friendly tone.
# Tools
## functions
namespace functions {
// Gets the location of the user.
type get_location = () => any;
// Gets the current weather in the provided location.
type get_current_weather = (_: {
// The city and state, e.g. San Francisco, CA
location: string,
format?: "celsius" | "fahrenheit", // default: celsius
}) => any;
// Gets the current weather in the provided list of locations.
type get_multiple_weathers = (_: {
// List of city and state, e.g. ["San Francisco, CA", "New York, NY"]
locations: string[],
format?: "celsius" | "fahrenheit", // default: celsius
}) => any;
} // namespace functions"""
schema_get_current_weather = {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"format": {
"type": "string",
"description": "default: celsius",
"enum": ["celsius", "fahrenheit"],
},
},
"required": ["location"],
tool_get_current_weather = {
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Gets the current weather in the provided location.",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"format": {
"type": "string",
"description": "default: celsius",
"enum": ["celsius", "fahrenheit"],
},
},
"required": ["location"],
}
}
}
schema_get_multiple_weathers = {
"type": "object",
"properties": {
"locations": {
"type":
"array",
"items": {
"type": "string"
tool_get_multiple_weathers = {
"type": "function",
"function": {
"name": "get_multiple_weathers",
"description":
"Gets the current weather in the provided list of locations.",
"parameters": {
"type": "object",
"properties": {
"locations": {
"type":
"array",
"items": {
"type": "string"
},
"description":
'List of city and state, e.g. ["San Francisco, CA", "New York, NY"]',
},
"format": {
"type": "string",
"description": "default: celsius",
"enum": ["celsius", "fahrenheit"],
},
},
"description":
'List of city and state, e.g. ["San Francisco, CA", "New York, NY"]',
},
"format": {
"type": "string",
"description": "default: celsius",
"enum": ["celsius", "fahrenheit"],
},
},
"required": ["locations"],
"required": ["locations"],
}
}
}
@ -103,14 +79,6 @@ def main():
)
messages = [
{
"role": "system",
"content": system_prompt,
},
{
"role": "developer",
"content": developer_prompt,
},
{
"role": "user",
"content": args.prompt,
@ -122,65 +90,41 @@ def main():
model=args.model,
messages=messages,
max_completion_tokens=500,
response_format={
"type":
"structural_tag",
"structures": [{
"begin":
"<|channel|>commentary to=get_current_weather <|constrain|>json<|message|>",
"schema": schema_get_current_weather,
"end": "<|call|>",
}, {
"begin":
"<|channel|>commentary to=get_multiple_weathers <|constrain|>json<|message|>",
"schema": schema_get_multiple_weathers,
"end": "<|call|>",
}],
"triggers": ["<|channel|>commentary to="],
},
stop=["<|call|>"],
extra_body={
"skip_special_tokens": False,
"include_stop_str_in_output": True,
},
tools=[tool_get_current_weather, tool_get_multiple_weathers],
)
tools = {
"get_current_weather": get_current_weather,
"get_multiple_weathers": get_multiple_weathers
}
message = chat_completion.choices[0].message
assert message, "Empty Message"
assert message.tool_calls, "Empty tool calls"
assert message.content is None, "Empty content expected"
reasoning = message.reasoning if hasattr(message, "reasoning") else None
tool_call = message.tool_calls[0]
func_name = tool_call.function.name
assert func_name in tools, "Invalid function name"
kwargs = json.loads(tool_call.function.arguments)
response_text = chat_completion.choices[0].message.content
print(f"[RESPONSE 1] {response_text}")
for regex, tool in [
(r"(<\|channel\|>commentary to=get_current_weather <\|constrain\|>json<\|message\|>)([\S\s]+)(<\|call\|>)",
get_current_weather),
(r"(<\|channel\|>commentary to=get_multiple_weathers <\|constrain\|>json<\|message\|>)([\S\s]+)(<\|call\|>)",
get_multiple_weathers)
]:
match = re.search(regex, response_text)
if match is not None:
break
else:
print("Failed to call functions, exiting...")
return
kwargs = json.loads(match.group(2))
print(f"[FUNCTION CALL] {tool.__name__}(**{kwargs})")
tool = tools[func_name]
print(f"[RESPONSE 1] [COT] {reasoning}")
print(f"[RESPONSE 1] [FUNCTION CALL] {tool.__name__}(**{kwargs})")
answer = tool(**kwargs)
messages.extend([{
"role": "assistant",
"content": match.group(0),
"reasoning": reasoning,
"tool_calls": [tool_call],
}, {
"role": f"{tool.__name__} to=assistant",
"role": "tool",
"content": json.dumps(answer),
"tool_call_id": tool_call.id
}])
chat_completion = client.chat.completions.create(
model=args.model,
messages=messages,
max_completion_tokens=500,
extra_body={
"skip_special_tokens": False,
"include_stop_str_in_output": True,
},
)
response_text = chat_completion.choices[0].message.content

View File

@ -67,3 +67,4 @@ soundfile
triton==3.3.1; platform_machine == "x86_64"
tiktoken
blobfile
openai-harmony==0.0.4

View File

@ -466,25 +466,41 @@ class GenerationExecutorWorker(GenerationExecutor):
def _deduce_max_tokens(request: GenerationRequest,
executor_config: tllm.ExecutorConfig) -> int:
if request.sampling_params.max_tokens:
return request.sampling_params.max_tokens
# deduce max_tokens when it's not set by user
max_tokens = request.sampling_params.max_tokens
query_token_len = len(
request.query_token_ids) if request.query_token_ids else 0
cp_size = 1 if (not hasattr(executor_config, "mapping")
or executor_config.mapping.cp_size
is None) else executor_config.mapping.cp_size
if not hasattr(executor_config, "max_seq_len"):
raise RuntimeError(
"max_tokens for sampling is not set and cannot be deduced")
logger.warning("`default_max_tokens` cannot be deduced")
if max_tokens is None:
raise ValueError(
"`max_tokens` must be set when `default_max_tokens` cannot be deduced"
)
splited_prompt_len = int(len(prompt_token_ids) / cp_size)
default_max_tokens = executor_config.max_seq_len - splited_prompt_len - query_token_len
if default_max_tokens < 0:
raise ValueError(
f"Deduced max_tokens {default_max_tokens} is less than 0, because"
f"prompt length {splited_prompt_len} plus query length {query_token_len} "
f"is larger than max_seq_len {executor_config.max_seq_len}")
return default_max_tokens
if default_max_tokens <= 0:
logger.warning(
f"`default_max_tokens` ({default_max_tokens}) should be greater than 0, "
f"`default_max_tokens` ({default_max_tokens}) = max_seq_len ({executor_config.max_seq_len})"
f" - `splited_prompt_len` ({splited_prompt_len}) - `query_token_len` ({query_token_len})"
)
if max_tokens is None:
raise ValueError(
"`max_tokens` must be set when `default_max_tokens` is illegal"
)
# default_max_tokens is the biggest available value
if max_tokens is None:
return default_max_tokens
elif max_tokens > default_max_tokens:
logger.warning(
f"User-specified `max_tokens` ({max_tokens}) is greater than deduced "
f"`default_max_tokens` ({default_max_tokens}), using default_max_tokens instead."
)
return default_max_tokens
return max_tokens
try:
executor_request = tllm.Request(

File diff suppressed because it is too large Load Diff

View File

@ -6,10 +6,12 @@ import uuid
from typing import Any, Dict, List, Literal, Optional, Union
import torch
from openai.types.chat import ChatCompletionAssistantMessageParam
from openai.types.chat import \
ChatCompletionContentPartParam as OpenAIChatCompletionContentPartParam
from openai.types.chat import \
ChatCompletionMessageParam as OpenAIChatCompletionMessageParam
from openai_harmony import ReasoningEffort
from pydantic import BaseModel, ConfigDict, Field, model_validator
from typing_extensions import Annotated, Required, TypedDict
@ -327,6 +329,11 @@ class FunctionCall(OpenAIBaseModel):
arguments: str
class DeltaFunctionCall(OpenAIBaseModel):
name: Optional[str] = None
arguments: Optional[str] = None
class ToolCall(OpenAIBaseModel):
id: str = Field(
default_factory=lambda: f"chatcmpl-tool-{str(uuid.uuid4().hex)}")
@ -334,10 +341,18 @@ class ToolCall(OpenAIBaseModel):
function: FunctionCall
class DeltaToolCall(OpenAIBaseModel):
id: Optional[str] = None
type: Optional[Literal["function"]] = None
index: int
function: Optional[DeltaFunctionCall] = None
class ChatMessage(OpenAIBaseModel):
role: str
content: str
content: Optional[str] = None
reasoning_content: Optional[str] = None
reasoning: Optional[str] = None
tool_calls: List[ToolCall] = Field(default_factory=list)
@ -378,8 +393,14 @@ class CustomChatCompletionMessageParam(TypedDict, total=False):
"""
class ReasoningAssistantMessage(ChatCompletionAssistantMessageParam):
"""Assistant message that includes reasoning tokens."""
reasoning: Optional[str]
ChatCompletionMessageParam = Union[OpenAIChatCompletionMessageParam,
CustomChatCompletionMessageParam]
CustomChatCompletionMessageParam,
ReasoningAssistantMessage]
class ChatCompletionLogProbs(OpenAIBaseModel):
@ -416,7 +437,9 @@ class DeltaMessage(OpenAIBaseModel):
role: Optional[str] = None
content: Optional[str] = None
reasoning_content: Optional[str] = None
tool_calls: List[ToolCall] = Field(default_factory=list)
# For GPT-OSS style reasoning
reasoning: Optional[str] = None
tool_calls: Optional[List[DeltaToolCall]] = None
class ChatCompletionResponseStreamChoice(OpenAIBaseModel):
@ -469,8 +492,8 @@ class ChatCompletionRequest(OpenAIBaseModel):
logit_bias: Optional[Dict[str, float]] = None
logprobs: Optional[int] = None
top_logprobs: Optional[int] = 0
max_completion_tokens: int = Field(default=None,
validation_alias='max_tokens')
max_completion_tokens: Optional[int] = Field(default=None,
validation_alias='max_tokens')
n: int = 1
presence_penalty: Optional[float] = 0.0
response_format: Optional[ResponseFormat] = None
@ -481,9 +504,17 @@ class ChatCompletionRequest(OpenAIBaseModel):
temperature: Optional[float] = 1.0
top_p: Optional[float] = 1.0
tools: Optional[List[ChatCompletionToolsParam]] = None
tool_choice: Optional[Union[Literal["none"],
tool_choice: Optional[Union[Literal["none", "auto"],
ChatCompletionNamedToolChoiceParam]] = "none"
user: Optional[str] = None
reasoning_effort: Optional[ReasoningEffort | Literal[
"low", "medium", "high"]] = Field(
default=ReasoningEffort.LOW,
description=(
"The level of reasoning effort to use. Controls how much "
"reasoning is shown in the model's response. Options: "
"'low', 'medium', 'high'."),
)
# doc: begin-chat-completion-sampling-params
best_of: Optional[int] = None
@ -610,9 +641,9 @@ class ChatCompletionRequest(OpenAIBaseModel):
@model_validator(mode="before")
@classmethod
def check_tool_choice(cls, data):
if "tool_choice" not in data and data.get("tools"):
data["tool_choice"] = "auto"
if "tool_choice" in data and data["tool_choice"] != "none":
if not isinstance(data["tool_choice"], dict):
raise ValueError("Currently only named tools are supported.")
if "tools" not in data or data["tools"] is None:
raise ValueError(
"When using `tool_choice`, `tools` must be set.")

View File

@ -49,6 +49,9 @@ from tensorrt_llm.serve.postprocess_handlers import (
from tensorrt_llm.version import __version__ as VERSION
from .._utils import nvtx_mark, set_prometheus_multiproc_dir
from .harmony_adapter import (HarmonyAdapter, handle_non_streaming_response,
handle_streaming_response,
maybe_transform_reasoning_effort)
# yapf: enale
TIMEOUT_KEEP_ALIVE = 5 # seconds.
@ -98,6 +101,10 @@ class OpenAIServer:
self.perf_metrics = deque(maxlen=max_perf_metrics)
self.perf_metrics_lock = asyncio.Lock()
# gpt-oss
self.harmony_adapter: HarmonyAdapter | None = None
self.use_harmony = self.model_config.model_type == "gpt_oss"
@asynccontextmanager
async def lifespan(app: FastAPI):
if self.metadata_server is not None:
@ -158,7 +165,6 @@ class OpenAIServer:
return JSONResponse(content=error_response.model_dump(),
status_code=error_response.code)
def register_routes(self):
self.app.add_api_route("/health", self.health, methods=["GET"])
self.app.add_api_route("/health_generate", self.health_generate, methods=["GET"])
@ -173,7 +179,7 @@ class OpenAIServer:
self.openai_completion,
methods=["POST"])
self.app.add_api_route("/v1/chat/completions",
self.openai_chat,
self.openai_chat if not self.use_harmony else self.chat_harmony,
methods=["POST"])
if self.llm.args.return_perf_metrics:
# register /prometheus/metrics
@ -661,6 +667,78 @@ class OpenAIServer:
logger.error(traceback.format_exc())
return self.create_error_response(str(e))
async def chat_harmony(self, request: ChatCompletionRequest, raw_request: Request) -> Response:
"""
Chat Completion API with harmony format support.
Supports both streaming and non-streaming modes.
"""
try:
# Initialize HarmonyAdapter
# NOTE: WAR for Disagg failure, may affect perf if no warmup
if not self.harmony_adapter:
self.harmony_adapter = HarmonyAdapter()
# Convert Pydantic models to dictionaries for JSON serialization (standard pattern)
tools_dict = None
if request.tools:
tools_dict = [tool.model_dump() for tool in request.tools]
# Reasoning effort precedence: request.reasoning_effort > system message parsing > serving default
reasoning_effort = maybe_transform_reasoning_effort(request.reasoning_effort)
# Get tool_choice from request
tool_choice = getattr(request, 'tool_choice', None)
try:
harmony_tokens = self.harmony_adapter.openai_to_harmony_tokens(
request.messages,
tools_dict,
reasoning_effort=reasoning_effort,
tool_choice=tool_choice
)
except Exception as e:
logger.error(f"messages_dict: {request.messages}")
logger.error(f"tools_dict: {tools_dict}")
logger.error(f"request: {request}")
raise e
# Get harmony stop tokens
harmony_stop_tokens = self.harmony_adapter.get_stop_tokens()
if request.stop_token_ids:
request.stop_token_ids.extend(harmony_stop_tokens)
else:
request.stop_token_ids = harmony_stop_tokens
sampling_params = request.to_sampling_params(
vocab_size=self.tokenizer.tokenizer.vocab_size)
sampling_params.detokenize = False # Harmony adapter handles detokenization
# Generate
promise = self.llm.generate_async(
inputs=harmony_tokens,
sampling_params=sampling_params,
streaming=bool(request.stream),
lora_request=request.lora_request,
)
# Disconnect cancellation
asyncio.create_task(self.await_disconnected(raw_request, promise))
# Handle streaming
if request.stream:
return StreamingResponse(
handle_streaming_response(
self.harmony_adapter, promise,
str(promise.request_id), request,
),
media_type="text/event-stream"
)
else:
response = await handle_non_streaming_response(self.harmony_adapter, promise, request)
return JSONResponse(response.model_dump())
except Exception as e:
logger.error("Error in harmony chat completion: %s", e)
logger.debug("Error details: %s", traceback.format_exc())
return self.create_error_response(message=str(e), err_type="internal_error")
async def __call__(self, host, port):
# Store the binding address for server registration
self.binding_addr = f"http://{host}:{port}"

View File

@ -1506,6 +1506,13 @@ def test_openai_perf_metrics(llm_root, llm_venv):
str(test_root / "_test_openai_perf_metrics.py")])
def test_openai_chat_harmony(llm_root, llm_venv):
test_root = unittest_path() / "llmapi" / "apps"
llm_venv.run_cmd(
["-m", "pytest",
str(test_root / "_test_openai_chat_harmony.py")])
def test_openai_prometheus(llm_root, llm_venv):
test_root = unittest_path() / "llmapi" / "apps"
llm_venv.run_cmd(

View File

@ -102,6 +102,7 @@ l0_h100:
- test_e2e.py::test_trtllm_bench_request_rate_and_concurrency[enable_concurrency-]
- test_e2e.py::test_trtllm_bench_request_rate_and_concurrency[enable_concurrency-enable_request_rate] # negative test
- test_e2e.py::test_trtllm_bench_help_sanity[meta-llama/Llama-3.1-8B]
- test_e2e.py::test_openai_chat_harmony
- test_e2e.py::test_ptp_quickstart_multimodal[gemma-3-27b-it-gemma/gemma-3-27b-it-image-True]
# ------------- AutoDeploy tests ---------------
- accuracy/test_llm_api_autodeploy.py::TestLlama3_1_8B::test_auto_dtype

View File

@ -0,0 +1,217 @@
import json
import openai
import pytest
from ..test_llm import get_model_path
from .openai_server import RemoteOpenAIServer
pytestmark = pytest.mark.threadleak(enabled=False)
@pytest.fixture(scope="module", ids=["GPT-OSS-20B"])
def model():
return "gpt_oss/gpt-oss-20b/"
@pytest.fixture(scope="module")
def server(model: str):
model_path = get_model_path(model)
with RemoteOpenAIServer(model_path) as remote_server:
yield remote_server
@pytest.fixture(scope="module")
def client(server: RemoteOpenAIServer):
return server.get_async_client()
@pytest.mark.asyncio(loop_scope="module")
async def test_reasoning(client: openai.AsyncOpenAI, model: str):
response = await client.chat.completions.create(
model=model,
messages=[{
"role": "user",
"content": "Which one is larger as numeric, 9.9 or 9.11?"
}])
assert response.choices[0].message.content
assert response.choices[0].message.reasoning
@pytest.mark.asyncio(loop_scope="module")
async def test_reasoning_effort(client: openai.AsyncOpenAI, model: str):
response = await client.chat.completions.create(
model=model,
messages=[{
"role": "user",
"content": "Which one is larger as numeric, 9.9 or 9.11?"
}],
reasoning_effort="Medium")
assert response.choices[0].message.content
assert response.choices[0].message.reasoning
@pytest.mark.asyncio(loop_scope="module")
async def test_chat(client: openai.AsyncOpenAI, model: str):
response = await client.chat.completions.create(
model=model,
messages=[{
"role": "developer",
"content": "Respond in Chinese."
}, {
"role": "user",
"content": "Hello!"
}, {
"role": "assistant",
"content": "Hello! How can I help you?"
}, {
"role": "user",
"content": "Tell me a joke."
}])
assert response.choices[0].message.content
assert response.choices[0].message.reasoning
def get_current_weather(location: str, format: str = "celsius") -> dict:
return {"sunny": True, "temperature": 20 if format == "celsius" else 68}
@pytest.mark.asyncio(loop_scope="module")
async def test_tool_calls(client: openai.AsyncOpenAI, model: str):
tool_get_current_weather = {
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Gets the current weather in the provided location.",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description":
"The city and state, e.g. San Francisco, CA",
},
"format": {
"type": "string",
"description": "default: celsius",
"enum": ["celsius", "fahrenheit"],
},
},
"required": ["location"],
}
}
}
messages = [{"role": "user", "content": "What is the weather like in SF?"}]
response = await client.chat.completions.create(
model=model,
messages=messages,
tools=[tool_get_current_weather],
)
message = response.choices[0].message
assert response.choices[0].finish_reason == "tool_calls"
assert message.content is None
assert message.reasoning
assert message.tool_calls
assert len(message.tool_calls) == 1
tool_call = message.tool_calls[0]
assert tool_call.function.name == "get_current_weather"
args = json.loads(tool_call.function.arguments)
answer = get_current_weather(**args)
messages.extend([{
"role": "assistant",
"tool_calls": [tool_call],
"reasoning": message.reasoning
}, {
"role": "tool",
"content": json.dumps(answer),
"tool_call_id": tool_call.id
}])
response = await client.chat.completions.create(
model=model,
messages=messages,
)
message = response.choices[0].message
assert message.content
@pytest.mark.asyncio(loop_scope="module")
async def test_streaming(client: openai.AsyncOpenAI, model: str):
response = await client.chat.completions.create(
model=model,
messages=[{
"role": "user",
"content": "Explain the theory of relativity in brief."
}],
stream=True,
)
collected_chunks = []
collected_messages = []
async for chunk in response:
collected_chunks.append(chunk)
collected_messages.append(chunk.choices[0].delta)
full_response = "".join([
m.content for m in collected_messages
if hasattr(m, "content") and m.content
])
full_reasoning_response = "".join([
m.reasoning for m in collected_messages
if hasattr(m, "reasoning") and m.reasoning
])
assert full_response
assert full_reasoning_response
@pytest.mark.asyncio(loop_scope="module")
async def test_streaming_tool_call(client: openai.AsyncOpenAI, model: str):
tool_get_current_weather = {
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Gets the current weather in the provided location.",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description":
"The city and state, e.g. San Francisco, CA",
},
"format": {
"type": "string",
"description": "default: celsius",
"enum": ["celsius", "fahrenheit"],
},
},
"required": ["location"],
}
}
}
messages = [{"role": "user", "content": "What is the weather like in SF?"}]
response = await client.chat.completions.create(
model=model,
messages=messages,
tools=[tool_get_current_weather],
stream=True,
)
tool_name: str
reasoning_chunks: list[str] = []
tool_arg_chunks: list[str] = []
async for chunk in response:
delta = chunk.choices[0].delta
if hasattr(delta, "tool_calls") and delta.tool_calls:
function = delta.tool_calls[0].function
if hasattr(function, "name") and function.name:
tool_name = function.name
if hasattr(function, "arguments") and function.arguments:
args_str = function.arguments
tool_arg_chunks.append(args_str)
if hasattr(delta, "reasoning") and delta.reasoning:
reasoning_chunks.append(delta.reasoning)
reasoning = "".join(reasoning_chunks)
tool_args = "".join(tool_arg_chunks)
assert tool_name == "get_current_weather"
assert tool_args
assert reasoning
args = json.loads(tool_args)
get_current_weather(**args)