mirror of
https://github.com/vllm-project/vllm.git
synced 2026-06-06 00:16:14 +00:00
[ROCm][CI] Stabilize Granite tool-use and test URL construction (#43017)
Signed-off-by: Andreas Karatzas <akaratza@amd.com>
This commit is contained in:
@@ -1,4 +1,8 @@
|
||||
{%- if tools %}
|
||||
{%- if messages and messages[0]['role'] != 'system' %}
|
||||
{{- '<|start_of_role|>system<|end_of_role|>You are a helpful AI assistant with access to the following tools. When a tool is required to answer the user query, respond with <|tool_call|> followed by a JSON list of tools used.<|end_of_text|>
|
||||
' }}
|
||||
{%- endif %}
|
||||
{{- '<|start_of_role|>available_tools<|end_of_role|>
|
||||
' }}
|
||||
{%- for tool in tools %}
|
||||
|
||||
@@ -13,9 +13,19 @@ from .utils import (
|
||||
SEED,
|
||||
WEATHER_TOOL,
|
||||
ServerConfig,
|
||||
ensure_system_prompt,
|
||||
)
|
||||
|
||||
|
||||
def apply_parallel_tool_system_prompt(
|
||||
messages,
|
||||
server_config: ServerConfig,
|
||||
):
|
||||
if server_config["model"] == "ibm-granite/granite-3.0-8b-instruct":
|
||||
return ensure_system_prompt(messages, server_config)
|
||||
return messages
|
||||
|
||||
|
||||
# test: getting the model to generate parallel tool calls (streaming/not)
|
||||
# when requested. NOTE that not all models may support this, so some exclusions
|
||||
# may be added in the future. e.g. llama 3.1 models are not designed to support
|
||||
@@ -33,8 +43,11 @@ async def test_parallel_tool_calls(
|
||||
|
||||
models = await client.models.list()
|
||||
model_name: str = models.data[0].id
|
||||
messages = apply_parallel_tool_system_prompt(
|
||||
MESSAGES_ASKING_FOR_PARALLEL_TOOLS, server_config
|
||||
)
|
||||
chat_completion = await client.chat.completions.create(
|
||||
messages=MESSAGES_ASKING_FOR_PARALLEL_TOOLS,
|
||||
messages=messages,
|
||||
temperature=0,
|
||||
max_completion_tokens=200,
|
||||
model=model_name,
|
||||
@@ -73,7 +86,7 @@ async def test_parallel_tool_calls(
|
||||
# make the same request, streaming
|
||||
stream = await client.chat.completions.create(
|
||||
model=model_name,
|
||||
messages=MESSAGES_ASKING_FOR_PARALLEL_TOOLS,
|
||||
messages=messages,
|
||||
temperature=0,
|
||||
max_completion_tokens=200,
|
||||
tools=[WEATHER_TOOL, SEARCH_TOOL],
|
||||
@@ -162,8 +175,11 @@ async def test_parallel_tool_calls_with_results(
|
||||
|
||||
models = await client.models.list()
|
||||
model_name: str = models.data[0].id
|
||||
messages = apply_parallel_tool_system_prompt(
|
||||
MESSAGES_WITH_PARALLEL_TOOL_RESPONSE, server_config
|
||||
)
|
||||
chat_completion = await client.chat.completions.create(
|
||||
messages=MESSAGES_WITH_PARALLEL_TOOL_RESPONSE,
|
||||
messages=messages,
|
||||
temperature=0,
|
||||
max_completion_tokens=200,
|
||||
model=model_name,
|
||||
@@ -182,7 +198,7 @@ async def test_parallel_tool_calls_with_results(
|
||||
assert "78" in choice.message.content # Orlando temp in tool response
|
||||
|
||||
stream = await client.chat.completions.create(
|
||||
messages=MESSAGES_WITH_PARALLEL_TOOL_RESPONSE,
|
||||
messages=messages,
|
||||
temperature=0,
|
||||
max_completion_tokens=200,
|
||||
model=model_name,
|
||||
@@ -220,15 +236,20 @@ async def test_parallel_tool_calls_with_results(
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_parallel_tool_calls_false(client: openai.AsyncOpenAI):
|
||||
async def test_parallel_tool_calls_false(
|
||||
client: openai.AsyncOpenAI, server_config: ServerConfig
|
||||
):
|
||||
"""
|
||||
Ensure only one tool call is returned when parallel_tool_calls is False.
|
||||
"""
|
||||
|
||||
models = await client.models.list()
|
||||
model_name: str = models.data[0].id
|
||||
messages = apply_parallel_tool_system_prompt(
|
||||
MESSAGES_ASKING_FOR_PARALLEL_TOOLS, server_config
|
||||
)
|
||||
chat_completion = await client.chat.completions.create(
|
||||
messages=MESSAGES_ASKING_FOR_PARALLEL_TOOLS,
|
||||
messages=messages,
|
||||
temperature=0,
|
||||
max_completion_tokens=200,
|
||||
model=model_name,
|
||||
@@ -248,7 +269,7 @@ async def test_parallel_tool_calls_false(client: openai.AsyncOpenAI):
|
||||
# make the same request, streaming
|
||||
stream = await client.chat.completions.create(
|
||||
model=model_name,
|
||||
messages=MESSAGES_ASKING_FOR_PARALLEL_TOOLS,
|
||||
messages=messages,
|
||||
temperature=0,
|
||||
max_completion_tokens=200,
|
||||
tools=[WEATHER_TOOL, SEARCH_TOOL],
|
||||
|
||||
@@ -12,17 +12,22 @@ from .utils import (
|
||||
SEARCH_TOOL,
|
||||
SEED,
|
||||
WEATHER_TOOL,
|
||||
ServerConfig,
|
||||
ensure_system_prompt,
|
||||
)
|
||||
|
||||
|
||||
# test: request a chat completion that should return tool calls, so we know they
|
||||
# are parsable
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_call_and_choice(client: openai.AsyncOpenAI):
|
||||
async def test_tool_call_and_choice(
|
||||
client: openai.AsyncOpenAI, server_config: ServerConfig
|
||||
):
|
||||
models = await client.models.list()
|
||||
model_name: str = models.data[0].id
|
||||
messages = ensure_system_prompt(MESSAGES_ASKING_FOR_TOOLS, server_config)
|
||||
chat_completion = await client.chat.completions.create(
|
||||
messages=MESSAGES_ASKING_FOR_TOOLS,
|
||||
messages=messages,
|
||||
temperature=0,
|
||||
max_completion_tokens=100,
|
||||
model=model_name,
|
||||
@@ -68,7 +73,7 @@ async def test_tool_call_and_choice(client: openai.AsyncOpenAI):
|
||||
# make the same request, streaming
|
||||
stream = await client.chat.completions.create(
|
||||
model=model_name,
|
||||
messages=MESSAGES_ASKING_FOR_TOOLS,
|
||||
messages=messages,
|
||||
temperature=0,
|
||||
max_completion_tokens=100,
|
||||
tools=[WEATHER_TOOL, SEARCH_TOOL],
|
||||
|
||||
@@ -201,6 +201,10 @@ CONFIGS: dict[str, ServerConfig] = {
|
||||
"--chat-template",
|
||||
str(VLLM_PATH / "examples/tool_chat_template_granite.jinja"),
|
||||
],
|
||||
"system_prompt": "You are a helpful AI assistant with access to tools. "
|
||||
"Use two-letter US state abbreviations in weather tool arguments. "
|
||||
"When a tool is required to answer the user query, respond with "
|
||||
"<|tool_call|> followed by a JSON list of tools used.",
|
||||
},
|
||||
"granite-3.1-8b": {
|
||||
"model": "ibm-granite/granite-3.1-8b-instruct",
|
||||
|
||||
+2
-1
@@ -660,7 +660,8 @@ class RemoteVLLMServer:
|
||||
)
|
||||
|
||||
def url_for(self, *parts: str) -> str:
|
||||
return self.url_root + "/" + "/".join(parts)
|
||||
path = "/".join(part.strip("/") for part in parts if part)
|
||||
return f"{self.url_root}/{path}"
|
||||
|
||||
def get_client(self, **kwargs):
|
||||
if "timeout" not in kwargs:
|
||||
|
||||
Reference in New Issue
Block a user