Rename ChunkResult to TextChunk and add transformer support

- Rename chunk_result.py to text_chunk.py with ChunkResult -> TextChunk
- Add 'original' field to TextChunk to track pre-transform text
- Add optional transform callback to chunker.chunk() method
- Add add_metadata transformer for prepending metadata to chunks
- Update create_chunk_results to apply transforms and populate original
- Update sentence_chunker and token_chunker with transform support
- Refactor create_base_text_units to use new transformer pattern
- Rename pluck_metadata to get/collect methods on TextDocument
This commit is contained in:
Nathan Evans 2026-01-07 17:43:34 -08:00
parent e8e316f291
commit 39125b2b13
11 changed files with 122 additions and 84 deletions

View File

@ -1,19 +0,0 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""A module containing 'add_metadata' function."""
def add_metadata(
text: str,
metadata: dict,
delimiter: str = ": ",
line_delimiter: str = "\n",
append: bool = False,
) -> str:
"""Add metadata to the given text, prepending by default. This utility writes the dict as rows of key/value pairs."""
metadata_str = (
line_delimiter.join(f"{k}{delimiter}{v}" for k, v in metadata.items())
+ line_delimiter
)
return text + metadata_str if append else metadata_str + text

View File

@ -1,17 +0,0 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""The ChunkResult dataclass."""
from dataclasses import dataclass
@dataclass
class ChunkResult:
"""Result of chunking a document."""
text: str
index: int
start_char: int
end_char: int
token_count: int | None = None

View File

@ -4,9 +4,10 @@
"""A module containing the 'Chunker' class."""
from abc import ABC, abstractmethod
from collections.abc import Callable
from typing import Any
from graphrag_chunking.chunk_result import ChunkResult
from graphrag_chunking.text_chunk import TextChunk
class Chunker(ABC):
@ -17,5 +18,7 @@ class Chunker(ABC):
"""Create a chunker instance."""
@abstractmethod
def chunk(self, text: str) -> list[ChunkResult]:
def chunk(
self, text: str, transform: Callable[[str], str] | None = None
) -> list[TextChunk]:
"""Chunk method definition."""

View File

