mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
[TRTLLM-7207][feat] Chat completions API for gpt-oss (#7261)
Signed-off-by: Pengyun Lin <81065165+LinPoly@users.noreply.github.com>
This commit is contained in:
parent
f30768e70d
commit
c1e7fb9042
@ -35,13 +35,7 @@ OpenAI MoE models support function calling. Here is an example based on [XGramma
|
||||
First, launch a server with XGrammar enabled:
|
||||
|
||||
```bash
|
||||
cat > ./extra_llm_api_options.yaml <<EOF
|
||||
guided_decoding_backend: xgrammar
|
||||
EOF
|
||||
|
||||
trtllm-serve <model> \
|
||||
--backend pytorch \
|
||||
--extra_llm_api_options extra_llm_api_options.yaml
|
||||
trtllm-serve <model>
|
||||
```
|
||||
|
||||
Run the [openai_chat_client_function_calling.py](./openai_chat_client_function_calling.py) script, which queries the LLM server in two steps:
|
||||
@ -68,14 +62,9 @@ The output would look similar to:
|
||||
|
||||
```txt
|
||||
[USER PROMPT] What is the weather like in SF?
|
||||
[RESPONSE 1] <|channel|>analysis<|message|>The user asks: "What is the weather like in SF?" They want the weather in SF. SF likely refers to San Francisco. We need to get the current weather. We can use get_current_weather function. We need to provide location string "San Francisco, CA". We can also ask for format? By default celsius. But maybe user expects Fahrenheit? They didn't specify. We can provide celsius or Fahrenheit. We can choose default celsius. But maybe better to provide Fahrenheit because US. But default is celsius. We can provide both? We can call function with format "fahrenheit" to be user-friendly. But the function default is celsius. We can override. Let's call get_current_weather with location "San Francisco, CA" and format "fahrenheit". Then we will get the weather. Then we will respond with friendly tone. We need to call the function.<|end|><|start|>assistant<|channel|>commentary to=get_current_weather <|constrain|>json<|message|>{
|
||||
"location": "San Francisco, CA",
|
||||
"format": "fahrenheit"
|
||||
}<|call|>
|
||||
[FUNCTION CALL] get_current_weather(**{'location': 'San Francisco, CA', 'format': 'fahrenheit'})
|
||||
[RESPONSE 2] <|channel|>analysis<|message|>The user asked: "What is the weather like in SF?" We have fetched the weather: sunny true, temperature 68 (F). We need to respond in a friendly tone. Provide a friendly answer: "It's sunny and 68°F in San Francisco." Possibly add a friendly comment. Also ask if they want more details.<|end|><|start|>assistant<|channel|>final<|message|>Sure thing! It’s a pleasant 68 °F in San Francisco right now, and the sun is shining. 🌞
|
||||
|
||||
Anything else you'd like to know about the weather or maybe some fun things to do in the city today?<|return|>
|
||||
[RESPONSE 1] [COT] Need to call get_current_weather.
|
||||
[RESPONSE 1] [FUNCTION CALL] get_current_weather(**{'location': 'San Francisco, CA'})
|
||||
[RESPONSE 2] It’s a bright, sunny day in San Francisco with the temperature around 20 °C (68 °F). Enjoy the pleasant weather!
|
||||
```
|
||||
|
||||
The function call works successfully:
|
||||
@ -95,14 +84,14 @@ The output would look like:
|
||||
|
||||
```txt
|
||||
[USER PROMPT] What is the weather like in NY and SF?
|
||||
[RESPONSE 1] <|channel|>analysis<|message|>The user asks: "What is the weather like in NY and SF?" They want the weather in New York and San Francisco. We need to provide the weather. We can use the function get_multiple_weathers. We need to provide the list of city and state strings. For New York, we can use "New York, NY". For San Francisco, "San Francisco, CA". We can call get_multiple_weathers with those two locations. We should specify format? The default is celsius. But maybe the user might want Fahrenheit? They didn't specify. We can just use default celsius. But maybe we can provide both? But the function only returns one format. We can just use default celsius. But we can also ask the user? But the user asked "What is the weather like in NY and SF?" We can just provide the weather. We can call the function. Then we will get the weather data. Then we can respond with a friendly tone. We need to call the function.<|end|><|start|>assistant<|channel|>commentary to=get_multiple_weathers <|constrain|>json<|message|>{"locations":["New York, NY","San Francisco, CA"]}<|call|>
|
||||
[FUNCTION CALL] get_multiple_weathers(**{'locations': ['New York, NY', 'San Francisco, CA']})
|
||||
[RESPONSE 2] <|channel|>analysis<|message|>The user asked: "What is the weather like in NY and SF?" We called get_multiple_weathers with locations ["New York, NY","San Francisco, CA"]. The function returned: [{"sunny": true, "temperature": 20}, {"sunny": true, "temperature": 20}]. That seems to be a list of two objects, each with sunny: true, temperature: 20. But we need to interpret the function output. The function get_multiple_weathers presumably returns a list of weather data for each location. But the returned data is ambiguous: we don't know which corresponds to which location. But we can assume the order matches the input order: first is New York, second is San Francisco. The temperature is 20 degrees Celsius? The function didn't specify units, but default is celsius. So 20°C. And sunny: true. So both are sunny and 20°C. We should respond in a friendly tone, summarizing the weather for both cities. We can mention that it's sunny and 20°C in both New York and San Francisco. We can also mention that it's a nice day. We can ask if they want more details. We should not mention the function call. We should just provide the answer.<|end|><|start|>assistant<|channel|>final<|message|>Sure thing! Here’s the scoop:
|
||||
[RESPONSE 1] [COT] Need to call get_multiple_weathers.
|
||||
[RESPONSE 1] [FUNCTION CALL] get_multiple_weathers(**{'locations': ['New York, NY', 'San Francisco, CA'], 'format': 'celsius'})
|
||||
[RESPONSE 2] Here’s a quick snapshot of the current weather in both cities:
|
||||
|
||||
- **New York, NY**: It’s sunny and a comfortable 20 °C (68 °F).
|
||||
- **San Francisco, CA**: Also sunny with a pleasant 20 °C (68 °F).
|
||||
|
||||
Looks like both coasts are enjoying a bright, mild day. Let me know if you’d like a forecast for later or any other details!<|return|>
|
||||
| City | Weather | Temperature |
|
||||
|------|---------|-------------|
|
||||
| New York | ☀️ Sunny | 20 °C |
|
||||
| San Francisco | ☀️ Sunny | 20 °C |
|
||||
```
|
||||
|
||||
Once again, the function call works successfully, this time using a different function: `get_multiple_weathers`.
|
||||
|
||||
@ -1,82 +1,58 @@
|
||||
import argparse
|
||||
import json
|
||||
import re
|
||||
|
||||
from openai import OpenAI
|
||||
|
||||
system_prompt = """You are ChatGPT, a large language model trained by OpenAI.
|
||||
Knowledge cutoff: 2024-06
|
||||
Current date: 2025-06-28
|
||||
|
||||
Reasoning: high
|
||||
|
||||
# Valid channels: analysis, commentary, final. Channel must be included for every message.
|
||||
Calls to these tools must go to the commentary channel: 'functions'."""
|
||||
|
||||
developer_prompt = """# Instructions
|
||||
|
||||
Use a friendly tone.
|
||||
|
||||
# Tools
|
||||
|
||||
## functions
|
||||
|
||||
namespace functions {
|
||||
|
||||
// Gets the location of the user.
|
||||
type get_location = () => any;
|
||||
|
||||
// Gets the current weather in the provided location.
|
||||
type get_current_weather = (_: {
|
||||
// The city and state, e.g. San Francisco, CA
|
||||
location: string,
|
||||
format?: "celsius" | "fahrenheit", // default: celsius
|
||||
}) => any;
|
||||
|
||||
// Gets the current weather in the provided list of locations.
|
||||
type get_multiple_weathers = (_: {
|
||||
// List of city and state, e.g. ["San Francisco, CA", "New York, NY"]
|
||||
locations: string[],
|
||||
format?: "celsius" | "fahrenheit", // default: celsius
|
||||
}) => any;
|
||||
|
||||
} // namespace functions"""
|
||||
|
||||
schema_get_current_weather = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA",
|
||||
},
|
||||
"format": {
|
||||
"type": "string",
|
||||
"description": "default: celsius",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
},
|
||||
},
|
||||
"required": ["location"],
|
||||
tool_get_current_weather = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_current_weather",
|
||||
"description": "Gets the current weather in the provided location.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA",
|
||||
},
|
||||
"format": {
|
||||
"type": "string",
|
||||
"description": "default: celsius",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
},
|
||||
},
|
||||
"required": ["location"],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
schema_get_multiple_weathers = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"locations": {
|
||||
"type":
|
||||
"array",
|
||||
"items": {
|
||||
"type": "string"
|
||||
tool_get_multiple_weathers = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_multiple_weathers",
|
||||
"description":
|
||||
"Gets the current weather in the provided list of locations.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"locations": {
|
||||
"type":
|
||||
"array",
|
||||
"items": {
|
||||
"type": "string"
|
||||
},
|
||||
"description":
|
||||
'List of city and state, e.g. ["San Francisco, CA", "New York, NY"]',
|
||||
},
|
||||
"format": {
|
||||
"type": "string",
|
||||
"description": "default: celsius",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
},
|
||||
},
|
||||
"description":
|
||||
'List of city and state, e.g. ["San Francisco, CA", "New York, NY"]',
|
||||
},
|
||||
"format": {
|
||||
"type": "string",
|
||||
"description": "default: celsius",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
},
|
||||
},
|
||||
"required": ["locations"],
|
||||
"required": ["locations"],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -103,14 +79,6 @@ def main():
|
||||
)
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": system_prompt,
|
||||
},
|
||||
{
|
||||
"role": "developer",
|
||||
"content": developer_prompt,
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": args.prompt,
|
||||
@ -122,65 +90,41 @@ def main():
|
||||
model=args.model,
|
||||
messages=messages,
|
||||
max_completion_tokens=500,
|
||||
response_format={
|
||||
"type":
|
||||
"structural_tag",
|
||||
"structures": [{
|
||||
"begin":
|
||||
"<|channel|>commentary to=get_current_weather <|constrain|>json<|message|>",
|
||||
"schema": schema_get_current_weather,
|
||||
"end": "<|call|>",
|
||||
}, {
|
||||
"begin":
|
||||
"<|channel|>commentary to=get_multiple_weathers <|constrain|>json<|message|>",
|
||||
"schema": schema_get_multiple_weathers,
|
||||
"end": "<|call|>",
|
||||
}],
|
||||
"triggers": ["<|channel|>commentary to="],
|
||||
},
|
||||
stop=["<|call|>"],
|
||||
extra_body={
|
||||
"skip_special_tokens": False,
|
||||
"include_stop_str_in_output": True,
|
||||
},
|
||||
tools=[tool_get_current_weather, tool_get_multiple_weathers],
|
||||
)
|
||||
tools = {
|
||||
"get_current_weather": get_current_weather,
|
||||
"get_multiple_weathers": get_multiple_weathers
|
||||
}
|
||||
message = chat_completion.choices[0].message
|
||||
assert message, "Empty Message"
|
||||
assert message.tool_calls, "Empty tool calls"
|
||||
assert message.content is None, "Empty content expected"
|
||||
reasoning = message.reasoning if hasattr(message, "reasoning") else None
|
||||
tool_call = message.tool_calls[0]
|
||||
func_name = tool_call.function.name
|
||||
assert func_name in tools, "Invalid function name"
|
||||
kwargs = json.loads(tool_call.function.arguments)
|
||||
|
||||
response_text = chat_completion.choices[0].message.content
|
||||
print(f"[RESPONSE 1] {response_text}")
|
||||
|
||||
for regex, tool in [
|
||||
(r"(<\|channel\|>commentary to=get_current_weather <\|constrain\|>json<\|message\|>)([\S\s]+)(<\|call\|>)",
|
||||
get_current_weather),
|
||||
(r"(<\|channel\|>commentary to=get_multiple_weathers <\|constrain\|>json<\|message\|>)([\S\s]+)(<\|call\|>)",
|
||||
get_multiple_weathers)
|
||||
]:
|
||||
match = re.search(regex, response_text)
|
||||
if match is not None:
|
||||
break
|
||||
else:
|
||||
print("Failed to call functions, exiting...")
|
||||
return
|
||||
|
||||
kwargs = json.loads(match.group(2))
|
||||
print(f"[FUNCTION CALL] {tool.__name__}(**{kwargs})")
|
||||
tool = tools[func_name]
|
||||
print(f"[RESPONSE 1] [COT] {reasoning}")
|
||||
print(f"[RESPONSE 1] [FUNCTION CALL] {tool.__name__}(**{kwargs})")
|
||||
answer = tool(**kwargs)
|
||||
|
||||
messages.extend([{
|
||||
"role": "assistant",
|
||||
"content": match.group(0),
|
||||
"reasoning": reasoning,
|
||||
"tool_calls": [tool_call],
|
||||
}, {
|
||||
"role": f"{tool.__name__} to=assistant",
|
||||
"role": "tool",
|
||||
"content": json.dumps(answer),
|
||||
"tool_call_id": tool_call.id
|
||||
}])
|
||||
|
||||
chat_completion = client.chat.completions.create(
|
||||
model=args.model,
|
||||
messages=messages,
|
||||
max_completion_tokens=500,
|
||||
extra_body={
|
||||
"skip_special_tokens": False,
|
||||
"include_stop_str_in_output": True,
|
||||
},
|
||||
)
|
||||
|
||||
response_text = chat_completion.choices[0].message.content
|
||||
|
||||
@ -67,3 +67,4 @@ soundfile
|
||||
triton==3.3.1; platform_machine == "x86_64"
|
||||
tiktoken
|
||||
blobfile
|
||||
openai-harmony==0.0.4
|
||||
|
||||
@ -466,25 +466,41 @@ class GenerationExecutorWorker(GenerationExecutor):
|
||||
|
||||
def _deduce_max_tokens(request: GenerationRequest,
|
||||
executor_config: tllm.ExecutorConfig) -> int:
|
||||
if request.sampling_params.max_tokens:
|
||||
return request.sampling_params.max_tokens
|
||||
# deduce max_tokens when it's not set by user
|
||||
max_tokens = request.sampling_params.max_tokens
|
||||
query_token_len = len(
|
||||
request.query_token_ids) if request.query_token_ids else 0
|
||||
cp_size = 1 if (not hasattr(executor_config, "mapping")
|
||||
or executor_config.mapping.cp_size
|
||||
is None) else executor_config.mapping.cp_size
|
||||
if not hasattr(executor_config, "max_seq_len"):
|
||||
raise RuntimeError(
|
||||
"max_tokens for sampling is not set and cannot be deduced")
|
||||
logger.warning("`default_max_tokens` cannot be deduced")
|
||||
if max_tokens is None:
|
||||
raise ValueError(
|
||||
"`max_tokens` must be set when `default_max_tokens` cannot be deduced"
|
||||
)
|
||||
splited_prompt_len = int(len(prompt_token_ids) / cp_size)
|
||||
default_max_tokens = executor_config.max_seq_len - splited_prompt_len - query_token_len
|
||||
if default_max_tokens < 0:
|
||||
raise ValueError(
|
||||
f"Deduced max_tokens {default_max_tokens} is less than 0, because"
|
||||
f"prompt length {splited_prompt_len} plus query length {query_token_len} "
|
||||
f"is larger than max_seq_len {executor_config.max_seq_len}")
|
||||
return default_max_tokens
|
||||
if default_max_tokens <= 0:
|
||||
logger.warning(
|
||||
f"`default_max_tokens` ({default_max_tokens}) should be greater than 0, "
|
||||
f"`default_max_tokens` ({default_max_tokens}) = max_seq_len ({executor_config.max_seq_len})"
|
||||
f" - `splited_prompt_len` ({splited_prompt_len}) - `query_token_len` ({query_token_len})"
|
||||
)
|
||||
if max_tokens is None:
|
||||
raise ValueError(
|
||||
"`max_tokens` must be set when `default_max_tokens` is illegal"
|
||||
)
|
||||
# default_max_tokens is the biggest available value
|
||||
if max_tokens is None:
|
||||
return default_max_tokens
|
||||
elif max_tokens > default_max_tokens:
|
||||
logger.warning(
|
||||
f"User-specified `max_tokens` ({max_tokens}) is greater than deduced "
|
||||
f"`default_max_tokens` ({default_max_tokens}), using default_max_tokens instead."
|
||||
)
|
||||
return default_max_tokens
|
||||
return max_tokens
|
||||
|
||||
try:
|
||||
executor_request = tllm.Request(
|
||||
|
||||
1598
tensorrt_llm/serve/harmony_adapter.py
Normal file
1598
tensorrt_llm/serve/harmony_adapter.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -6,10 +6,12 @@ import uuid
|
||||
from typing import Any, Dict, List, Literal, Optional, Union
|
||||
|
||||
import torch
|
||||
from openai.types.chat import ChatCompletionAssistantMessageParam
|
||||
from openai.types.chat import \
|
||||
ChatCompletionContentPartParam as OpenAIChatCompletionContentPartParam
|
||||
from openai.types.chat import \
|
||||
ChatCompletionMessageParam as OpenAIChatCompletionMessageParam
|
||||
from openai_harmony import ReasoningEffort
|
||||
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
||||
from typing_extensions import Annotated, Required, TypedDict
|
||||
|
||||
@ -327,6 +329,11 @@ class FunctionCall(OpenAIBaseModel):
|
||||
arguments: str
|
||||
|
||||
|
||||
class DeltaFunctionCall(OpenAIBaseModel):
|
||||
name: Optional[str] = None
|
||||
arguments: Optional[str] = None
|
||||
|
||||
|
||||
class ToolCall(OpenAIBaseModel):
|
||||
id: str = Field(
|
||||
default_factory=lambda: f"chatcmpl-tool-{str(uuid.uuid4().hex)}")
|
||||
@ -334,10 +341,18 @@ class ToolCall(OpenAIBaseModel):
|
||||
function: FunctionCall
|
||||
|
||||
|
||||
class DeltaToolCall(OpenAIBaseModel):
|
||||
id: Optional[str] = None
|
||||
type: Optional[Literal["function"]] = None
|
||||
index: int
|
||||
function: Optional[DeltaFunctionCall] = None
|
||||
|
||||
|
||||
class ChatMessage(OpenAIBaseModel):
|
||||
role: str
|
||||
content: str
|
||||
content: Optional[str] = None
|
||||
reasoning_content: Optional[str] = None
|
||||
reasoning: Optional[str] = None
|
||||
tool_calls: List[ToolCall] = Field(default_factory=list)
|
||||
|
||||
|
||||
@ -378,8 +393,14 @@ class CustomChatCompletionMessageParam(TypedDict, total=False):
|
||||
"""
|
||||
|
||||
|
||||
class ReasoningAssistantMessage(ChatCompletionAssistantMessageParam):
|
||||
"""Assistant message that includes reasoning tokens."""
|
||||
reasoning: Optional[str]
|
||||
|
||||
|
||||
ChatCompletionMessageParam = Union[OpenAIChatCompletionMessageParam,
|
||||
CustomChatCompletionMessageParam]
|
||||
CustomChatCompletionMessageParam,
|
||||
ReasoningAssistantMessage]
|
||||
|
||||
|
||||
class ChatCompletionLogProbs(OpenAIBaseModel):
|
||||
@ -416,7 +437,9 @@ class DeltaMessage(OpenAIBaseModel):
|
||||
role: Optional[str] = None
|
||||
content: Optional[str] = None
|
||||
reasoning_content: Optional[str] = None
|
||||
tool_calls: List[ToolCall] = Field(default_factory=list)
|
||||
# For GPT-OSS style reasoning
|
||||
reasoning: Optional[str] = None
|
||||
tool_calls: Optional[List[DeltaToolCall]] = None
|
||||
|
||||
|
||||
class ChatCompletionResponseStreamChoice(OpenAIBaseModel):
|
||||
@ -469,8 +492,8 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
||||
logit_bias: Optional[Dict[str, float]] = None
|
||||
logprobs: Optional[int] = None
|
||||
top_logprobs: Optional[int] = 0
|
||||
max_completion_tokens: int = Field(default=None,
|
||||
validation_alias='max_tokens')
|
||||
max_completion_tokens: Optional[int] = Field(default=None,
|
||||
validation_alias='max_tokens')
|
||||
n: int = 1
|
||||
presence_penalty: Optional[float] = 0.0
|
||||
response_format: Optional[ResponseFormat] = None
|
||||
@ -481,9 +504,17 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
||||
temperature: Optional[float] = 1.0
|
||||
top_p: Optional[float] = 1.0
|
||||
tools: Optional[List[ChatCompletionToolsParam]] = None
|
||||
tool_choice: Optional[Union[Literal["none"],
|
||||
tool_choice: Optional[Union[Literal["none", "auto"],
|
||||
ChatCompletionNamedToolChoiceParam]] = "none"
|
||||
user: Optional[str] = None
|
||||
reasoning_effort: Optional[ReasoningEffort | Literal[
|
||||
"low", "medium", "high"]] = Field(
|
||||
default=ReasoningEffort.LOW,
|
||||
description=(
|
||||
"The level of reasoning effort to use. Controls how much "
|
||||
"reasoning is shown in the model's response. Options: "
|
||||
"'low', 'medium', 'high'."),
|
||||
)
|
||||
|
||||
# doc: begin-chat-completion-sampling-params
|
||||
best_of: Optional[int] = None
|
||||
@ -610,9 +641,9 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_tool_choice(cls, data):
|
||||
if "tool_choice" not in data and data.get("tools"):
|
||||
data["tool_choice"] = "auto"
|
||||
if "tool_choice" in data and data["tool_choice"] != "none":
|
||||
if not isinstance(data["tool_choice"], dict):
|
||||
raise ValueError("Currently only named tools are supported.")
|
||||
if "tools" not in data or data["tools"] is None:
|
||||
raise ValueError(
|
||||
"When using `tool_choice`, `tools` must be set.")
|
||||
|
||||
@ -49,6 +49,9 @@ from tensorrt_llm.serve.postprocess_handlers import (
|
||||
from tensorrt_llm.version import __version__ as VERSION
|
||||
|
||||
from .._utils import nvtx_mark, set_prometheus_multiproc_dir
|
||||
from .harmony_adapter import (HarmonyAdapter, handle_non_streaming_response,
|
||||
handle_streaming_response,
|
||||
maybe_transform_reasoning_effort)
|
||||
|
||||
# yapf: enale
|
||||
TIMEOUT_KEEP_ALIVE = 5 # seconds.
|
||||
@ -98,6 +101,10 @@ class OpenAIServer:
|
||||
self.perf_metrics = deque(maxlen=max_perf_metrics)
|
||||
self.perf_metrics_lock = asyncio.Lock()
|
||||
|
||||
# gpt-oss
|
||||
self.harmony_adapter: HarmonyAdapter | None = None
|
||||
self.use_harmony = self.model_config.model_type == "gpt_oss"
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
if self.metadata_server is not None:
|
||||
@ -158,7 +165,6 @@ class OpenAIServer:
|
||||
return JSONResponse(content=error_response.model_dump(),
|
||||
status_code=error_response.code)
|
||||
|
||||
|
||||
def register_routes(self):
|
||||
self.app.add_api_route("/health", self.health, methods=["GET"])
|
||||
self.app.add_api_route("/health_generate", self.health_generate, methods=["GET"])
|
||||
@ -173,7 +179,7 @@ class OpenAIServer:
|
||||
self.openai_completion,
|
||||
methods=["POST"])
|
||||
self.app.add_api_route("/v1/chat/completions",
|
||||
self.openai_chat,
|
||||
self.openai_chat if not self.use_harmony else self.chat_harmony,
|
||||
methods=["POST"])
|
||||
if self.llm.args.return_perf_metrics:
|
||||
# register /prometheus/metrics
|
||||
@ -661,6 +667,78 @@ class OpenAIServer:
|
||||
logger.error(traceback.format_exc())
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
async def chat_harmony(self, request: ChatCompletionRequest, raw_request: Request) -> Response:
|
||||
"""
|
||||
Chat Completion API with harmony format support.
|
||||
Supports both streaming and non-streaming modes.
|
||||
"""
|
||||
try:
|
||||
# Initialize HarmonyAdapter
|
||||
# NOTE: WAR for Disagg failure, may affect perf if no warmup
|
||||
if not self.harmony_adapter:
|
||||
self.harmony_adapter = HarmonyAdapter()
|
||||
# Convert Pydantic models to dictionaries for JSON serialization (standard pattern)
|
||||
tools_dict = None
|
||||
if request.tools:
|
||||
tools_dict = [tool.model_dump() for tool in request.tools]
|
||||
|
||||
# Reasoning effort precedence: request.reasoning_effort > system message parsing > serving default
|
||||
reasoning_effort = maybe_transform_reasoning_effort(request.reasoning_effort)
|
||||
# Get tool_choice from request
|
||||
tool_choice = getattr(request, 'tool_choice', None)
|
||||
|
||||
try:
|
||||
harmony_tokens = self.harmony_adapter.openai_to_harmony_tokens(
|
||||
request.messages,
|
||||
tools_dict,
|
||||
reasoning_effort=reasoning_effort,
|
||||
tool_choice=tool_choice
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"messages_dict: {request.messages}")
|
||||
logger.error(f"tools_dict: {tools_dict}")
|
||||
logger.error(f"request: {request}")
|
||||
raise e
|
||||
|
||||
# Get harmony stop tokens
|
||||
harmony_stop_tokens = self.harmony_adapter.get_stop_tokens()
|
||||
if request.stop_token_ids:
|
||||
request.stop_token_ids.extend(harmony_stop_tokens)
|
||||
else:
|
||||
request.stop_token_ids = harmony_stop_tokens
|
||||
|
||||
sampling_params = request.to_sampling_params(
|
||||
vocab_size=self.tokenizer.tokenizer.vocab_size)
|
||||
sampling_params.detokenize = False # Harmony adapter handles detokenization
|
||||
|
||||
# Generate
|
||||
promise = self.llm.generate_async(
|
||||
inputs=harmony_tokens,
|
||||
sampling_params=sampling_params,
|
||||
streaming=bool(request.stream),
|
||||
lora_request=request.lora_request,
|
||||
)
|
||||
# Disconnect cancellation
|
||||
asyncio.create_task(self.await_disconnected(raw_request, promise))
|
||||
|
||||
# Handle streaming
|
||||
if request.stream:
|
||||
return StreamingResponse(
|
||||
handle_streaming_response(
|
||||
self.harmony_adapter, promise,
|
||||
str(promise.request_id), request,
|
||||
),
|
||||
media_type="text/event-stream"
|
||||
)
|
||||
else:
|
||||
response = await handle_non_streaming_response(self.harmony_adapter, promise, request)
|
||||
return JSONResponse(response.model_dump())
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error in harmony chat completion: %s", e)
|
||||
logger.debug("Error details: %s", traceback.format_exc())
|
||||
return self.create_error_response(message=str(e), err_type="internal_error")
|
||||
|
||||
async def __call__(self, host, port):
|
||||
# Store the binding address for server registration
|
||||
self.binding_addr = f"http://{host}:{port}"
|
||||
|
||||
@ -1506,6 +1506,13 @@ def test_openai_perf_metrics(llm_root, llm_venv):
|
||||
str(test_root / "_test_openai_perf_metrics.py")])
|
||||
|
||||
|
||||
def test_openai_chat_harmony(llm_root, llm_venv):
|
||||
test_root = unittest_path() / "llmapi" / "apps"
|
||||
llm_venv.run_cmd(
|
||||
["-m", "pytest",
|
||||
str(test_root / "_test_openai_chat_harmony.py")])
|
||||
|
||||
|
||||
def test_openai_prometheus(llm_root, llm_venv):
|
||||
test_root = unittest_path() / "llmapi" / "apps"
|
||||
llm_venv.run_cmd(
|
||||
|
||||
@ -102,6 +102,7 @@ l0_h100:
|
||||
- test_e2e.py::test_trtllm_bench_request_rate_and_concurrency[enable_concurrency-]
|
||||
- test_e2e.py::test_trtllm_bench_request_rate_and_concurrency[enable_concurrency-enable_request_rate] # negative test
|
||||
- test_e2e.py::test_trtllm_bench_help_sanity[meta-llama/Llama-3.1-8B]
|
||||
- test_e2e.py::test_openai_chat_harmony
|
||||
- test_e2e.py::test_ptp_quickstart_multimodal[gemma-3-27b-it-gemma/gemma-3-27b-it-image-True]
|
||||
# ------------- AutoDeploy tests ---------------
|
||||
- accuracy/test_llm_api_autodeploy.py::TestLlama3_1_8B::test_auto_dtype
|
||||
|
||||
217
tests/unittest/llmapi/apps/_test_openai_chat_harmony.py
Normal file
217
tests/unittest/llmapi/apps/_test_openai_chat_harmony.py
Normal file
@ -0,0 +1,217 @@
|
||||
import json
|
||||
|
||||
import openai
|
||||
import pytest
|
||||
|
||||
from ..test_llm import get_model_path
|
||||
from .openai_server import RemoteOpenAIServer
|
||||
|
||||
pytestmark = pytest.mark.threadleak(enabled=False)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module", ids=["GPT-OSS-20B"])
|
||||
def model():
|
||||
return "gpt_oss/gpt-oss-20b/"
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def server(model: str):
|
||||
model_path = get_model_path(model)
|
||||
with RemoteOpenAIServer(model_path) as remote_server:
|
||||
yield remote_server
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def client(server: RemoteOpenAIServer):
|
||||
return server.get_async_client()
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="module")
|
||||
async def test_reasoning(client: openai.AsyncOpenAI, model: str):
|
||||
response = await client.chat.completions.create(
|
||||
model=model,
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": "Which one is larger as numeric, 9.9 or 9.11?"
|
||||
}])
|
||||
assert response.choices[0].message.content
|
||||
assert response.choices[0].message.reasoning
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="module")
|
||||
async def test_reasoning_effort(client: openai.AsyncOpenAI, model: str):
|
||||
response = await client.chat.completions.create(
|
||||
model=model,
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": "Which one is larger as numeric, 9.9 or 9.11?"
|
||||
}],
|
||||
reasoning_effort="Medium")
|
||||
assert response.choices[0].message.content
|
||||
assert response.choices[0].message.reasoning
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="module")
|
||||
async def test_chat(client: openai.AsyncOpenAI, model: str):
|
||||
response = await client.chat.completions.create(
|
||||
model=model,
|
||||
messages=[{
|
||||
"role": "developer",
|
||||
"content": "Respond in Chinese."
|
||||
}, {
|
||||
"role": "user",
|
||||
"content": "Hello!"
|
||||
}, {
|
||||
"role": "assistant",
|
||||
"content": "Hello! How can I help you?"
|
||||
}, {
|
||||
"role": "user",
|
||||
"content": "Tell me a joke."
|
||||
}])
|
||||
assert response.choices[0].message.content
|
||||
assert response.choices[0].message.reasoning
|
||||
|
||||
|
||||
def get_current_weather(location: str, format: str = "celsius") -> dict:
|
||||
return {"sunny": True, "temperature": 20 if format == "celsius" else 68}
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="module")
|
||||
async def test_tool_calls(client: openai.AsyncOpenAI, model: str):
|
||||
tool_get_current_weather = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_current_weather",
|
||||
"description": "Gets the current weather in the provided location.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description":
|
||||
"The city and state, e.g. San Francisco, CA",
|
||||
},
|
||||
"format": {
|
||||
"type": "string",
|
||||
"description": "default: celsius",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
},
|
||||
},
|
||||
"required": ["location"],
|
||||
}
|
||||
}
|
||||
}
|
||||
messages = [{"role": "user", "content": "What is the weather like in SF?"}]
|
||||
response = await client.chat.completions.create(
|
||||
model=model,
|
||||
messages=messages,
|
||||
tools=[tool_get_current_weather],
|
||||
)
|
||||
message = response.choices[0].message
|
||||
assert response.choices[0].finish_reason == "tool_calls"
|
||||
assert message.content is None
|
||||
assert message.reasoning
|
||||
assert message.tool_calls
|
||||
assert len(message.tool_calls) == 1
|
||||
tool_call = message.tool_calls[0]
|
||||
assert tool_call.function.name == "get_current_weather"
|
||||
args = json.loads(tool_call.function.arguments)
|
||||
answer = get_current_weather(**args)
|
||||
messages.extend([{
|
||||
"role": "assistant",
|
||||
"tool_calls": [tool_call],
|
||||
"reasoning": message.reasoning
|
||||
}, {
|
||||
"role": "tool",
|
||||
"content": json.dumps(answer),
|
||||
"tool_call_id": tool_call.id
|
||||
}])
|
||||
response = await client.chat.completions.create(
|
||||
model=model,
|
||||
messages=messages,
|
||||
)
|
||||
message = response.choices[0].message
|
||||
assert message.content
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="module")
|
||||
async def test_streaming(client: openai.AsyncOpenAI, model: str):
|
||||
response = await client.chat.completions.create(
|
||||
model=model,
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": "Explain the theory of relativity in brief."
|
||||
}],
|
||||
stream=True,
|
||||
)
|
||||
collected_chunks = []
|
||||
collected_messages = []
|
||||
async for chunk in response:
|
||||
collected_chunks.append(chunk)
|
||||
collected_messages.append(chunk.choices[0].delta)
|
||||
|
||||
full_response = "".join([
|
||||
m.content for m in collected_messages
|
||||
if hasattr(m, "content") and m.content
|
||||
])
|
||||
full_reasoning_response = "".join([
|
||||
m.reasoning for m in collected_messages
|
||||
if hasattr(m, "reasoning") and m.reasoning
|
||||
])
|
||||
assert full_response
|
||||
assert full_reasoning_response
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="module")
|
||||
async def test_streaming_tool_call(client: openai.AsyncOpenAI, model: str):
|
||||
tool_get_current_weather = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_current_weather",
|
||||
"description": "Gets the current weather in the provided location.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description":
|
||||
"The city and state, e.g. San Francisco, CA",
|
||||
},
|
||||
"format": {
|
||||
"type": "string",
|
||||
"description": "default: celsius",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
},
|
||||
},
|
||||
"required": ["location"],
|
||||
}
|
||||
}
|
||||
}
|
||||
messages = [{"role": "user", "content": "What is the weather like in SF?"}]
|
||||
response = await client.chat.completions.create(
|
||||
model=model,
|
||||
messages=messages,
|
||||
tools=[tool_get_current_weather],
|
||||
stream=True,
|
||||
)
|
||||
tool_name: str
|
||||
reasoning_chunks: list[str] = []
|
||||
tool_arg_chunks: list[str] = []
|
||||
async for chunk in response:
|
||||
delta = chunk.choices[0].delta
|
||||
if hasattr(delta, "tool_calls") and delta.tool_calls:
|
||||
function = delta.tool_calls[0].function
|
||||
if hasattr(function, "name") and function.name:
|
||||
tool_name = function.name
|
||||
if hasattr(function, "arguments") and function.arguments:
|
||||
args_str = function.arguments
|
||||
tool_arg_chunks.append(args_str)
|
||||
if hasattr(delta, "reasoning") and delta.reasoning:
|
||||
reasoning_chunks.append(delta.reasoning)
|
||||
reasoning = "".join(reasoning_chunks)
|
||||
tool_args = "".join(tool_arg_chunks)
|
||||
assert tool_name == "get_current_weather"
|
||||
assert tool_args
|
||||
assert reasoning
|
||||
args = json.loads(tool_args)
|
||||
get_current_weather(**args)
|
||||
Loading…
Reference in New Issue
Block a user