TensorRT-LLMs/tests/unittest/llmapi/test_reasoning_parser.py
Pengyun Lin a4227cf1b0
[None][feat] Support Qwen3 reasoning parser (#8000)
Signed-off-by: Pengyun Lin <81065165+LinPoly@users.noreply.github.com>
2025-10-21 14:08:39 +08:00

74 lines
3.1 KiB
Python

import pytest
from tensorrt_llm.llmapi.reasoning_parser import ReasoningParserFactory
R1_START, R1_END = "<think>", "</think>"
@pytest.mark.parametrize(("text", "content", "reasoning_context"), [
("a b", "", "a b"),
(f"{R1_END} a b", " a b", ""),
(f"a {R1_END} b", " b", "a "),
(f"a b {R1_END}", "", "a b "),
(f"{R1_START} a {R1_END} b", " b", f"{R1_START} a "),
])
def test_deepseek_r1_reasoning_parser(text: str, content: str,
reasoning_context: str):
reasoning_parser = ReasoningParserFactory.create_reasoning_parser(
"deepseek-r1")
result = reasoning_parser.parse(text)
assert result.content == content
assert result.reasoning_content == reasoning_context
@pytest.mark.parametrize(("delta_texts", "content", "reasoning_context"), [
(["a", "b"], ["", ""], ["a", "b"]),
([R1_END, "a", "b"], ["", "a", "b"], ["", "", ""]),
(["a", R1_END, "b"], ["", "", "b"], ["a", "", ""]),
(["a", "b", R1_END], ["", "", ""], ["a", "b", ""]),
(["a", f"l{R1_END}", "b"], ["", "", "b"], ["a", "l", ""]),
(["a", f"l{R1_END}r", "b"], ["", "r", "b"], ["a", "l", ""]),
(["a", f"{R1_END}r", "b"], ["", "r", "b"], ["a", "", ""]),
])
def test_deepseek_r1_reasoning_parser_stream(delta_texts: list, content: list,
reasoning_context: list):
reasoning_parser = ReasoningParserFactory.create_reasoning_parser(
"deepseek-r1")
for i, delta_text in enumerate(delta_texts):
result = reasoning_parser.parse_delta(delta_text)
assert result.content == content[i]
assert result.reasoning_content == reasoning_context[i]
@pytest.mark.parametrize(("text", "content", "reasoning_context"), [
("a<think>b</think>c", "c", "b"),
("<think>a</think>b", "b", "a"),
("<think>a", "", "a"),
("a", "a", ""),
("<think>", "", ""),
])
def test_qwen3_reasoning_parser(text: str, content: str,
reasoning_context: str):
reasoning_parser = ReasoningParserFactory.create_reasoning_parser("qwen3")
result = reasoning_parser.parse(text)
assert result.content == content
assert result.reasoning_content == reasoning_context
@pytest.mark.parametrize(("delta_texts", "content", "reasoning_context"), [
(["<think>a", "l</think>r", "b"], ["", "r", "b"], ["a", "l", ""]),
(["<th", "ink>a</think>b"], ["", "b"], ["", "a"]),
(["<think>a</th", "ink>b"], ["", "b"], ["a", ""]),
(["<think>", "a</think>b"], ["", "b"], ["", "a"]),
(["<think>a</think>", "b"], ["", "b"], ["a", ""]),
(["<think>a</th", "ank></th", "ink>b"], ["", "", "b"
], ["a", "</thank>", ""]),
])
def test_qwen3_reasoning_parser_stream(delta_texts: list, content: list,
reasoning_context: list):
reasoning_parser = ReasoningParserFactory.create_reasoning_parser("qwen3")
for i, delta_text in enumerate(delta_texts):
result = reasoning_parser.parse_delta(delta_text)
assert result.content == content[i]
assert result.reasoning_content == reasoning_context[i]