[Bugfix] Remove tokenizer encode/decode calls from Olmo3 reasoning parser (#40855)

Signed-off-by: Yifan <yzong@redhat.com>
Co-authored-by: Flora Feng <4florafeng@gmail.com>
This commit is contained in:
yzong-rh
2026-04-27 22:36:54 -04:00
committed by GitHub
parent 03aeed802f
commit 0d4f714208
2 changed files with 42 additions and 16 deletions
+12 -1
View File
@@ -41,6 +41,12 @@ SIMPLE_REASONING_WITH_MULTIPLE_NEWLINES = {
"content": "\n\n\nThis is the rest",
}
SIMPLE_REASONING_WITH_TRAILING_SPACE = {
"output": f"{START_REASONING}\nLook!\nI'm thinking... {END_REASONING}\nThis is the rest", # noqa: E501
"reasoning": "\nLook!\nI'm thinking... ",
"content": "\nThis is the rest",
}
NO_REASONING_ONLY_END_THINK = {
"output": f"{END_REASONING}\n\nNo thoughts, head empty!",
"reasoning": None,
@@ -114,6 +120,11 @@ TEST_CASES = [
SIMPLE_REASONING_WITH_MULTIPLE_NEWLINES,
id="simple_reasoning_with_multiple_newlines_streaming",
),
pytest.param(
True, # enable streaming
SIMPLE_REASONING_WITH_TRAILING_SPACE,
id="simple_reasoning_with_trailing_space_streaming",
),
pytest.param(
True, # enable streaming
NO_REASONING_ONLY_END_THINK,
@@ -127,7 +138,7 @@ TEST_CASES = [
]
# Global tokenizer initialization to avoid repeated loading
tokenizer = AutoTokenizer.from_pretrained("allenai/dolma2-tokenizer")
tokenizer = AutoTokenizer.from_pretrained("allenai/Olmo-3-7B-Think")
@pytest.mark.parametrize("streaming, param_dict", TEST_CASES)
+30 -15
View File
@@ -218,24 +218,32 @@ class Olmo3ReasoningParser(ReasoningParser):
token is missing from generation.
"""
think_start: str = r"<think>"
think_end: str = r"</think>"
# </think> is split in 3 by the pre-tokenizer, first split can be tokenized
# with an optional leading space, so there are 2 possible tokenizations
think_end_first_split: list[str] = [r"Ġ</", r"</"]
think_end_rest_split: list[str] = [r"think", r">"]
# notice that the first think is optional; this allows template to
# work in cases when we hardcode a <think> at the beginning of the
# reasoning template.
reasoning_regex: re.Pattern = re.compile(
rf"^(?:{think_start})?(?P<reasoning>.*?)"
rf"{think_end}(?P<content>.*)$",
re.DOTALL,
)
def __init__(self, tokenizer: "TokenizerLike", *args, **kwargs):
super().__init__(tokenizer, *args, **kwargs)
self.think_start = r"<think>"
self.think_end = r"</think>"
# notice that the first think is optional; this allows template to
# work in cases when we hardcode a <think> at the beginning of the
# reasoning template.
reasoning_expr = (
rf"^(?:{self.think_start})?(?P<reasoning>.*?)"
rf"{self.think_end}(?P<content>.*)$"
)
self.reasoning_regex = re.compile(reasoning_expr, re.DOTALL)
self.buffer = Olmo3ReasoningBuffer(
think_start=self.think_start, think_end=self.think_end
)
self.think_end_first_token_ids: list[int] = [
self.vocab[token] for token in self.think_end_first_split
]
self.think_end_rest_token_ids: list[int] = [
self.vocab[token] for token in self.think_end_rest_split
]
@property
def reasoning_start_str(self) -> str:
@@ -246,8 +254,15 @@ class Olmo3ReasoningParser(ReasoningParser):
return self.think_end
def is_reasoning_end(self, input_ids: Sequence[int]) -> bool:
text = self.model_tokenizer.decode(input_ids)
return self.think_end in text
rest_ids = self.think_end_rest_token_ids
rest_len = len(rest_ids)
for i in range(len(input_ids) - rest_len, -1, -1):
if (
list(input_ids[i + 1 : i + 1 + rest_len]) == rest_ids
and input_ids[i] in self.think_end_first_token_ids
):
return True
return False
def extract_content_ids(self, input_ids: list[int]) -> list[int]:
# for Olmo 3 streaming reason parsing, the stream parse