@ -5,26 +5,28 @@
from collections.abc import Callable
from graphrag_chunking.chunk_result import ChunkResult
from graphrag_chunking.text_chunk import TextChunk
def create_chunk_results(
chunks: list[str],
transform: Callable[[str], str] | None = None,
encode: Callable[[str], list[int]] | None = None,
) -> list[ChunkResult]:
"""Create chunk results from a list of text chunks. The index assignments are 0-based and assume chunks we not stripped relative to the source text."""
) -> list[TextChunk]:
"""Create chunk results from a list of text chunks. The index assignments are 0-based and assume chunks were not stripped relative to the source text."""
results = []
start_char = 0
for index, chunk in enumerate(chunks):
end_char = start_char + len(chunk) - 1 # 0-based indices
chunk = ChunkResult(
text=chunk,
result = TextChunk(
original=chunk,
text=transform(chunk) if transform else chunk,
index=index,
start_char=start_char,
end_char=end_char,
)
if encode:
chunk.token_count = len(encode(chunk.text))
results.append(chunk)
result.token_count = len(encode(result.text))
results.append(result)
start_char = end_char + 1
return results

View File

@ -9,9 +9,9 @@ from typing import Any
import nltk
from graphrag_chunking.bootstrap_nltk import bootstrap
from graphrag_chunking.chunk_result import ChunkResult
from graphrag_chunking.chunker import Chunker
from graphrag_chunking.create_chunk_results import create_chunk_results
from graphrag_chunking.text_chunk import TextChunk
class SentenceChunker(Chunker):
@ -24,10 +24,14 @@ class SentenceChunker(Chunker):
self._encode = encode
bootstrap()
def chunk(self, text) -> list[ChunkResult]:
def chunk(
self, text: str, transform: Callable[[str], str] | None = None
) -> list[TextChunk]:
"""Chunk the text into sentence-based chunks."""
sentences = nltk.sent_tokenize(text.strip())
results = create_chunk_results(sentences, encode=self._encode)
results = create_chunk_results(
sentences, transform=transform, encode=self._encode
)
# nltk sentence tokenizer may trim whitespace, so we need to adjust start/end chars
for index, result in enumerate(results):
txt = result.text

View File

@ -0,0 +1,29 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""The TextChunk dataclass."""
from dataclasses import dataclass
@dataclass
class TextChunk:
"""Result of chunking a document."""
original: str
"""Raw original text chunk before any transformation."""
text: str
"""The final text content of this chunk."""
index: int
"""Zero-based index of this chunk within the source document."""
start_char: int
"""Character index where the raw chunk text begins in the source document."""
end_char: int
"""Character index where the raw chunk text ends in the source document."""
token_count: int | None = None
"""Number of tokens in the final chunk text, if computed."""

View File

@ -6,9 +6,9 @@
from collections.abc import Callable
from typing import Any
from graphrag_chunking.chunk_result import ChunkResult
from graphrag_chunking.chunker import Chunker
from graphrag_chunking.create_chunk_results import create_chunk_results
from graphrag_chunking.text_chunk import TextChunk
class TokenChunker(Chunker):
@ -28,7 +28,9 @@ class TokenChunker(Chunker):
self._encode = encode
self._decode = decode
def chunk(self, text: str) -> list[ChunkResult]:
def chunk(
self, text: str, transform: Callable[[str], str] | None = None
) -> list[TextChunk]:
"""Chunk the text into token-based chunks."""
chunks = split_text_on_tokens(
text,
@ -37,7 +39,7 @@ class TokenChunker(Chunker):
encode=self._encode,
decode=self._decode,
)
return create_chunk_results(chunks, encode=self._encode)
return create_chunk_results(chunks, transform=transform, encode=self._encode)
def split_text_on_tokens(

View File

@ -0,0 +1,25 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""A collection of useful built-in transformers you can use for chunking."""
from collections.abc import Callable
from typing import Any
def add_metadata(
metadata: dict[str, Any],
delimiter: str = ": ",
line_delimiter: str = "\n",
append: bool = False,
) -> Callable[[str], str]:
"""Add metadata to the given text, prepending by default. This utility writes the dict as rows of key/value pairs."""
def transformer(text: str) -> str:
metadata_str = (
line_delimiter.join(f"{k}{delimiter}{v}" for k, v in metadata.items())
+ line_delimiter
)
return text + metadata_str if append else metadata_str + text
return transformer

View File

@ -27,27 +27,33 @@ class TextDocument:
raw_data: dict[str, Any] | None = None
"""Raw data from source document."""
def pluck_metadata(self, fields: list[str]) -> dict[str, Any]:
"""Extract metadata fields from a TextDocument.
def get(self, field: str, default_value: Any = None) -> Any:
"""
Get a single field from the TextDocument.
Functions like the get method on a dictionary, returning default_value if the field is not found.
Supports nested fields using dot notation.
This takes a two step approach for flexibility:
1. If the field is one of the standard text document fields (id, title, text, creation_date), just grab it directly. This accommodates unstructured text for example, which just has the standard fields.
2. Otherwise. try to extract it from the raw_data dict. This allows users to specify any column from the original input file.
If a field does not exist in either location, we'll throw because that means the user config is incorrect.
"""
metadata = {}
if field in ["id", "title", "text", "creation_date"]:
return getattr(self, field)
raw = self.raw_data or {}
try:
return get_property(raw, field)
except KeyError:
return default_value
def collect(self, fields: list[str]) -> dict[str, Any]:
"""Extract data fields from a TextDocument into a dict."""
data = {}
for field in fields:
if field in ["id", "title", "text", "creation_date"]:
value = getattr(self, field)
else:
raw = self.raw_data or {}
value = get_property(raw, field)
if value is None:
logger.warning(
"Metadata field '%s' not found in TextDocument standard fields or raw_data. Please check your configuration.",
field,
)
value = self.get(field)
if value is not None:
metadata[field] = value
return metadata
data[field] = value
return data

View File

@ -7,9 +7,9 @@ import logging
from typing import Any, cast
import pandas as pd
from graphrag_chunking.add_metadata import add_metadata
from graphrag_chunking.chunker import Chunker
from graphrag_chunking.chunker_factory import create_chunker
from graphrag_chunking.transformers import add_metadata
from graphrag_input import TextDocument
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
@ -66,8 +66,6 @@ def create_base_text_units(
logger.info("Starting chunking process for %d documents", total_rows)
def chunker_with_logging(row: pd.Series, row_index: int) -> Any:
row["chunks"] = [chunk.text for chunk in chunker.chunk(row["text"])]
if prepend_metadata:
# create a standard text document for metadata plucking
# ignore any additional fields in case the input dataframe has extra columns
@ -76,13 +74,17 @@ def create_base_text_units(
title=row["title"],
text=row["text"],
creation_date=row["creation_date"],
raw_data=row.get("raw_data", None),
raw_data=row["raw_data"],
)
metadata = document.pluck_metadata(prepend_metadata)
row["chunks"] = [
add_metadata(chunk, metadata, line_delimiter=".\n")
for chunk in row["chunks"]
]
metadata = document.collect(prepend_metadata)
transformer = add_metadata(metadata=metadata, line_delimiter=".\n") # delim with . for back-compat older indexes
else:
transformer = None
row["chunks"] = [
chunk.text for chunk in chunker.chunk(row["text"], transform=transformer)
]
tick()
logger.info("chunker progress: %d/%d", row_index + 1, total_rows)
return row

View File

@ -1,14 +1,15 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
from graphrag_chunking.add_metadata import add_metadata
from graphrag_chunking.transformers import add_metadata
def test_add_metadata_one_row():
"""Test prepending metadata to chunks"""
chunks = ["This is a test.", "Another sentence."]
metadata = {"message": "hello"}
results = [add_metadata(chunk, metadata) for chunk in chunks]
transformer = add_metadata(metadata)
results = [transformer(chunk) for chunk in chunks]
assert results[0] == "message: hello\nThis is a test."
assert results[1] == "message: hello\nAnother sentence."
@ -17,7 +18,8 @@ def test_add_metadata_one_row_append():
"""Test prepending metadata to chunks"""
chunks = ["This is a test.", "Another sentence."]
metadata = {"message": "hello"}
results = [add_metadata(chunk, metadata, append=True) for chunk in chunks]
transformer = add_metadata(metadata, append=True)
results = [transformer(chunk) for chunk in chunks]
assert results[0] == "This is a test.message: hello\n"
assert results[1] == "Another sentence.message: hello\n"
@ -26,7 +28,8 @@ def test_add_metadata_multiple_rows():
"""Test prepending metadata to chunks"""
chunks = ["This is a test.", "Another sentence."]
metadata = {"message": "hello", "tag": "first"}
results = [add_metadata(chunk, metadata) for chunk in chunks]
transformer = add_metadata(metadata)
results = [transformer(chunk) for chunk in chunks]
assert results[0] == "message: hello\ntag: first\nThis is a test."
assert results[1] == "message: hello\ntag: first\nAnother sentence."
@ -35,9 +38,7 @@ def test_add_metadata_custom_delimiters():
"""Test prepending metadata to chunks"""
chunks = ["This is a test.", "Another sentence."]
metadata = {"message": "hello", "tag": "first"}
results = [
add_metadata(chunk, metadata, delimiter="-", line_delimiter="_")
for chunk in chunks
]
transformer = add_metadata(metadata, delimiter="-", line_delimiter="_")
results = [transformer(chunk) for chunk in chunks]
assert results[0] == "message-hello_tag-first_This is a test."
assert results[1] == "message-hello_tag-first_Another sentence."