from abc import ABC, abstractmethod from dataclasses import dataclass from typing import Dict, Optional @dataclass class ReasoningParserResult: def __init__(self, in_reasoning: bool, content: Optional[str] = None, reasoning_content: Optional[str] = None): self.in_reasoning = in_reasoning self.content = content self.reasoning_content = reasoning_content class BaseReasoningParser(ABC): @abstractmethod def parse(self, text: str) -> ReasoningParserResult: raise NotImplementedError @abstractmethod def parse_delta(self, delta_text: str) -> ReasoningParserResult: raise NotImplementedError class DeepSeekR1Parser(BaseReasoningParser): """ Reasoning parser for DeepSeek-R1. Reasoning format: (.*). Since the latest official tokenizer_config.json initially adds "\\n" at the end of the prompt (https://huggingface.co/deepseek-ai/DeepSeek-R1/blob/main/tokenizer_config.json), treat all the text before the tag as `reasoning_content` and the text after as `content`. """ def __init__(self): self.reasoning_end = "" self.in_reasoning = True def _create_reasoning_end_result(self, content: str, reasoning_content: str): if len(content) == 0: reasoning_parser_result = ReasoningParserResult( True, reasoning_content=reasoning_content) elif len(reasoning_content) == 0: reasoning_parser_result = ReasoningParserResult(False, content=content) else: reasoning_parser_result = ReasoningParserResult( False, content=content, reasoning_content=reasoning_content) return reasoning_parser_result def parse(self, text: str) -> ReasoningParserResult: if self.reasoning_end not in text: return ReasoningParserResult(True, reasoning_content=text) splits = text.split(self.reasoning_end, maxsplit=1) reasoning_content = splits[0] content = splits[1] reasoning_parser_result = self._create_reasoning_end_result( content, reasoning_content) return reasoning_parser_result def parse_delta(self, delta_text: str) -> ReasoningParserResult: if self.in_reasoning and self.reasoning_end in delta_text: end_idx = delta_text.find(self.reasoning_end) reasoning_content = delta_text[:end_idx] content = delta_text[end_idx + len(self.reasoning_end):] reasoning_parser_result = self._create_reasoning_end_result( content, reasoning_content) self.in_reasoning = False return reasoning_parser_result if self.in_reasoning: return ReasoningParserResult(self.in_reasoning, reasoning_content=delta_text) # not self.in_reasoning: return ReasoningParserResult(self.in_reasoning, content=delta_text) class ReasoningParserFactory: parsers: Dict[str, BaseReasoningParser] = { "deepseek-r1": DeepSeekR1Parser, } @staticmethod def create_reasoning_parser(reasoning_parser: str) -> BaseReasoningParser: if reasoning_parser not in ReasoningParserFactory.parsers: raise ValueError(f"Invalid reasoning_parser: {reasoning_parser}") reasoning_parser_class = ReasoningParserFactory.parsers.get( reasoning_parser.lower()) return reasoning_parser_class()