Collapse token splitting functionality into one class/function

This commit is contained in:
Nathan Evans 2025-12-18 13:22:18 -08:00
parent b63f747d44
commit a20dbdb795
5 changed files with 142 additions and 240 deletions

View File

@ -3,14 +3,16 @@
"""A module containing 'TokenChunker' class."""
from collections.abc import Callable
from typing import Any
from graphrag.chunking.chunker import Chunker
from graphrag.chunking.token_text_splitter import (
TokenTextSplitter,
)
from graphrag.tokenizer.get_tokenizer import get_tokenizer
EncodedText = list[int]
DecodeFn = Callable[[EncodedText], str]
EncodeFn = Callable[[str], EncodedText]
class TokenChunker(Chunker):
"""A chunker that splits text into token-based chunks."""
@ -26,12 +28,41 @@ class TokenChunker(Chunker):
self._size = size
self._overlap = overlap
self._encoding_model = encoding_model
self._text_splitter = TokenTextSplitter(
chunk_size=size,
chunk_overlap=overlap,
tokenizer=get_tokenizer(encoding_model=encoding_model),
)
self._tokenizer = get_tokenizer(encoding_model=encoding_model)
def chunk(self, text: str) -> list[str]:
"""Chunk the text into token-based chunks."""
return self._text_splitter.split_text(text)
return split_text_on_tokens(
text,
chunk_size=self._size,
chunk_overlap=self._overlap,
encode=self._tokenizer.encode,
decode=self._tokenizer.decode,
)
def split_text_on_tokens(
text: str,
chunk_size: int,
chunk_overlap: int,
encode: EncodeFn,
decode: DecodeFn,
) -> list[str]:
"""Split a single text and return chunks using the tokenizer."""
result = []
input_tokens = encode(text)
start_idx = 0
cur_idx = min(start_idx + chunk_size, len(input_tokens))
chunk_tokens = input_tokens[start_idx:cur_idx]
while start_idx < len(input_tokens):
chunk_text = decode(list(chunk_tokens))
result.append(chunk_text) # Append chunked text as string
if cur_idx == len(input_tokens):
break
start_idx += chunk_size - chunk_overlap
cur_idx = min(start_idx + chunk_size, len(input_tokens))
chunk_tokens = input_tokens[start_idx:cur_idx]
return result

View File

