TensorRT-LLMs/tests/unittest/llmapi/apps/test_chat_utils.py
William Zhang 121140cfec
[None][fixes] Add tool call parsing fixes and Qwen3 coder parser (#8817)
Signed-off-by: William Zhang <133824995+2ez4bz@users.noreply.github.com>
2025-11-13 04:34:38 -08:00

180 lines
5.8 KiB
Python

from unittest.mock import MagicMock
import pytest
from tensorrt_llm.serve.chat_utils import parse_chat_message_content
@pytest.fixture
def mock_mm_data_tracker():
"""Create a mock MultimodalDataTracker for testing."""
return MagicMock()
class TestParseAssistantMessages:
"""Test suite for assistant role messages."""
@pytest.mark.parametrize("content", [None, "Hello, how can I help you?"])
def test_assistant_message_no_tool_calls(
self,
mock_mm_data_tracker,
content,
):
"""Test parsing an assistant message with simple string content."""
message = {"role": "assistant", "content": content}
result = parse_chat_message_content(message, mock_mm_data_tracker)
assert result["role"] == "assistant"
assert result["content"] == (content or "")
assert result["media"] == []
assert "tool_calls" not in result
@pytest.mark.parametrize(
"arguments",
[
# JSON string.
'{"location": "San Francisco", "unit": "celsius"}',
# Python dict.
{"location": "San Francisco", "unit": "celsius"},
],
)
def test_assistant_message_with_tool_calls_string_arguments(
self, mock_mm_data_tracker, arguments
):
"""Test parsing an assistant message with tool calls where arguments are JSON strings."""
message = {
"role": "assistant",
"content": None,
"tool_calls": [
{
"id": "call_123",
"type": "function",
"function": {
"name": "get_weather",
"arguments": arguments,
},
}
],
}
result = parse_chat_message_content(message, mock_mm_data_tracker)
assert result == {
"role": "assistant",
"content": "",
"media": [],
"tool_calls": [
{
"id": "call_123",
"type": "function",
"function": {
"name": "get_weather",
"arguments": {"location": "San Francisco", "unit": "celsius"},
},
}
],
}
def test_assistant_message_with_empty_tool_arguments(self, mock_mm_data_tracker):
"""Test parsing an assistant message with tool calls that have no arguments."""
message = {
"role": "assistant",
"content": "Foobar",
"tool_calls": [
{
"id": "call_789",
"type": "function",
"function": {"name": "get_current_time", "arguments": None},
}
],
}
result = parse_chat_message_content(message, mock_mm_data_tracker)
expected = {
"role": "assistant",
"content": "Foobar",
"media": [],
"tool_calls": [
{
"id": "call_789",
"type": "function",
"function": {"name": "get_current_time", "arguments": {}},
}
],
}
assert result == expected
def test_assistant_message_with_multiple_tool_calls(self, mock_mm_data_tracker):
"""Test parsing an assistant message with multiple tool calls."""
message = {
"role": "assistant",
"content": None,
"tool_calls": [
{
"id": "call_1",
"type": "function",
"function": {"name": "get_weather", "arguments": '{"location": "New York"}'},
},
{
"id": "call_2",
"type": "function",
"function": {"name": "get_time", "arguments": {"timezone": "EST"}},
},
{"id": "call_3", "type": "function", "function": {"name": "no_args_function"}},
],
}
result = parse_chat_message_content(message, mock_mm_data_tracker)
expected = {
"role": "assistant",
"content": "",
"media": [],
"tool_calls": [
{
"id": "call_1",
"type": "function",
"function": {"name": "get_weather", "arguments": {"location": "New York"}},
},
{
"id": "call_2",
"type": "function",
"function": {"name": "get_time", "arguments": {"timezone": "EST"}},
},
{
"id": "call_3",
"type": "function",
"function": {"name": "no_args_function", "arguments": {}},
},
],
}
assert result == expected
class TestParseToolMessages:
"""Test suite for tool role messages."""
@pytest.mark.parametrize("content", ["The weather in San Francisco is 72°F and sunny.", None])
def test_tool_message_with_tool_call_id(self, mock_mm_data_tracker, content):
"""Test parsing a tool message with tool_call_id."""
message = {"role": "tool", "content": (content or ""), "tool_call_id": "call_123"}
result = parse_chat_message_content(message, mock_mm_data_tracker)
expected = {**message, "media": []}
assert result == expected
def test_tool_message_without_tool_call_id(self, mock_mm_data_tracker):
"""Test parsing a tool message without tool_call_id."""
message = {
"role": "tool",
"content": "Database query completed successfully.",
}
result = parse_chat_message_content(message, mock_mm_data_tracker)
expected = {**message, "media": []}
assert result == expected