Co-locate chunking/splitting

This commit is contained in:
Nathan Evans 2025-12-18 11:58:53 -08:00
parent 461291706f
commit b63f747d44
8 changed files with 39 additions and 69 deletions

View File

@ -6,8 +6,8 @@
from typing import Any
from graphrag.chunking.chunker import Chunker
from graphrag.index.text_splitting.text_splitting import (
split_single_text_on_tokens,
from graphrag.chunking.token_text_splitter import (
TokenTextSplitter,
)
from graphrag.tokenizer.get_tokenizer import get_tokenizer
@ -26,14 +26,12 @@ class TokenChunker(Chunker):
self._size = size
self._overlap = overlap
self._encoding_model = encoding_model
self._text_splitter = TokenTextSplitter(
chunk_size=size,
chunk_overlap=overlap,
tokenizer=get_tokenizer(encoding_model=encoding_model),
)
def chunk(self, text: str) -> list[str]:
"""Chunk the text into token-based chunks."""
tokenizer = get_tokenizer(encoding_model=self._encoding_model)
return split_single_text_on_tokens(
text,
chunk_overlap=self._overlap,
tokens_per_chunk=self._size,
encode=tokenizer.encode,
decode=tokenizer.decode,
)
return self._text_splitter.split_text(text)

View File

@ -1,7 +1,7 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""A module containing 'TokenTextSplitter' class and 'split_single_text_on_tokens' function."""
"""A module containing 'TokenTextSplitter' class and 'split_text_on_tokens' function."""
import logging
from abc import ABC
@ -10,7 +10,6 @@ from typing import cast
import pandas as pd
from graphrag.tokenizer.get_tokenizer import get_tokenizer
from graphrag.tokenizer.tokenizer import Tokenizer
EncodedText = list[int]
@ -26,47 +25,33 @@ class TokenTextSplitter(ABC):
_chunk_size: int
_chunk_overlap: int
_length_function: LengthFn
_keep_separator: bool
_add_start_index: bool
_strip_whitespace: bool
def __init__(
self,
tokenizer: Tokenizer,
# based on OpenAI embedding chunk size limits
# https://devblogs.microsoft.com/azure-sql/embedding-models-and-dimensions-optimizing-the-performance-resource-usage-ratio/
chunk_size: int = 8191,
chunk_overlap: int = 100,
length_function: LengthFn = len,
keep_separator: bool = False,
add_start_index: bool = False,
strip_whitespace: bool = True,
tokenizer: Tokenizer | None = None,
):
"""Init method definition."""
self._chunk_size = chunk_size
self._chunk_overlap = chunk_overlap
self._length_function = length_function
self._keep_separator = keep_separator
self._add_start_index = add_start_index
self._strip_whitespace = strip_whitespace
self._tokenizer = tokenizer or get_tokenizer()
self._tokenizer = tokenizer
def num_tokens(self, text: str) -> int:
"""Return the number of tokens in a string."""
return self._tokenizer.num_tokens(text)
def split_text(self, text: str | list[str]) -> list[str]:
def split_text(self, text: str) -> list[str]:
"""Split text method."""
if isinstance(text, list):
text = " ".join(text)
elif cast("bool", pd.isna(text)) or text == "":
if cast("bool", pd.isna(text)) or text == "":
return []
if not isinstance(text, str):
msg = f"Attempting to split a non-string value, actual is {type(text)}"
raise TypeError(msg)
return split_single_text_on_tokens(
return split_text_on_tokens(
text,
chunk_overlap=self._chunk_overlap,
tokens_per_chunk=self._chunk_size,
@ -75,7 +60,7 @@ class TokenTextSplitter(ABC):
)
def split_single_text_on_tokens(
def split_text_on_tokens(
text: str,
tokens_per_chunk: int,
chunk_overlap: int,
@ -84,19 +69,19 @@ def split_single_text_on_tokens(
) -> list[str]:
"""Split a single text and return chunks using the tokenizer."""
result = []
input_ids = encode(text)
input_tokens = encode(text)
start_idx = 0
cur_idx = min(start_idx + tokens_per_chunk, len(input_ids))
chunk_ids = input_ids[start_idx:cur_idx]
cur_idx = min(start_idx + tokens_per_chunk, len(input_tokens))
chunk_tokens = input_tokens[start_idx:cur_idx]
while start_idx < len(input_ids):
chunk_text = decode(list(chunk_ids))
while start_idx < len(input_tokens):
chunk_text = decode(list(chunk_tokens))
result.append(chunk_text) # Append chunked text as string
if cur_idx == len(input_ids):
if cur_idx == len(input_tokens):
break
start_idx += tokens_per_chunk - chunk_overlap
cur_idx = min(start_idx + tokens_per_chunk, len(input_ids))
chunk_ids = input_ids[start_idx:cur_idx]
cur_idx = min(start_idx + tokens_per_chunk, len(input_tokens))
chunk_tokens = input_tokens[start_idx:cur_idx]
return result

View File

@ -10,7 +10,7 @@ from dataclasses import dataclass
import numpy as np
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
from graphrag.index.text_splitting.text_splitting import TokenTextSplitter
from graphrag.chunking.token_text_splitter import TokenTextSplitter
from graphrag.index.utils.is_null import is_null
from graphrag.language_model.protocol.base import EmbeddingModel
from graphrag.logger.progress import ProgressTicker, progress_ticker

View File

@ -5,10 +5,11 @@ from unittest import mock
import pytest
import tiktoken
from graphrag.index.text_splitting.text_splitting import (
from graphrag.chunking.token_text_splitter import (
TokenTextSplitter,
split_single_text_on_tokens,
split_text_on_tokens,
)
from graphrag.tokenizer.get_tokenizer import get_tokenizer
class MockTokenizer:
@ -19,31 +20,34 @@ class MockTokenizer:
return "".join(chr(id) for id in token_ids)
tokenizer = get_tokenizer()
def test_split_text_str_empty():
splitter = TokenTextSplitter(chunk_size=5, chunk_overlap=2)
splitter = TokenTextSplitter(tokenizer=tokenizer, chunk_size=5, chunk_overlap=2)
result = splitter.split_text("")
assert result == []
def test_split_text_str_bool():
splitter = TokenTextSplitter(chunk_size=5, chunk_overlap=2)
splitter = TokenTextSplitter(tokenizer=tokenizer, chunk_size=5, chunk_overlap=2)
result = splitter.split_text(None) # type: ignore
assert result == []
def test_split_text_str_int():
splitter = TokenTextSplitter(chunk_size=5, chunk_overlap=2)
splitter = TokenTextSplitter(tokenizer=tokenizer, chunk_size=5, chunk_overlap=2)
with pytest.raises(TypeError):
splitter.split_text(123) # type: ignore
@mock.patch("graphrag.index.text_splitting.text_splitting.split_single_text_on_tokens")
@mock.patch("graphrag.chunking.token_text_splitter.split_text_on_tokens")
def test_split_text_large_input(mock_split):
large_text = "a" * 10_000
mock_split.return_value = ["chunk"] * 2_000
splitter = TokenTextSplitter(chunk_size=5, chunk_overlap=2)
splitter = TokenTextSplitter(tokenizer=tokenizer, chunk_size=5, chunk_overlap=2)
result = splitter.split_text(large_text)
@ -51,20 +55,7 @@ def test_split_text_large_input(mock_split):
mock_split.assert_called_once()
@mock.patch("graphrag.index.text_splitting.text_splitting.split_single_text_on_tokens")
def test_token_text_splitter(mock_split_text):
expected_chunks = ["chunk1", "chunk2", "chunk3"]
mock_split_text.return_value = expected_chunks
splitter = TokenTextSplitter(chunk_size=5, chunk_overlap=2)
splitter.split_text(["chunk1", "chunk2", "chunk3"])
mock_split_text.assert_called_once()
def test_split_single_text_on_tokens():
def test_split_text_on_tokens():
text = "This is a test text, meaning to be taken seriously by this test only."
mocked_tokenizer = MockTokenizer()
expected_splits = [
@ -83,7 +74,7 @@ def test_split_single_text_on_tokens():
"est only.",
]
result = split_single_text_on_tokens(
result = split_text_on_tokens(
text=text,
chunk_overlap=5,
tokens_per_chunk=10,
@ -93,7 +84,7 @@ def test_split_single_text_on_tokens():
assert result == expected_splits
def test_split_single_text_on_tokens_no_overlap():
def test_split_text_on_tokens_no_overlap():
text = "This is a test text, meaning to be taken seriously by this test only."
enc = tiktoken.get_encoding("cl100k_base")
@ -123,7 +114,7 @@ def test_split_single_text_on_tokens_no_overlap():
" only.",
]
result = split_single_text_on_tokens(
result = split_text_on_tokens(
text=text,
chunk_overlap=1,
tokens_per_chunk=2,

View File

@ -1,2 +0,0 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License

View File

@ -1,2 +0,0 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License