[ROCm][CI] Stabilize Granite tool-use and test URL construction (#43017)

Signed-off-by: Andreas Karatzas <akaratza@amd.com>
This commit is contained in:
Andreas Karatzas
2026-05-22 23:21:11 -05:00
committed by GitHub
parent 6a4723a2e0
commit 76ea1d5d2f
5 changed files with 46 additions and 11 deletions
@@ -1,4 +1,8 @@
{%- if tools %}
{%- if messages and messages[0]['role'] != 'system' %}
{{- '<|start_of_role|>system<|end_of_role|>You are a helpful AI assistant with access to the following tools. When a tool is required to answer the user query, respond with <|tool_call|> followed by a JSON list of tools used.<|end_of_text|>
' }}
{%- endif %}
{{- '<|start_of_role|>available_tools<|end_of_role|>
' }}
{%- for tool in tools %}
+28 -7
View File
@@ -13,9 +13,19 @@ from .utils import (
SEED,
WEATHER_TOOL,
ServerConfig,
ensure_system_prompt,
)
def apply_parallel_tool_system_prompt(
messages,
server_config: ServerConfig,
):
if server_config["model"] == "ibm-granite/granite-3.0-8b-instruct":
return ensure_system_prompt(messages, server_config)
return messages
# test: getting the model to generate parallel tool calls (streaming/not)
# when requested. NOTE that not all models may support this, so some exclusions
# may be added in the future. e.g. llama 3.1 models are not designed to support
@@ -33,8 +43,11 @@ async def test_parallel_tool_calls(
models = await client.models.list()
model_name: str = models.data[0].id
messages = apply_parallel_tool_system_prompt(
MESSAGES_ASKING_FOR_PARALLEL_TOOLS, server_config
)
chat_completion = await client.chat.completions.create(
messages=MESSAGES_ASKING_FOR_PARALLEL_TOOLS,
messages=messages,
temperature=0,
max_completion_tokens=200,
model=model_name,
@@ -73,7 +86,7 @@ async def test_parallel_tool_calls(
# make the same request, streaming
stream = await client.chat.completions.create(
model=model_name,
messages=MESSAGES_ASKING_FOR_PARALLEL_TOOLS,
messages=messages,
temperature=0,
max_completion_tokens=200,
tools=[WEATHER_TOOL, SEARCH_TOOL],
@@ -162,8 +175,11 @@ async def test_parallel_tool_calls_with_results(
models = await client.models.list()
model_name: str = models.data[0].id
messages = apply_parallel_tool_system_prompt(
MESSAGES_WITH_PARALLEL_TOOL_RESPONSE, server_config
)
chat_completion = await client.chat.completions.create(
messages=MESSAGES_WITH_PARALLEL_TOOL_RESPONSE,
messages=messages,
temperature=0,
max_completion_tokens=200,
model=model_name,
@@ -182,7 +198,7 @@ async def test_parallel_tool_calls_with_results(
assert "78" in choice.message.content # Orlando temp in tool response
stream = await client.chat.completions.create(
messages=MESSAGES_WITH_PARALLEL_TOOL_RESPONSE,
messages=messages,
temperature=0,
max_completion_tokens=200,
model=model_name,
@@ -220,15 +236,20 @@ async def test_parallel_tool_calls_with_results(
@pytest.mark.asyncio
async def test_parallel_tool_calls_false(client: openai.AsyncOpenAI):
async def test_parallel_tool_calls_false(
client: openai.AsyncOpenAI, server_config: ServerConfig
):
"""
Ensure only one tool call is returned when parallel_tool_calls is False.
"""
models = await client.models.list()
model_name: str = models.data[0].id
messages = apply_parallel_tool_system_prompt(
MESSAGES_ASKING_FOR_PARALLEL_TOOLS, server_config
)
chat_completion = await client.chat.completions.create(
messages=MESSAGES_ASKING_FOR_PARALLEL_TOOLS,
messages=messages,
temperature=0,
max_completion_tokens=200,
model=model_name,
@@ -248,7 +269,7 @@ async def test_parallel_tool_calls_false(client: openai.AsyncOpenAI):
# make the same request, streaming
stream = await client.chat.completions.create(
model=model_name,
messages=MESSAGES_ASKING_FOR_PARALLEL_TOOLS,
messages=messages,
temperature=0,
max_completion_tokens=200,
tools=[WEATHER_TOOL, SEARCH_TOOL],
+8 -3
View File
@@ -12,17 +12,22 @@ from .utils import (
SEARCH_TOOL,
SEED,
WEATHER_TOOL,
ServerConfig,
ensure_system_prompt,
)
# test: request a chat completion that should return tool calls, so we know they
# are parsable
@pytest.mark.asyncio
async def test_tool_call_and_choice(client: openai.AsyncOpenAI):
async def test_tool_call_and_choice(
client: openai.AsyncOpenAI, server_config: ServerConfig
):
models = await client.models.list()
model_name: str = models.data[0].id
messages = ensure_system_prompt(MESSAGES_ASKING_FOR_TOOLS, server_config)
chat_completion = await client.chat.completions.create(
messages=MESSAGES_ASKING_FOR_TOOLS,
messages=messages,
temperature=0,
max_completion_tokens=100,
model=model_name,
@@ -68,7 +73,7 @@ async def test_tool_call_and_choice(client: openai.AsyncOpenAI):
# make the same request, streaming
stream = await client.chat.completions.create(
model=model_name,
messages=MESSAGES_ASKING_FOR_TOOLS,
messages=messages,
temperature=0,
max_completion_tokens=100,
tools=[WEATHER_TOOL, SEARCH_TOOL],
+4
View File
@@ -201,6 +201,10 @@ CONFIGS: dict[str, ServerConfig] = {
"--chat-template",
str(VLLM_PATH / "examples/tool_chat_template_granite.jinja"),
],
"system_prompt": "You are a helpful AI assistant with access to tools. "
"Use two-letter US state abbreviations in weather tool arguments. "
"When a tool is required to answer the user query, respond with "
"<|tool_call|> followed by a JSON list of tools used.",
},
"granite-3.1-8b": {
"model": "ibm-granite/granite-3.1-8b-instruct",
+2 -1
View File
@@ -660,7 +660,8 @@ class RemoteVLLMServer:
)
def url_for(self, *parts: str) -> str:
return self.url_root + "/" + "/".join(parts)
path = "/".join(part.strip("/") for part in parts if part)
return f"{self.url_root}/{path}"
def get_client(self, **kwargs):
if "timeout" not in kwargs: