Add prepending tests

This commit is contained in:
Nathan Evans 2025-12-19 16:36:30 -08:00
parent eb22d7a61b
commit 780a03827c
5 changed files with 42 additions and 66 deletions

View File

@ -15,5 +15,5 @@ class Chunker(ABC):
"""Create a chunker instance."""
@abstractmethod
def chunk(self, text: str, metadata: str | dict | None = None) -> list[str]:
def chunk(self, text: str, metadata: dict | None = None) -> list[str]:
"""Chunk method definition."""

View File

@ -3,7 +3,6 @@
"""A module containing 'SentenceChunker' class."""
import json
from typing import Any
import nltk
@ -20,19 +19,11 @@ class SentenceChunker(Chunker):
self._prepend_metadata = prepend_metadata
bootstrap()
def chunk(self, text: str, metadata: str | dict | None = None) -> list[str]:
def chunk(self, text: str, metadata: dict | None = None) -> list[str]:
"""Chunk the text into sentence-based chunks."""
chunks = nltk.sent_tokenize(text)
if self._prepend_metadata and metadata is not None:
line_delimiter = ".\n"
metadata_str = ""
if isinstance(metadata, str):
metadata = json.loads(metadata)
if isinstance(metadata, dict):
metadata_str = (
line_delimiter.join(f"{k}: {v}" for k, v in metadata.items())
+ line_delimiter
)
metadata_str = ".\n".join(f"{k}: {v}" for k, v in metadata.items()) + ".\n"
chunks = [metadata_str + chunk for chunk in chunks]
return chunks

View File

@ -3,7 +3,6 @@
"""A module containing 'TokenChunker' class."""
import json
from collections.abc import Callable
from typing import Any
@ -23,7 +22,6 @@ class TokenChunker(Chunker):
self,
size: int,
overlap: int,
encoding_model: str,
tokenizer: Tokenizer,
prepend_metadata: bool = False,
chunk_size_includes_metadata: bool = False,
@ -32,25 +30,18 @@ class TokenChunker(Chunker):
"""Create a token chunker instance."""
self._size = size
self._overlap = overlap
self._encoding_model = encoding_model
self._prepend_metadata = prepend_metadata
self._chunk_size_includes_metadata = chunk_size_includes_metadata
self._tokenizer = tokenizer
def chunk(self, text: str, metadata: str | dict | None = None) -> list[str]:
def chunk(self, text: str, metadata: dict | None = None) -> list[str]:
"""Chunk the text into token-based chunks."""
line_delimiter = ".\n"
# we have to create and measure the metadata first to account for the length when chunking
metadata_str = ""
metadata_tokens = 0
if self._prepend_metadata and metadata is not None:
if isinstance(metadata, str):
metadata = json.loads(metadata)
if isinstance(metadata, dict):
metadata_str = (
line_delimiter.join(f"{k}: {v}" for k, v in metadata.items())
+ line_delimiter
)
metadata_str = ".\n".join(f"{k}: {v}" for k, v in metadata.items()) + ".\n"
if self._chunk_size_includes_metadata:
metadata_tokens = len(self._tokenizer.encode(metadata_str))

View File

@ -3,6 +3,7 @@
"""A module containing run_workflow method definition."""
import json
import logging
from typing import Any, cast
@ -62,7 +63,10 @@ def create_base_text_units(
logger.info("Starting chunking process for %d documents", total_rows)
def chunker_with_logging(row: pd.Series, row_index: int) -> Any:
row["chunks"] = chunker.chunk(row["text"], row.get("metadata"))
metadata = row.get("metadata")
if (metadata is not None) and isinstance(metadata, str):
metadata = json.loads(metadata)
row["chunks"] = chunker.chunk(row["text"], metadata=metadata)
tick()
logger.info("chunker progress: %d/%d", row_index + 1, total_rows)
return row

View File

@ -11,14 +11,15 @@ from graphrag.chunking.token_chunker import (
split_text_on_tokens,
)
from graphrag.tokenizer.get_tokenizer import get_tokenizer
from graphrag_common.types.tokenizer import Tokenizer
class MockTokenizer:
def encode(self, text):
class MockTokenizer(Tokenizer):
def encode(self, text) -> list[int]:
return [ord(char) for char in text]
def decode(self, token_ids):
return "".join(chr(id) for id in token_ids)
def decode(self, tokens) -> str:
return "".join(chr(id) for id in tokens)
tokenizer = get_tokenizer()
@ -53,6 +54,18 @@ class TestRunSentences:
assert chunks[0] == " Sentence with spaces."
assert chunks[1] == "Another one!"
def test_prepend_metadata(self):
"""Test prepending metadata to chunks"""
input = "This is a test. Another sentence."
config = ChunkingConfig(
strategy=ChunkStrategyType.Sentence, prepend_metadata=True
)
chunker = create_chunker(config)
chunks = chunker.chunk(input, metadata={"message": "hello"})
assert chunks[0] == "message: hello.\nThis is a test."
assert chunks[1] == "message: hello.\nAnother sentence."
class TestRunTokens:
@patch("tiktoken.get_encoding")
@ -75,6 +88,20 @@ class TestRunTokens:
assert len(chunks) > 0
def test_prepend_metadata(self):
"""Test prepending metadata to chunks"""
mocked_tokenizer = MockTokenizer()
input = "This is a test."
config = ChunkingConfig(
strategy=ChunkStrategyType.Tokens, size=5, overlap=0, prepend_metadata=True
)
chunker = create_chunker(config, tokenizer=mocked_tokenizer)
chunks = chunker.chunk(input, metadata={"message": "hello"})
assert chunks[0] == "message: hello.\nThis "
assert chunks[1] == "message: hello.\nis a "
assert chunks[2] == "message: hello.\ntest."
def test_split_text_str_empty():
result = split_text_on_tokens(
@ -146,40 +173,3 @@ def test_split_text_on_tokens_no_overlap():
encode=tok.encode,
)
assert result == expected_splits
@patch("tiktoken.get_encoding")
def test_get_encoding_fn_encode(mock_get_encoding):
# Create a mock encoding object with encode and decode methods
mock_encoding = Mock()
mock_encoding.encode = Mock(return_value=[1, 2, 3])
mock_encoding.decode = Mock(return_value="decoded text")
# Configure the mock_get_encoding to return the mock encoding object
mock_get_encoding.return_value = mock_encoding
# Call the function to get encode and decode functions
tokenizer = get_tokenizer(encoding_model="mock_encoding")
# Test the encode function
encoded_text = tokenizer.encode("test text")
assert encoded_text == [1, 2, 3]
mock_encoding.encode.assert_called_once_with("test text")
@patch("tiktoken.get_encoding")
def test_get_encoding_fn_decode(mock_get_encoding):
# Create a mock encoding object with encode and decode methods
mock_encoding = Mock()
mock_encoding.encode = Mock(return_value=[1, 2, 3])
mock_encoding.decode = Mock(return_value="decoded text")
# Configure the mock_get_encoding to return the mock encoding object
mock_get_encoding.return_value = mock_encoding
# Call the function to get encode and decode functions
tokenizer = get_tokenizer(encoding_model="mock_encoding")
decoded_text = tokenizer.decode([1, 2, 3])
assert decoded_text == "decoded text"
mock_encoding.decode.assert_called_once_with([1, 2, 3])