mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
225 lines
7.7 KiB
Python
225 lines
7.7 KiB
Python
from unittest.mock import MagicMock
|
|
|
|
import pytest
|
|
|
|
from tensorrt_llm.serve.chat_utils import load_chat_template, parse_chat_message_content
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_mm_data_tracker():
|
|
"""Create a mock MultimodalDataTracker for testing."""
|
|
return MagicMock()
|
|
|
|
|
|
class TestParseAssistantMessages:
|
|
"""Test suite for assistant role messages."""
|
|
|
|
@pytest.mark.parametrize("content", [None, "Hello, how can I help you?"])
|
|
def test_assistant_message_no_tool_calls(
|
|
self,
|
|
mock_mm_data_tracker,
|
|
content,
|
|
):
|
|
"""Test parsing an assistant message with simple string content."""
|
|
message = {"role": "assistant", "content": content}
|
|
|
|
result = parse_chat_message_content(message, mock_mm_data_tracker)
|
|
|
|
assert result["role"] == "assistant"
|
|
assert result["content"] == (content or "")
|
|
assert result["media"] == []
|
|
assert "tool_calls" not in result
|
|
|
|
@pytest.mark.parametrize(
|
|
"arguments",
|
|
[
|
|
# JSON string.
|
|
'{"location": "San Francisco", "unit": "celsius"}',
|
|
# Python dict.
|
|
{"location": "San Francisco", "unit": "celsius"},
|
|
],
|
|
)
|
|
def test_assistant_message_with_tool_calls_string_arguments(
|
|
self, mock_mm_data_tracker, arguments
|
|
):
|
|
"""Test parsing an assistant message with tool calls where arguments are JSON strings."""
|
|
message = {
|
|
"role": "assistant",
|
|
"content": None,
|
|
"tool_calls": [
|
|
{
|
|
"id": "call_123",
|
|
"type": "function",
|
|
"function": {
|
|
"name": "get_weather",
|
|
"arguments": arguments,
|
|
},
|
|
}
|
|
],
|
|
}
|
|
|
|
result = parse_chat_message_content(message, mock_mm_data_tracker)
|
|
|
|
assert result == {
|
|
"role": "assistant",
|
|
"content": "",
|
|
"media": [],
|
|
"tool_calls": [
|
|
{
|
|
"id": "call_123",
|
|
"type": "function",
|
|
"function": {
|
|
"name": "get_weather",
|
|
"arguments": {"location": "San Francisco", "unit": "celsius"},
|
|
},
|
|
}
|
|
],
|
|
}
|
|
|
|
def test_assistant_message_with_empty_tool_arguments(self, mock_mm_data_tracker):
|
|
"""Test parsing an assistant message with tool calls that have no arguments."""
|
|
message = {
|
|
"role": "assistant",
|
|
"content": "Foobar",
|
|
"tool_calls": [
|
|
{
|
|
"id": "call_789",
|
|
"type": "function",
|
|
"function": {"name": "get_current_time", "arguments": None},
|
|
}
|
|
],
|
|
}
|
|
|
|
result = parse_chat_message_content(message, mock_mm_data_tracker)
|
|
|
|
expected = {
|
|
"role": "assistant",
|
|
"content": "Foobar",
|
|
"media": [],
|
|
"tool_calls": [
|
|
{
|
|
"id": "call_789",
|
|
"type": "function",
|
|
"function": {"name": "get_current_time", "arguments": {}},
|
|
}
|
|
],
|
|
}
|
|
assert result == expected
|
|
|
|
def test_assistant_message_with_multiple_tool_calls(self, mock_mm_data_tracker):
|
|
"""Test parsing an assistant message with multiple tool calls."""
|
|
message = {
|
|
"role": "assistant",
|
|
"content": None,
|
|
"tool_calls": [
|
|
{
|
|
"id": "call_1",
|
|
"type": "function",
|
|
"function": {"name": "get_weather", "arguments": '{"location": "New York"}'},
|
|
},
|
|
{
|
|
"id": "call_2",
|
|
"type": "function",
|
|
"function": {"name": "get_time", "arguments": {"timezone": "EST"}},
|
|
},
|
|
{"id": "call_3", "type": "function", "function": {"name": "no_args_function"}},
|
|
],
|
|
}
|
|
|
|
result = parse_chat_message_content(message, mock_mm_data_tracker)
|
|
|
|
expected = {
|
|
"role": "assistant",
|
|
"content": "",
|
|
"media": [],
|
|
"tool_calls": [
|
|
{
|
|
"id": "call_1",
|
|
"type": "function",
|
|
"function": {"name": "get_weather", "arguments": {"location": "New York"}},
|
|
},
|
|
{
|
|
"id": "call_2",
|
|
"type": "function",
|
|
"function": {"name": "get_time", "arguments": {"timezone": "EST"}},
|
|
},
|
|
{
|
|
"id": "call_3",
|
|
"type": "function",
|
|
"function": {"name": "no_args_function", "arguments": {}},
|
|
},
|
|
],
|
|
}
|
|
assert result == expected
|
|
|
|
|
|
class TestParseToolMessages:
|
|
"""Test suite for tool role messages."""
|
|
|
|
@pytest.mark.parametrize("content", ["The weather in San Francisco is 72°F and sunny.", None])
|
|
def test_tool_message_with_tool_call_id(self, mock_mm_data_tracker, content):
|
|
"""Test parsing a tool message with tool_call_id."""
|
|
message = {"role": "tool", "content": (content or ""), "tool_call_id": "call_123"}
|
|
|
|
result = parse_chat_message_content(message, mock_mm_data_tracker)
|
|
|
|
expected = {**message, "media": []}
|
|
assert result == expected
|
|
|
|
def test_tool_message_without_tool_call_id(self, mock_mm_data_tracker):
|
|
"""Test parsing a tool message without tool_call_id."""
|
|
message = {
|
|
"role": "tool",
|
|
"content": "Database query completed successfully.",
|
|
}
|
|
|
|
result = parse_chat_message_content(message, mock_mm_data_tracker)
|
|
|
|
expected = {**message, "media": []}
|
|
assert result == expected
|
|
|
|
|
|
# ruff: noqa: E501
|
|
TEMPLATE_CHATML = """{% for message in messages %}{{'<|im_start|>' + message['role'] + '\\n' + message['content']}}{% if (loop.last and add_generation_prompt) or not loop.last %}{{ '<|im_end|>' + '\\n'}}{% endif %}{% endfor %}
|
|
{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{ '<|im_start|>assistant\\n' }}{% endif %}"""
|
|
|
|
|
|
@pytest.fixture
|
|
def chat_template_path(tmp_path):
|
|
"""Return the path to the chat template."""
|
|
temp_file_path = tmp_path / "chat_template.jinja"
|
|
with open(temp_file_path, "w") as f:
|
|
f.write(TEMPLATE_CHATML)
|
|
return temp_file_path
|
|
|
|
|
|
class TestLoadChatTemplate:
|
|
"""Test suite for loading chat templates."""
|
|
|
|
def test_load_chat_template_from_path(self, chat_template_path):
|
|
"""Test loading a chat template from a path."""
|
|
template = load_chat_template(chat_template_path)
|
|
assert template == TEMPLATE_CHATML
|
|
|
|
def test_load_chat_template_from_string(self):
|
|
"""Test loading a chat template from a string."""
|
|
text = "Hello, how can I help you?"
|
|
template = load_chat_template(text, is_literal=True)
|
|
assert template == text
|
|
|
|
def test_load_chat_template_from_none(self):
|
|
"""Test loading a chat template from None."""
|
|
template = load_chat_template(None)
|
|
assert template is None
|
|
|
|
def test_load_chat_template_from_path_with_invalid_path(self):
|
|
"""Test loading a chat template from a path with an invalid path."""
|
|
with pytest.raises(ValueError, match="looks like a file path"):
|
|
load_chat_template("invalid/path/to/chat_template.jinja")
|
|
|
|
def test_jinjalike_literal(self):
|
|
"""Test loading a chat template from a jinja-like string."""
|
|
template = "{{ messages }}"
|
|
template_content = load_chat_template(template)
|
|
assert template_content == template
|