mirror of
https://github.com/microsoft/graphrag.git
synced 2026-01-14 00:57:23 +08:00
Rename ChunkResult to TextChunk and add transformer support
- Rename chunk_result.py to text_chunk.py with ChunkResult -> TextChunk - Add 'original' field to TextChunk to track pre-transform text - Add optional transform callback to chunker.chunk() method - Add add_metadata transformer for prepending metadata to chunks - Update create_chunk_results to apply transforms and populate original - Update sentence_chunker and token_chunker with transform support - Refactor create_base_text_units to use new transformer pattern - Rename pluck_metadata to get/collect methods on TextDocument
This commit is contained in:
parent
e8e316f291
commit
39125b2b13
@ -1,19 +0,0 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""A module containing 'add_metadata' function."""
|
||||
|
||||
|
||||
def add_metadata(
|
||||
text: str,
|
||||
metadata: dict,
|
||||
delimiter: str = ": ",
|
||||
line_delimiter: str = "\n",
|
||||
append: bool = False,
|
||||
) -> str:
|
||||
"""Add metadata to the given text, prepending by default. This utility writes the dict as rows of key/value pairs."""
|
||||
metadata_str = (
|
||||
line_delimiter.join(f"{k}{delimiter}{v}" for k, v in metadata.items())
|
||||
+ line_delimiter
|
||||
)
|
||||
return text + metadata_str if append else metadata_str + text
|
||||
@ -1,17 +0,0 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""The ChunkResult dataclass."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChunkResult:
|
||||
"""Result of chunking a document."""
|
||||
|
||||
text: str
|
||||
index: int
|
||||
start_char: int
|
||||
end_char: int
|
||||
token_count: int | None = None
|
||||
@ -4,9 +4,10 @@
|
||||
"""A module containing the 'Chunker' class."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
from graphrag_chunking.chunk_result import ChunkResult
|
||||
from graphrag_chunking.text_chunk import TextChunk
|
||||
|
||||
|
||||
class Chunker(ABC):
|
||||
@ -17,5 +18,7 @@ class Chunker(ABC):
|
||||
"""Create a chunker instance."""
|
||||
|
||||
@abstractmethod
|
||||
def chunk(self, text: str) -> list[ChunkResult]:
|
||||
def chunk(
|
||||
self, text: str, transform: Callable[[str], str] | None = None
|
||||
) -> list[TextChunk]:
|
||||
"""Chunk method definition."""
|
||||
|
||||
@ -5,26 +5,28 @@
|
||||
|
||||
from collections.abc import Callable
|
||||
|
||||
from graphrag_chunking.chunk_result import ChunkResult
|
||||
from graphrag_chunking.text_chunk import TextChunk
|
||||
|
||||
|
||||
def create_chunk_results(
|
||||
chunks: list[str],
|
||||
transform: Callable[[str], str] | None = None,
|
||||
encode: Callable[[str], list[int]] | None = None,
|
||||
) -> list[ChunkResult]:
|
||||
"""Create chunk results from a list of text chunks. The index assignments are 0-based and assume chunks we not stripped relative to the source text."""
|
||||
) -> list[TextChunk]:
|
||||
"""Create chunk results from a list of text chunks. The index assignments are 0-based and assume chunks were not stripped relative to the source text."""
|
||||
results = []
|
||||
start_char = 0
|
||||
for index, chunk in enumerate(chunks):
|
||||
end_char = start_char + len(chunk) - 1 # 0-based indices
|
||||
chunk = ChunkResult(
|
||||
text=chunk,
|
||||
result = TextChunk(
|
||||
original=chunk,
|
||||
text=transform(chunk) if transform else chunk,
|
||||
index=index,
|
||||
start_char=start_char,
|
||||
end_char=end_char,
|
||||
)
|
||||
if encode:
|
||||
chunk.token_count = len(encode(chunk.text))
|
||||
results.append(chunk)
|
||||
result.token_count = len(encode(result.text))
|
||||
results.append(result)
|
||||
start_char = end_char + 1
|
||||
return results
|
||||
|
||||
@ -9,9 +9,9 @@ from typing import Any
|
||||
import nltk
|
||||
|
||||
from graphrag_chunking.bootstrap_nltk import bootstrap
|
||||
from graphrag_chunking.chunk_result import ChunkResult
|
||||
from graphrag_chunking.chunker import Chunker
|
||||
from graphrag_chunking.create_chunk_results import create_chunk_results
|
||||
from graphrag_chunking.text_chunk import TextChunk
|
||||
|
||||
|
||||
class SentenceChunker(Chunker):
|
||||
@ -24,10 +24,14 @@ class SentenceChunker(Chunker):
|
||||
self._encode = encode
|
||||
bootstrap()
|
||||
|
||||
def chunk(self, text) -> list[ChunkResult]:
|
||||
def chunk(
|
||||
self, text: str, transform: Callable[[str], str] | None = None
|
||||
) -> list[TextChunk]:
|
||||
"""Chunk the text into sentence-based chunks."""
|
||||
sentences = nltk.sent_tokenize(text.strip())
|
||||
results = create_chunk_results(sentences, encode=self._encode)
|
||||
results = create_chunk_results(
|
||||
sentences, transform=transform, encode=self._encode
|
||||
)
|
||||
# nltk sentence tokenizer may trim whitespace, so we need to adjust start/end chars
|
||||
for index, result in enumerate(results):
|
||||
txt = result.text
|
||||
|
||||
29
packages/graphrag-chunking/graphrag_chunking/text_chunk.py
Normal file
29
packages/graphrag-chunking/graphrag_chunking/text_chunk.py
Normal file
@ -0,0 +1,29 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""The TextChunk dataclass."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class TextChunk:
|
||||
"""Result of chunking a document."""
|
||||
|
||||
original: str
|
||||
"""Raw original text chunk before any transformation."""
|
||||
|
||||
text: str
|
||||
"""The final text content of this chunk."""
|
||||
|
||||
index: int
|
||||
"""Zero-based index of this chunk within the source document."""
|
||||
|
||||
start_char: int
|
||||
"""Character index where the raw chunk text begins in the source document."""
|
||||
|
||||
end_char: int
|
||||
"""Character index where the raw chunk text ends in the source document."""
|
||||
|
||||
token_count: int | None = None
|
||||
"""Number of tokens in the final chunk text, if computed."""
|
||||
@ -6,9 +6,9 @@
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
from graphrag_chunking.chunk_result import ChunkResult
|
||||
from graphrag_chunking.chunker import Chunker
|
||||
from graphrag_chunking.create_chunk_results import create_chunk_results
|
||||
from graphrag_chunking.text_chunk import TextChunk
|
||||
|
||||
|
||||
class TokenChunker(Chunker):
|
||||
@ -28,7 +28,9 @@ class TokenChunker(Chunker):
|
||||
self._encode = encode
|
||||
self._decode = decode
|
||||
|
||||
def chunk(self, text: str) -> list[ChunkResult]:
|
||||
def chunk(
|
||||
self, text: str, transform: Callable[[str], str] | None = None
|
||||
) -> list[TextChunk]:
|
||||
"""Chunk the text into token-based chunks."""
|
||||
chunks = split_text_on_tokens(
|
||||
text,
|
||||
@ -37,7 +39,7 @@ class TokenChunker(Chunker):
|
||||
encode=self._encode,
|
||||
decode=self._decode,
|
||||
)
|
||||
return create_chunk_results(chunks, encode=self._encode)
|
||||
return create_chunk_results(chunks, transform=transform, encode=self._encode)
|
||||
|
||||
|
||||
def split_text_on_tokens(
|
||||
|
||||
25
packages/graphrag-chunking/graphrag_chunking/transformers.py
Normal file
25
packages/graphrag-chunking/graphrag_chunking/transformers.py
Normal file
@ -0,0 +1,25 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""A collection of useful built-in transformers you can use for chunking."""
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
|
||||
def add_metadata(
|
||||
metadata: dict[str, Any],
|
||||
delimiter: str = ": ",
|
||||
line_delimiter: str = "\n",
|
||||
append: bool = False,
|
||||
) -> Callable[[str], str]:
|
||||
"""Add metadata to the given text, prepending by default. This utility writes the dict as rows of key/value pairs."""
|
||||
|
||||
def transformer(text: str) -> str:
|
||||
metadata_str = (
|
||||
line_delimiter.join(f"{k}{delimiter}{v}" for k, v in metadata.items())
|
||||
+ line_delimiter
|
||||
)
|
||||
return text + metadata_str if append else metadata_str + text
|
||||
|
||||
return transformer
|
||||
@ -27,27 +27,33 @@ class TextDocument:
|
||||
raw_data: dict[str, Any] | None = None
|
||||
"""Raw data from source document."""
|
||||
|
||||
def pluck_metadata(self, fields: list[str]) -> dict[str, Any]:
|
||||
"""Extract metadata fields from a TextDocument.
|
||||
def get(self, field: str, default_value: Any = None) -> Any:
|
||||
"""
|
||||
Get a single field from the TextDocument.
|
||||
|
||||
Functions like the get method on a dictionary, returning default_value if the field is not found.
|
||||
|
||||
Supports nested fields using dot notation.
|
||||
|
||||
This takes a two step approach for flexibility:
|
||||
1. If the field is one of the standard text document fields (id, title, text, creation_date), just grab it directly. This accommodates unstructured text for example, which just has the standard fields.
|
||||
2. Otherwise. try to extract it from the raw_data dict. This allows users to specify any column from the original input file.
|
||||
|
||||
If a field does not exist in either location, we'll throw because that means the user config is incorrect.
|
||||
"""
|
||||
metadata = {}
|
||||
if field in ["id", "title", "text", "creation_date"]:
|
||||
return getattr(self, field)
|
||||
|
||||
raw = self.raw_data or {}
|
||||
try:
|
||||
return get_property(raw, field)
|
||||
except KeyError:
|
||||
return default_value
|
||||
|
||||
def collect(self, fields: list[str]) -> dict[str, Any]:
|
||||
"""Extract data fields from a TextDocument into a dict."""
|
||||
data = {}
|
||||
for field in fields:
|
||||
if field in ["id", "title", "text", "creation_date"]:
|
||||
value = getattr(self, field)
|
||||
else:
|
||||
raw = self.raw_data or {}
|
||||
value = get_property(raw, field)
|
||||
if value is None:
|
||||
logger.warning(
|
||||
"Metadata field '%s' not found in TextDocument standard fields or raw_data. Please check your configuration.",
|
||||
field,
|
||||
)
|
||||
value = self.get(field)
|
||||
if value is not None:
|
||||
metadata[field] = value
|
||||
return metadata
|
||||
data[field] = value
|
||||
return data
|
||||
|
||||
@ -7,9 +7,9 @@ import logging
|
||||
from typing import Any, cast
|
||||
|
||||
import pandas as pd
|
||||
from graphrag_chunking.add_metadata import add_metadata
|
||||
from graphrag_chunking.chunker import Chunker
|
||||
from graphrag_chunking.chunker_factory import create_chunker
|
||||
from graphrag_chunking.transformers import add_metadata
|
||||
from graphrag_input import TextDocument
|
||||
|
||||
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
|
||||
@ -66,8 +66,6 @@ def create_base_text_units(
|
||||
logger.info("Starting chunking process for %d documents", total_rows)
|
||||
|
||||
def chunker_with_logging(row: pd.Series, row_index: int) -> Any:
|
||||
row["chunks"] = [chunk.text for chunk in chunker.chunk(row["text"])]
|
||||
|
||||
if prepend_metadata:
|
||||
# create a standard text document for metadata plucking
|
||||
# ignore any additional fields in case the input dataframe has extra columns
|
||||
@ -76,13 +74,17 @@ def create_base_text_units(
|
||||
title=row["title"],
|
||||
text=row["text"],
|
||||
creation_date=row["creation_date"],
|
||||
raw_data=row.get("raw_data", None),
|
||||
raw_data=row["raw_data"],
|
||||
)
|
||||
metadata = document.pluck_metadata(prepend_metadata)
|
||||
row["chunks"] = [
|
||||
add_metadata(chunk, metadata, line_delimiter=".\n")
|
||||
for chunk in row["chunks"]
|
||||
]
|
||||
metadata = document.collect(prepend_metadata)
|
||||
transformer = add_metadata(metadata=metadata, line_delimiter=".\n") # delim with . for back-compat older indexes
|
||||
else:
|
||||
transformer = None
|
||||
|
||||
row["chunks"] = [
|
||||
chunk.text for chunk in chunker.chunk(row["text"], transform=transformer)
|
||||
]
|
||||
|
||||
tick()
|
||||
logger.info("chunker progress: %d/%d", row_index + 1, total_rows)
|
||||
return row
|
||||
|
||||
@ -1,14 +1,15 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
from graphrag_chunking.add_metadata import add_metadata
|
||||
from graphrag_chunking.transformers import add_metadata
|
||||
|
||||
|
||||
def test_add_metadata_one_row():
|
||||
"""Test prepending metadata to chunks"""
|
||||
chunks = ["This is a test.", "Another sentence."]
|
||||
metadata = {"message": "hello"}
|
||||
results = [add_metadata(chunk, metadata) for chunk in chunks]
|
||||
transformer = add_metadata(metadata)
|
||||
results = [transformer(chunk) for chunk in chunks]
|
||||
assert results[0] == "message: hello\nThis is a test."
|
||||
assert results[1] == "message: hello\nAnother sentence."
|
||||
|
||||
@ -17,7 +18,8 @@ def test_add_metadata_one_row_append():
|
||||
"""Test prepending metadata to chunks"""
|
||||
chunks = ["This is a test.", "Another sentence."]
|
||||
metadata = {"message": "hello"}
|
||||
results = [add_metadata(chunk, metadata, append=True) for chunk in chunks]
|
||||
transformer = add_metadata(metadata, append=True)
|
||||
results = [transformer(chunk) for chunk in chunks]
|
||||
assert results[0] == "This is a test.message: hello\n"
|
||||
assert results[1] == "Another sentence.message: hello\n"
|
||||
|
||||
@ -26,7 +28,8 @@ def test_add_metadata_multiple_rows():
|
||||
"""Test prepending metadata to chunks"""
|
||||
chunks = ["This is a test.", "Another sentence."]
|
||||
metadata = {"message": "hello", "tag": "first"}
|
||||
results = [add_metadata(chunk, metadata) for chunk in chunks]
|
||||
transformer = add_metadata(metadata)
|
||||
results = [transformer(chunk) for chunk in chunks]
|
||||
assert results[0] == "message: hello\ntag: first\nThis is a test."
|
||||
assert results[1] == "message: hello\ntag: first\nAnother sentence."
|
||||
|
||||
@ -35,9 +38,7 @@ def test_add_metadata_custom_delimiters():
|
||||
"""Test prepending metadata to chunks"""
|
||||
chunks = ["This is a test.", "Another sentence."]
|
||||
metadata = {"message": "hello", "tag": "first"}
|
||||
results = [
|
||||
add_metadata(chunk, metadata, delimiter="-", line_delimiter="_")
|
||||
for chunk in chunks
|
||||
]
|
||||
transformer = add_metadata(metadata, delimiter="-", line_delimiter="_")
|
||||
results = [transformer(chunk) for chunk in chunks]
|
||||
assert results[0] == "message-hello_tag-first_This is a test."
|
||||
assert results[1] == "message-hello_tag-first_Another sentence."
|
||||
|
||||
Loading…
Reference in New Issue
Block a user