TensorRT-LLMs/tests/unittest/llmapi/apps/test_chat_utils.py
Pengyun Lin fa61825c74
[None][feat] Support custom chat template for tool calling (#9297)
Signed-off-by: Pengyun Lin <81065165+LinPoly@users.noreply.github.com>
2025-11-25 22:07:04 +08:00

225 lines
7.7 KiB
Python

from unittest.mock import MagicMock
import pytest
from tensorrt_llm.serve.chat_utils import load_chat_template, parse_chat_message_content
@pytest.fixture
def mock_mm_data_tracker():
"""Create a mock MultimodalDataTracker for testing."""
return MagicMock()
class TestParseAssistantMessages:
"""Test suite for assistant role messages."""
@pytest.mark.parametrize("content", [None, "Hello, how can I help you?"])
def test_assistant_message_no_tool_calls(
self,
mock_mm_data_tracker,
content,
):
"""Test parsing an assistant message with simple string content."""
message = {"role": "assistant", "content": content}
result = parse_chat_message_content(message, mock_mm_data_tracker)
assert result["role"] == "assistant"
assert result["content"] == (content or "")
assert result["media"] == []
assert "tool_calls" not in result
@pytest.mark.parametrize(
"arguments",
[
# JSON string.
'{"location": "San Francisco", "unit": "celsius"}',
# Python dict.
{"location": "San Francisco", "unit": "celsius"},
],
)
def test_assistant_message_with_tool_calls_string_arguments(
self, mock_mm_data_tracker, arguments
):
"""Test parsing an assistant message with tool calls where arguments are JSON strings."""
message = {
"role": "assistant",
"content": None,
"tool_calls": [
{
"id": "call_123",
"type": "function",
"function": {
"name": "get_weather",
"arguments": arguments,
},
}
],
}
result = parse_chat_message_content(message, mock_mm_data_tracker)
assert result == {
"role": "assistant",
"content": "",
"media": [],
"tool_calls": [
{
"id": "call_123",
"type": "function",
"function": {
"name": "get_weather",
"arguments": {"location": "San Francisco", "unit": "celsius"},
},
}
],
}
def test_assistant_message_with_empty_tool_arguments(self, mock_mm_data_tracker):
"""Test parsing an assistant message with tool calls that have no arguments."""
message = {
"role": "assistant",
"content": "Foobar",
"tool_calls": [
{
"id": "call_789",
"type": "function",
"function": {"name": "get_current_time", "arguments": None},
}
],
}
result = parse_chat_message_content(message, mock_mm_data_tracker)
expected = {
"role": "assistant",
"content": "Foobar",
"media": [],
"tool_calls": [
{
"id": "call_789",
"type": "function",
"function": {"name": "get_current_time", "arguments": {}},
}
],
}
assert result == expected
def test_assistant_message_with_multiple_tool_calls(self, mock_mm_data_tracker):
"""Test parsing an assistant message with multiple tool calls."""
message = {
"role": "assistant",
"content": None,
"tool_calls": [
{
"id": "call_1",
"type": "function",
"function": {"name": "get_weather", "arguments": '{"location": "New York"}'},
},
{
"id": "call_2",
"type": "function",
"function": {"name": "get_time", "arguments": {"timezone": "EST"}},
},
{"id": "call_3", "type": "function", "function": {"name": "no_args_function"}},
],
}
result = parse_chat_message_content(message, mock_mm_data_tracker)
expected = {
"role": "assistant",
"content": "",
"media": [],
"tool_calls": [
{
"id": "call_1",
"type": "function",
"function": {"name": "get_weather", "arguments": {"location": "New York"}},
},
{
"id": "call_2",
"type": "function",
"function": {"name": "get_time", "arguments": {"timezone": "EST"}},
},
{
"id": "call_3",
"type": "function",
"function": {"name": "no_args_function", "arguments": {}},
},
],
}
assert result == expected
class TestParseToolMessages:
"""Test suite for tool role messages."""
@pytest.mark.parametrize("content", ["The weather in San Francisco is 72°F and sunny.", None])
def test_tool_message_with_tool_call_id(self, mock_mm_data_tracker, content):
"""Test parsing a tool message with tool_call_id."""
message = {"role": "tool", "content": (content or ""), "tool_call_id": "call_123"}
result = parse_chat_message_content(message, mock_mm_data_tracker)
expected = {**message, "media": []}
assert result == expected
def test_tool_message_without_tool_call_id(self, mock_mm_data_tracker):
"""Test parsing a tool message without tool_call_id."""
message = {
"role": "tool",
"content": "Database query completed successfully.",
}
result = parse_chat_message_content(message, mock_mm_data_tracker)
expected = {**message, "media": []}
assert result == expected
# ruff: noqa: E501
TEMPLATE_CHATML = """{% for message in messages %}{{'<|im_start|>' + message['role'] + '\\n' + message['content']}}{% if (loop.last and add_generation_prompt) or not loop.last %}{{ '<|im_end|>' + '\\n'}}{% endif %}{% endfor %}
{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{ '<|im_start|>assistant\\n' }}{% endif %}"""
@pytest.fixture
def chat_template_path(tmp_path):
"""Return the path to the chat template."""
temp_file_path = tmp_path / "chat_template.jinja"
with open(temp_file_path, "w") as f:
f.write(TEMPLATE_CHATML)
return temp_file_path
class TestLoadChatTemplate:
"""Test suite for loading chat templates."""
def test_load_chat_template_from_path(self, chat_template_path):
"""Test loading a chat template from a path."""
template = load_chat_template(chat_template_path)
assert template == TEMPLATE_CHATML
def test_load_chat_template_from_string(self):
"""Test loading a chat template from a string."""
text = "Hello, how can I help you?"
template = load_chat_template(text, is_literal=True)
assert template == text
def test_load_chat_template_from_none(self):
"""Test loading a chat template from None."""
template = load_chat_template(None)
assert template is None
def test_load_chat_template_from_path_with_invalid_path(self):
"""Test loading a chat template from a path with an invalid path."""
with pytest.raises(ValueError, match="looks like a file path"):
load_chat_template("invalid/path/to/chat_template.jinja")
def test_jinjalike_literal(self):
"""Test loading a chat template from a jinja-like string."""
template = "{{ messages }}"
template_content = load_chat_template(template)
assert template_content == template