@ -1,87 +0,0 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""A module containing 'TokenTextSplitter' class and 'split_text_on_tokens' function."""
import logging
from abc import ABC
from collections.abc import Callable
from typing import cast
import pandas as pd
from graphrag.tokenizer.tokenizer import Tokenizer
EncodedText = list[int]
DecodeFn = Callable[[EncodedText], str]
EncodeFn = Callable[[str], EncodedText]
LengthFn = Callable[[str], int]
logger = logging.getLogger(__name__)
class TokenTextSplitter(ABC):
"""Token text splitter class definition."""
_chunk_size: int
_chunk_overlap: int
def __init__(
self,
tokenizer: Tokenizer,
# based on OpenAI embedding chunk size limits
# https://devblogs.microsoft.com/azure-sql/embedding-models-and-dimensions-optimizing-the-performance-resource-usage-ratio/
chunk_size: int = 8191,
chunk_overlap: int = 100,
):
"""Init method definition."""
self._chunk_size = chunk_size
self._chunk_overlap = chunk_overlap
self._tokenizer = tokenizer
def num_tokens(self, text: str) -> int:
"""Return the number of tokens in a string."""
return self._tokenizer.num_tokens(text)
def split_text(self, text: str) -> list[str]:
"""Split text method."""
if cast("bool", pd.isna(text)) or text == "":
return []
if not isinstance(text, str):
msg = f"Attempting to split a non-string value, actual is {type(text)}"
raise TypeError(msg)
return split_text_on_tokens(
text,
chunk_overlap=self._chunk_overlap,
tokens_per_chunk=self._chunk_size,
decode=self._tokenizer.decode,
encode=self._tokenizer.encode,
)
def split_text_on_tokens(
text: str,
tokens_per_chunk: int,
chunk_overlap: int,
encode: EncodeFn,
decode: DecodeFn,
) -> list[str]:
"""Split a single text and return chunks using the tokenizer."""
result = []
input_tokens = encode(text)
start_idx = 0
cur_idx = min(start_idx + tokens_per_chunk, len(input_tokens))
chunk_tokens = input_tokens[start_idx:cur_idx]
while start_idx < len(input_tokens):
chunk_text = decode(list(chunk_tokens))
result.append(chunk_text) # Append chunked text as string
if cur_idx == len(input_tokens):
break
start_idx += tokens_per_chunk - chunk_overlap
cur_idx = min(start_idx + tokens_per_chunk, len(input_tokens))
chunk_tokens = input_tokens[start_idx:cur_idx]
return result

View File

@ -10,7 +10,7 @@ from dataclasses import dataclass
import numpy as np
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
from graphrag.chunking.token_text_splitter import TokenTextSplitter
from graphrag.chunking.token_chunker import split_text_on_tokens
from graphrag.index.utils.is_null import is_null
from graphrag.language_model.protocol.base import EmbeddingModel
from graphrag.logger.progress import ProgressTicker, progress_ticker
@ -39,17 +39,15 @@ async def run_embed_text(
if is_null(input):
return TextEmbeddingResult(embeddings=None)
splitter = _get_splitter(tokenizer, batch_max_tokens)
semaphore: asyncio.Semaphore = asyncio.Semaphore(num_threads)
# Break up the input texts. The sizes here indicate how many snippets are in each input text
texts, input_sizes = _prepare_embed_texts(input, splitter)
texts, input_sizes = _prepare_embed_texts(input, tokenizer, batch_max_tokens)
text_batches = _create_text_batches(
texts,
tokenizer,
batch_size,
batch_max_tokens,
splitter,
)
logger.info(
"embedding %d inputs via %d snippets using %d batches. max_batch_size=%d, batch_max_tokens=%d",
@ -72,13 +70,6 @@ async def run_embed_text(
return TextEmbeddingResult(embeddings=embeddings)
def _get_splitter(tokenizer: Tokenizer, batch_max_tokens: int) -> TokenTextSplitter:
return TokenTextSplitter(
tokenizer=tokenizer,
chunk_size=batch_max_tokens,
)
async def _execute(
model: EmbeddingModel,
chunks: list[list[str]],
@ -100,9 +91,9 @@ async def _execute(
def _create_text_batches(
texts: list[str],
tokenizer: Tokenizer,
max_batch_size: int,
max_batch_tokens: int,
splitter: TokenTextSplitter,
) -> list[list[str]]:
"""Create batches of texts to embed."""
# https://learn.microsoft.com/en-us/azure/ai-services/openai/reference
@ -112,7 +103,7 @@ def _create_text_batches(
current_batch_tokens = 0
for text in texts:
token_count = splitter.num_tokens(text)
token_count = tokenizer.num_tokens(text)
if (
len(current_batch) >= max_batch_size
or current_batch_tokens + token_count > max_batch_tokens
@ -131,18 +122,23 @@ def _create_text_batches(
def _prepare_embed_texts(
input: list[str], splitter: TokenTextSplitter
input: list[str],
tokenizer: Tokenizer,
batch_max_tokens: int = 8191,
chunk_overlap: int = 100,
) -> tuple[list[str], list[int]]:
sizes: list[int] = []
snippets: list[str] = []
for text in input:
# Split the input text and filter out any empty content
split_texts = splitter.split_text(text)
if split_texts is None:
continue
split_texts = split_text_on_tokens(
text,
chunk_size=batch_max_tokens,
chunk_overlap=chunk_overlap,
encode=tokenizer.encode,
decode=tokenizer.decode,
)
split_texts = [text for text in split_texts if len(text) > 0]
sizes.append(len(split_texts))
snippets.extend(split_texts)

View File

@ -5,6 +5,9 @@ from unittest.mock import Mock, patch
from graphrag.chunking.bootstrap import bootstrap
from graphrag.chunking.chunker_factory import create_chunker
from graphrag.chunking.token_chunker import (
split_text_on_tokens,
)
from graphrag.config.enums import ChunkStrategyType
from graphrag.config.models.chunking_config import ChunkingConfig
from graphrag.tokenizer.get_tokenizer import get_tokenizer
@ -62,6 +65,89 @@ class TestRunTokens:
assert len(chunks) > 0
class MockTokenizer:
def encode(self, text):
return [ord(char) for char in text]
def decode(self, token_ids):
return "".join(chr(id) for id in token_ids)
tokenizer = get_tokenizer()
def test_split_text_str_empty():
result = split_text_on_tokens(
"",
chunk_size=5,
chunk_overlap=2,
encode=tokenizer.encode,
decode=tokenizer.decode,
)
assert result == []
def test_split_text_on_tokens():
text = "This is a test text, meaning to be taken seriously by this test only."
mocked_tokenizer = MockTokenizer()
expected_splits = [
"This is a ",
"is a test ",
"test text,",
"text, mean",
" meaning t",
"ing to be ",
"o be taken",
"taken seri", # cspell:disable-line
" seriously",
"ously by t", # cspell:disable-line
" by this t",
"his test o",
"est only.",
]
result = split_text_on_tokens(
text=text,
chunk_overlap=5,
chunk_size=10,
decode=mocked_tokenizer.decode,
encode=lambda text: mocked_tokenizer.encode(text),
)
assert result == expected_splits
def test_split_text_on_tokens_no_overlap():
text = "This is a test text, meaning to be taken seriously by this test only."
tok = get_tokenizer(encoding_model="cl100k_base")
expected_splits = [
"This is",
" is a",
" a test",
" test text",
" text,",
", meaning",
" meaning to",
" to be",
" be taken", # cspell:disable-line
" taken seriously", # cspell:disable-line
" seriously by",
" by this", # cspell:disable-line
" this test",
" test only",
" only.",
]
result = split_text_on_tokens(
text=text,
chunk_size=2,
chunk_overlap=1,
decode=tok.decode,
encode=tok.encode,
)
assert result == expected_splits
@patch("tiktoken.get_encoding")
def test_get_encoding_fn_encode(mock_get_encoding):
# Create a mock encoding object with encode and decode methods

View File

@ -1,124 +0,0 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
from unittest import mock
import pytest
import tiktoken
from graphrag.chunking.token_text_splitter import (
TokenTextSplitter,
split_text_on_tokens,
)
from graphrag.tokenizer.get_tokenizer import get_tokenizer
class MockTokenizer:
def encode(self, text):
return [ord(char) for char in text]
def decode(self, token_ids):
return "".join(chr(id) for id in token_ids)
tokenizer = get_tokenizer()
def test_split_text_str_empty():
splitter = TokenTextSplitter(tokenizer=tokenizer, chunk_size=5, chunk_overlap=2)
result = splitter.split_text("")
assert result == []
def test_split_text_str_bool():
splitter = TokenTextSplitter(tokenizer=tokenizer, chunk_size=5, chunk_overlap=2)
result = splitter.split_text(None) # type: ignore
assert result == []
def test_split_text_str_int():
splitter = TokenTextSplitter(tokenizer=tokenizer, chunk_size=5, chunk_overlap=2)
with pytest.raises(TypeError):
splitter.split_text(123) # type: ignore
@mock.patch("graphrag.chunking.token_text_splitter.split_text_on_tokens")
def test_split_text_large_input(mock_split):
large_text = "a" * 10_000
mock_split.return_value = ["chunk"] * 2_000
splitter = TokenTextSplitter(tokenizer=tokenizer, chunk_size=5, chunk_overlap=2)
result = splitter.split_text(large_text)
assert len(result) == 2_000, "Large input was not split correctly"
mock_split.assert_called_once()
def test_split_text_on_tokens():
text = "This is a test text, meaning to be taken seriously by this test only."
mocked_tokenizer = MockTokenizer()
expected_splits = [
"This is a ",
"is a test ",
"test text,",
"text, mean",
" meaning t",
"ing to be ",
"o be taken",
"taken seri", # cspell:disable-line
" seriously",
"ously by t", # cspell:disable-line
" by this t",
"his test o",
"est only.",
]
result = split_text_on_tokens(
text=text,
chunk_overlap=5,
tokens_per_chunk=10,
decode=mocked_tokenizer.decode,
encode=lambda text: mocked_tokenizer.encode(text),
)
assert result == expected_splits
def test_split_text_on_tokens_no_overlap():
text = "This is a test text, meaning to be taken seriously by this test only."
enc = tiktoken.get_encoding("cl100k_base")
def encode(text: str) -> list[int]:
if not isinstance(text, str):
text = f"{text}"
return enc.encode(text)
def decode(tokens: list[int]) -> str:
return enc.decode(tokens)
expected_splits = [
"This is",
" is a",
" a test",
" test text",
" text,",
", meaning",
" meaning to",
" to be",
" be taken", # cspell:disable-line
" taken seriously", # cspell:disable-line
" seriously by",
" by this", # cspell:disable-line
" this test",
" test only",
" only.",
]
result = split_text_on_tokens(
text=text,
chunk_overlap=1,
tokens_per_chunk=2,
decode=decode,
encode=lambda text: encode(text),
)
assert result == expected_splits