mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
140 lines
5.9 KiB
Python
140 lines
5.9 KiB
Python
from abc import ABC, abstractmethod
|
|
from dataclasses import dataclass
|
|
from typing import Any, Optional, Type
|
|
|
|
|
|
@dataclass
|
|
class ReasoningParserResult:
|
|
content: str = ""
|
|
reasoning_content: str = ""
|
|
|
|
|
|
class BaseReasoningParser(ABC):
|
|
|
|
@abstractmethod
|
|
def parse(self, text: str) -> ReasoningParserResult:
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
def parse_delta(self, delta_text: str) -> ReasoningParserResult:
|
|
raise NotImplementedError
|
|
|
|
|
|
class DeepSeekR1Parser(BaseReasoningParser):
|
|
"""
|
|
Reasoning parser for DeepSeek-R1. Reasoning format: <think>(.*)</think>.
|
|
Since the latest official tokenizer_config.json initially adds "<think>\\n" at the end of the prompt
|
|
(https://huggingface.co/deepseek-ai/DeepSeek-R1/blob/main/tokenizer_config.json),
|
|
treat all the text before the </think> tag as `reasoning_content` and the text after as `content`.
|
|
"""
|
|
|
|
def __init__(self, reasoning_at_start: bool = False) -> None:
|
|
self.reasoning_start = "<think>"
|
|
self.reasoning_end = "</think>"
|
|
self.reasoning_at_start = reasoning_at_start
|
|
self.in_reasoning = self.reasoning_at_start
|
|
self._buffer = ""
|
|
|
|
def _create_reasoning_end_result(self, content: str,
|
|
reasoning_content: str):
|
|
if len(content) == 0:
|
|
reasoning_parser_result = ReasoningParserResult(
|
|
reasoning_content=reasoning_content)
|
|
elif len(reasoning_content) == 0:
|
|
reasoning_parser_result = ReasoningParserResult(content=content)
|
|
else:
|
|
reasoning_parser_result = ReasoningParserResult(
|
|
content=content, reasoning_content=reasoning_content)
|
|
return reasoning_parser_result
|
|
|
|
def parse(self, text: str) -> ReasoningParserResult:
|
|
if not self.reasoning_at_start:
|
|
splits = text.partition(self.reasoning_start)
|
|
if splits[1] == "":
|
|
# no reasoning start tag found
|
|
return ReasoningParserResult(content=text)
|
|
# reasoning start tag found
|
|
# text before reasoning start tag is dropped
|
|
text = splits[2]
|
|
splits = text.partition(self.reasoning_end)
|
|
reasoning_content, content = splits[0], splits[2]
|
|
return ReasoningParserResult(content=content,
|
|
reasoning_content=reasoning_content)
|
|
|
|
def parse_delta(self, delta_text: str) -> ReasoningParserResult:
|
|
self._buffer += delta_text
|
|
delta_text = self._buffer
|
|
reasoning_content = None
|
|
content = None
|
|
if (self.reasoning_start.startswith(delta_text)
|
|
or self.reasoning_end.startswith(delta_text)):
|
|
# waiting for more text to determine if it's a reasoning start or end tag
|
|
return ReasoningParserResult()
|
|
|
|
if not self.in_reasoning:
|
|
begin_idx = delta_text.find(self.reasoning_start)
|
|
if begin_idx == -1:
|
|
self._buffer = ""
|
|
return ReasoningParserResult(content=delta_text)
|
|
self.in_reasoning = True
|
|
# set reasoning_content, will be processed by the next block
|
|
reasoning_content = delta_text[begin_idx +
|
|
len(self.reasoning_start):]
|
|
|
|
if self.in_reasoning:
|
|
delta_text = reasoning_content if reasoning_content is not None else delta_text
|
|
end_idx = delta_text.find(self.reasoning_end)
|
|
if end_idx == -1:
|
|
last_idx = delta_text.rfind(self.reasoning_end[0])
|
|
if last_idx != -1 and self.reasoning_end.startswith(
|
|
delta_text[last_idx:]):
|
|
self._buffer = delta_text[last_idx:]
|
|
reasoning_content = delta_text[:last_idx]
|
|
else:
|
|
self._buffer = ""
|
|
reasoning_content = delta_text
|
|
return ReasoningParserResult(
|
|
reasoning_content=reasoning_content)
|
|
reasoning_content = delta_text[:end_idx]
|
|
content = delta_text[end_idx + len(self.reasoning_end):]
|
|
self.in_reasoning = False
|
|
self._buffer = ""
|
|
return ReasoningParserResult(content=content,
|
|
reasoning_content=reasoning_content)
|
|
raise RuntimeError(
|
|
"Unreachable code reached in `DeepSeekR1Parser.parse_delta`")
|
|
|
|
|
|
class ReasoningParserFactory:
|
|
parsers: dict[str, Type[BaseReasoningParser]] = {
|
|
"deepseek-r1": DeepSeekR1Parser,
|
|
"qwen3": DeepSeekR1Parser,
|
|
"nano-v3": DeepSeekR1Parser,
|
|
}
|
|
|
|
@staticmethod
|
|
def create_reasoning_parser(
|
|
reasoning_parser: str,
|
|
chat_template_kwargs: Optional[dict[str, Any]] = None
|
|
) -> BaseReasoningParser:
|
|
try:
|
|
reasoning_parser_class = ReasoningParserFactory.parsers[
|
|
reasoning_parser.lower()]
|
|
if reasoning_parser == "deepseek-r1":
|
|
return reasoning_parser_class(reasoning_at_start=True)
|
|
elif reasoning_parser == "nano-v3":
|
|
# Note: If the model is with reasoning (default behavior), `reasoning_at_start` should be True, and the starting response should be parsed into `reasoning_content`.
|
|
# While the model is without reasoning, `reasoning_at_start` should be False to parse the response into `content` fields.
|
|
is_reasoning_model = True
|
|
if isinstance(chat_template_kwargs, dict):
|
|
is_reasoning_model = chat_template_kwargs.get(
|
|
"enable_thinking", True)
|
|
return reasoning_parser_class(
|
|
reasoning_at_start=is_reasoning_model)
|
|
return reasoning_parser_class()
|
|
except KeyError as e:
|
|
raise ValueError(
|
|
f"Invalid reasoning parser: {reasoning_parser}\n"
|
|
f"Supported parsers: {list(ReasoningParserFactory.parsers.keys())}"
|
|
) from e
|