Add ChunkResult model

This commit is contained in:
Nathan Evans 2025-12-22 14:18:21 -08:00
parent ee20153d8c
commit a741bfb8d7
7 changed files with 111 additions and 18 deletions

View File

@ -0,0 +1,17 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""The ChunkResult dataclass."""
from dataclasses import dataclass
@dataclass
class ChunkResult:
"""Result of chunking a document."""
text: str
index: int
start_char: int
end_char: int
token_count: int | None = None

View File

@ -6,6 +6,8 @@
from abc import ABC, abstractmethod
from typing import Any
from graphrag_chunking.chunk_result import ChunkResult
class Chunker(ABC):
"""Abstract base class for document chunkers."""
@ -15,5 +17,5 @@ class Chunker(ABC):
"""Create a chunker instance."""
@abstractmethod
def chunk(self, text: str) -> list[str]:
def chunk(self, text: str) -> list[ChunkResult]:
"""Chunk method definition."""

View File

@ -0,0 +1,30 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""A module containing 'create_chunk_results' function."""
from collections.abc import Callable
from graphrag_chunking.chunk_result import ChunkResult
def create_chunk_results(
chunks: list[str],
encode: Callable[[str], list[int]] | None = None,
) -> list[ChunkResult]:
"""Create chunk results from a list of text chunks. The index assignments are 0-based and assume chunks we not stripped relative to the source text."""
results = []
start_char = 0
for index, chunk in enumerate(chunks):
end_char = start_char + len(chunk) - 1 # 0-based indices
chunk = ChunkResult(
text=chunk,
index=index,
start_char=start_char,
end_char=end_char,
)
if encode:
chunk.token_count = len(encode(chunk.text))
results.append(chunk)
start_char = end_char + 1
return results

View File

@ -3,21 +3,42 @@
"""A module containing 'SentenceChunker' class."""
from collections.abc import Callable
from typing import Any
import nltk
from graphrag_chunking.bootstrap_nltk import bootstrap
from graphrag_chunking.chunk_result import ChunkResult
from graphrag_chunking.chunker import Chunker
from graphrag_chunking.create_chunk_results import create_chunk_results
class SentenceChunker(Chunker):
"""A chunker that splits text into sentence-based chunks."""
def __init__(self, **kwargs: Any) -> None:
def __init__(
self, encode: Callable[[str], list[int]] | None = None, **kwargs: Any
) -> None:
"""Create a sentence chunker instance."""
self._encode = encode
bootstrap()
def chunk(self, text) -> list[str]:
def chunk(self, text) -> list[ChunkResult]:
"""Chunk the text into sentence-based chunks."""
return nltk.sent_tokenize(text)
sentences = nltk.sent_tokenize(text.strip())
results = create_chunk_results(sentences, encode=self._encode)
# nltk sentence tokenizer may trim whitespace, so we need to adjust start/end chars
for index, result in enumerate(results):
txt = result.text
start = result.start_char
actual_start = text.find(txt, start)
delta = actual_start - start
if delta > 0:
result.start_char += delta
result.end_char += delta
# bump the next to keep the start check from falling too far behind
if index < len(results) - 1:
results[index + 1].start_char += delta
results[index + 1].end_char += delta
return results

View File

@ -6,7 +6,9 @@
from collections.abc import Callable
from typing import Any
from graphrag_chunking.chunk_result import ChunkResult
from graphrag_chunking.chunker import Chunker
from graphrag_chunking.create_chunk_results import create_chunk_results
class TokenChunker(Chunker):
@ -26,15 +28,16 @@ class TokenChunker(Chunker):
self._encode = encode
self._decode = decode
def chunk(self, text: str) -> list[str]:
def chunk(self, text: str) -> list[ChunkResult]:
"""Chunk the text into token-based chunks."""
return split_text_on_tokens(
chunks = split_text_on_tokens(
text,
chunk_size=self._size,
chunk_overlap=self._overlap,
encode=self._encode,
decode=self._decode,
)
return create_chunk_results(chunks, encode=self._encode)
def split_text_on_tokens(

View File

@ -66,7 +66,7 @@ def create_base_text_units(
logger.info("Starting chunking process for %d documents", total_rows)
def chunker_with_logging(row: pd.Series, row_index: int) -> Any:
row["chunks"] = chunker.chunk(row["text"])
row["chunks"] = [chunk.text for chunk in chunker.chunk(row["text"])]
metadata = row.get("metadata", None)
if prepend_metadata and metadata is not None:

View File

@ -28,22 +28,42 @@ class TestRunSentences:
def test_basic_functionality(self):
"""Test basic sentence splitting without metadata"""
input = "This is a test. Another sentence."
input = "This is a test. Another sentence. And a third one!"
chunker = create_chunker(ChunkingConfig(strategy=ChunkStrategyType.Sentence))
chunks = chunker.chunk(input)
assert len(chunks) == 3
assert chunks[0].text == "This is a test."
assert chunks[0].index == 0
assert chunks[0].start_char == 0
assert chunks[0].end_char == 14
assert chunks[1].text == "Another sentence."
assert chunks[1].index == 1
assert chunks[1].start_char == 16
assert chunks[1].end_char == 32
assert chunks[2].text == "And a third one!"
assert chunks[2].index == 2
assert chunks[2].start_char == 34
assert chunks[2].end_char == 49
def test_mixed_whitespace_handling(self):
"""Test input with irregular whitespace"""
input = " Sentence with spaces. Another one! "
chunker = create_chunker(ChunkingConfig(strategy=ChunkStrategyType.Sentence))
chunks = chunker.chunk(input)
assert len(chunks) == 2
assert chunks[0] == "This is a test."
assert chunks[1] == "Another sentence."
assert chunks[0].text == "Sentence with spaces."
assert chunks[0].index == 0
assert chunks[0].start_char == 3
assert chunks[0].end_char == 23
def test_mixed_whitespace_handling(self):
"""Test input with irregular whitespace"""
input = " Sentence with spaces. Another one! "
chunker = create_chunker(ChunkingConfig(strategy=ChunkStrategyType.Sentence))
chunks = chunker.chunk(input)
assert chunks[0] == " Sentence with spaces."
assert chunks[1] == "Another one!"
assert chunks[1].text == "Another one!"
assert chunks[1].index == 1
assert chunks[1].start_char == 25
assert chunks[1].end_char == 36
class TestRunTokens: