mirror of
https://github.com/microsoft/graphrag.git
synced 2026-01-14 00:57:23 +08:00
Move metadata prepending to a util
This commit is contained in:
parent
c8dbb029f4
commit
247547f5bc
@ -15,5 +15,5 @@ class Chunker(ABC):
|
||||
"""Create a chunker instance."""
|
||||
|
||||
@abstractmethod
|
||||
def chunk(self, text: str, metadata: dict | None = None) -> list[str]:
|
||||
def chunk(self, text: str) -> list[str]:
|
||||
"""Chunk method definition."""
|
||||
|
||||
@ -32,5 +32,5 @@ class ChunkingConfig(BaseModel):
|
||||
)
|
||||
prepend_metadata: bool | None = Field(
|
||||
description="Prepend metadata into each chunk.",
|
||||
default=None,
|
||||
default=False,
|
||||
)
|
||||
|
||||
15
packages/graphrag/graphrag/chunking/prepend_metadata.py
Normal file
15
packages/graphrag/graphrag/chunking/prepend_metadata.py
Normal file
@ -0,0 +1,15 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""A module containing 'prepend_metadata' function."""
|
||||
|
||||
|
||||
def prepend_metadata(
|
||||
text: str, metadata: dict, delimiter: str = ": ", line_delimiter: str = "\n"
|
||||
) -> str:
|
||||
"""Prepend metadata to the given text. This utility writes the dict as rows of key/value pairs."""
|
||||
metadata_str = (
|
||||
line_delimiter.join(f"{k}{delimiter}{v}" for k, v in metadata.items())
|
||||
+ line_delimiter
|
||||
)
|
||||
return metadata_str + text
|
||||
@ -14,16 +14,10 @@ from graphrag.chunking.chunker import Chunker
|
||||
class SentenceChunker(Chunker):
|
||||
"""A chunker that splits text into sentence-based chunks."""
|
||||
|
||||
def __init__(self, prepend_metadata: bool = False, **kwargs: Any) -> None:
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
"""Create a sentence chunker instance."""
|
||||
self._prepend_metadata = prepend_metadata
|
||||
bootstrap()
|
||||
|
||||
def chunk(self, text: str, metadata: dict | None = None) -> list[str]:
|
||||
def chunk(self, text) -> list[str]:
|
||||
"""Chunk the text into sentence-based chunks."""
|
||||
chunks = nltk.sent_tokenize(text)
|
||||
|
||||
if self._prepend_metadata and metadata is not None:
|
||||
metadata_str = ".\n".join(f"{k}: {v}" for k, v in metadata.items()) + ".\n"
|
||||
chunks = [metadata_str + chunk for chunk in chunks]
|
||||
return chunks
|
||||
return nltk.sent_tokenize(text)
|
||||
|
||||
@ -23,18 +23,16 @@ class TokenChunker(Chunker):
|
||||
size: int,
|
||||
overlap: int,
|
||||
tokenizer: Tokenizer,
|
||||
prepend_metadata: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Create a token chunker instance."""
|
||||
self._size = size
|
||||
self._overlap = overlap
|
||||
self._prepend_metadata = prepend_metadata
|
||||
self._tokenizer = tokenizer
|
||||
|
||||
def chunk(self, text: str, metadata: dict | None = None) -> list[str]:
|
||||
def chunk(self, text: str) -> list[str]:
|
||||
"""Chunk the text into token-based chunks."""
|
||||
chunks = split_text_on_tokens(
|
||||
return split_text_on_tokens(
|
||||
text,
|
||||
chunk_size=self._size,
|
||||
chunk_overlap=self._overlap,
|
||||
@ -42,12 +40,6 @@ class TokenChunker(Chunker):
|
||||
decode=self._tokenizer.decode,
|
||||
)
|
||||
|
||||
if self._prepend_metadata and metadata is not None:
|
||||
metadata_str = ".\n".join(f"{k}: {v}" for k, v in metadata.items()) + ".\n"
|
||||
chunks = [metadata_str + chunk for chunk in chunks]
|
||||
|
||||
return chunks
|
||||
|
||||
|
||||
def split_text_on_tokens(
|
||||
text: str,
|
||||
|
||||
@ -13,6 +13,7 @@ from graphrag_common.types.tokenizer import Tokenizer
|
||||
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
|
||||
from graphrag.chunking.chunker import Chunker
|
||||
from graphrag.chunking.chunker_factory import create_chunker
|
||||
from graphrag.chunking.prepend_metadata import prepend_metadata as prepend_metadata_fn
|
||||
from graphrag.config.models.graph_rag_config import GraphRagConfig
|
||||
from graphrag.index.typing.context import PipelineRunContext
|
||||
from graphrag.index.typing.workflow import WorkflowFunctionOutput
|
||||
@ -39,6 +40,7 @@ async def run_workflow(
|
||||
context.callbacks,
|
||||
tokenizer=tokenizer,
|
||||
chunker=chunker,
|
||||
prepend_metadata=config.chunks.prepend_metadata,
|
||||
)
|
||||
|
||||
await write_table_to_storage(output, "text_units", context.output_storage)
|
||||
@ -52,6 +54,7 @@ def create_base_text_units(
|
||||
callbacks: WorkflowCallbacks,
|
||||
tokenizer: Tokenizer,
|
||||
chunker: Chunker,
|
||||
prepend_metadata: bool | None = False,
|
||||
) -> pd.DataFrame:
|
||||
"""All the steps to transform base text_units."""
|
||||
documents.sort_values(by=["id"], ascending=[True], inplace=True)
|
||||
@ -63,10 +66,12 @@ def create_base_text_units(
|
||||
logger.info("Starting chunking process for %d documents", total_rows)
|
||||
|
||||
def chunker_with_logging(row: pd.Series, row_index: int) -> Any:
|
||||
metadata = row.get("metadata")
|
||||
if (metadata is not None) and isinstance(metadata, str):
|
||||
metadata = json.loads(metadata)
|
||||
row["chunks"] = chunker.chunk(row["text"], metadata=metadata)
|
||||
row["chunks"] = chunker.chunk(row["text"])
|
||||
|
||||
metadata = row.get("metadata", None)
|
||||
if prepend_metadata and metadata is not None:
|
||||
metadata = json.loads(metadata) if isinstance(metadata, str) else metadata
|
||||
row["chunks"] = [prepend_metadata_fn(chunk, metadata) for chunk in row["chunks"]]
|
||||
tick()
|
||||
logger.info("chunker progress: %d/%d", row_index + 1, total_rows)
|
||||
return row
|
||||
|
||||
@ -48,18 +48,6 @@ class TestRunSentences:
|
||||
assert chunks[0] == " Sentence with spaces."
|
||||
assert chunks[1] == "Another one!"
|
||||
|
||||
def test_prepend_metadata(self):
|
||||
"""Test prepending metadata to chunks"""
|
||||
input = "This is a test. Another sentence."
|
||||
config = ChunkingConfig(
|
||||
strategy=ChunkStrategyType.Sentence, prepend_metadata=True
|
||||
)
|
||||
chunker = create_chunker(config)
|
||||
chunks = chunker.chunk(input, metadata={"message": "hello"})
|
||||
|
||||
assert chunks[0] == "message: hello.\nThis is a test."
|
||||
assert chunks[1] == "message: hello.\nAnother sentence."
|
||||
|
||||
|
||||
class TestRunTokens:
|
||||
@patch("tiktoken.get_encoding")
|
||||
@ -83,20 +71,6 @@ class TestRunTokens:
|
||||
|
||||
assert len(chunks) > 0
|
||||
|
||||
def test_prepend_metadata(self):
|
||||
"""Test prepending metadata to chunks"""
|
||||
mocked_tokenizer = MockTokenizer()
|
||||
input = "This is a test."
|
||||
config = ChunkingConfig(
|
||||
strategy=ChunkStrategyType.Tokens, size=5, overlap=0, prepend_metadata=True
|
||||
)
|
||||
chunker = create_chunker(config, tokenizer=mocked_tokenizer)
|
||||
chunks = chunker.chunk(input, metadata={"message": "hello"})
|
||||
|
||||
assert chunks[0] == "message: hello.\nThis "
|
||||
assert chunks[1] == "message: hello.\nis a "
|
||||
assert chunks[2] == "message: hello.\ntest."
|
||||
|
||||
|
||||
def test_split_text_str_empty():
|
||||
result = split_text_on_tokens(
|
||||
|
||||
34
tests/unit/chunking/test_prepend_metadata.py
Normal file
34
tests/unit/chunking/test_prepend_metadata.py
Normal file
@ -0,0 +1,34 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
from graphrag.chunking.prepend_metadata import prepend_metadata
|
||||
|
||||
|
||||
def test_prepend_metadata_one_row():
|
||||
"""Test prepending metadata to chunks"""
|
||||
chunks = ["This is a test.", "Another sentence."]
|
||||
metadata = {"message": "hello"}
|
||||
results = [prepend_metadata(chunk, metadata) for chunk in chunks]
|
||||
assert results[0] == "message: hello\nThis is a test."
|
||||
assert results[1] == "message: hello\nAnother sentence."
|
||||
|
||||
|
||||
def test_prepend_metadata_multiple_rows():
|
||||
"""Test prepending metadata to chunks"""
|
||||
chunks = ["This is a test.", "Another sentence."]
|
||||
metadata = {"message": "hello", "tag": "first"}
|
||||
results = [prepend_metadata(chunk, metadata) for chunk in chunks]
|
||||
assert results[0] == "message: hello\ntag: first\nThis is a test."
|
||||
assert results[1] == "message: hello\ntag: first\nAnother sentence."
|
||||
|
||||
|
||||
def test_prepend_metadata_custom_delimiters():
|
||||
"""Test prepending metadata to chunks"""
|
||||
chunks = ["This is a test.", "Another sentence."]
|
||||
metadata = {"message": "hello", "tag": "first"}
|
||||
results = [
|
||||
prepend_metadata(chunk, metadata, delimiter="-", line_delimiter="_")
|
||||
for chunk in chunks
|
||||
]
|
||||
assert results[0] == "message-hello_tag-first_This is a test."
|
||||
assert results[1] == "message-hello_tag-first_Another sentence."
|
||||
Loading…
Reference in New Issue
Block a user