mirror of
https://github.com/microsoft/graphrag.git
synced 2026-01-14 00:57:23 +08:00
Add ChunkResult model
This commit is contained in:
parent
ee20153d8c
commit
a741bfb8d7
17
packages/graphrag-chunking/graphrag_chunking/chunk_result.py
Normal file
17
packages/graphrag-chunking/graphrag_chunking/chunk_result.py
Normal file
@ -0,0 +1,17 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""The ChunkResult dataclass."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChunkResult:
|
||||
"""Result of chunking a document."""
|
||||
|
||||
text: str
|
||||
index: int
|
||||
start_char: int
|
||||
end_char: int
|
||||
token_count: int | None = None
|
||||
@ -6,6 +6,8 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
|
||||
from graphrag_chunking.chunk_result import ChunkResult
|
||||
|
||||
|
||||
class Chunker(ABC):
|
||||
"""Abstract base class for document chunkers."""
|
||||
@ -15,5 +17,5 @@ class Chunker(ABC):
|
||||
"""Create a chunker instance."""
|
||||
|
||||
@abstractmethod
|
||||
def chunk(self, text: str) -> list[str]:
|
||||
def chunk(self, text: str) -> list[ChunkResult]:
|
||||
"""Chunk method definition."""
|
||||
|
||||
@ -0,0 +1,30 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""A module containing 'create_chunk_results' function."""
|
||||
|
||||
from collections.abc import Callable
|
||||
|
||||
from graphrag_chunking.chunk_result import ChunkResult
|
||||
|
||||
|
||||
def create_chunk_results(
|
||||
chunks: list[str],
|
||||
encode: Callable[[str], list[int]] | None = None,
|
||||
) -> list[ChunkResult]:
|
||||
"""Create chunk results from a list of text chunks. The index assignments are 0-based and assume chunks we not stripped relative to the source text."""
|
||||
results = []
|
||||
start_char = 0
|
||||
for index, chunk in enumerate(chunks):
|
||||
end_char = start_char + len(chunk) - 1 # 0-based indices
|
||||
chunk = ChunkResult(
|
||||
text=chunk,
|
||||
index=index,
|
||||
start_char=start_char,
|
||||
end_char=end_char,
|
||||
)
|
||||
if encode:
|
||||
chunk.token_count = len(encode(chunk.text))
|
||||
results.append(chunk)
|
||||
start_char = end_char + 1
|
||||
return results
|
||||
@ -3,21 +3,42 @@
|
||||
|
||||
"""A module containing 'SentenceChunker' class."""
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
import nltk
|
||||
|
||||
from graphrag_chunking.bootstrap_nltk import bootstrap
|
||||
from graphrag_chunking.chunk_result import ChunkResult
|
||||
from graphrag_chunking.chunker import Chunker
|
||||
from graphrag_chunking.create_chunk_results import create_chunk_results
|
||||
|
||||
|
||||
class SentenceChunker(Chunker):
|
||||
"""A chunker that splits text into sentence-based chunks."""
|
||||
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
def __init__(
|
||||
self, encode: Callable[[str], list[int]] | None = None, **kwargs: Any
|
||||
) -> None:
|
||||
"""Create a sentence chunker instance."""
|
||||
self._encode = encode
|
||||
bootstrap()
|
||||
|
||||
def chunk(self, text) -> list[str]:
|
||||
def chunk(self, text) -> list[ChunkResult]:
|
||||
"""Chunk the text into sentence-based chunks."""
|
||||
return nltk.sent_tokenize(text)
|
||||
sentences = nltk.sent_tokenize(text.strip())
|
||||
results = create_chunk_results(sentences, encode=self._encode)
|
||||
# nltk sentence tokenizer may trim whitespace, so we need to adjust start/end chars
|
||||
for index, result in enumerate(results):
|
||||
txt = result.text
|
||||
start = result.start_char
|
||||
actual_start = text.find(txt, start)
|
||||
delta = actual_start - start
|
||||
if delta > 0:
|
||||
result.start_char += delta
|
||||
result.end_char += delta
|
||||
# bump the next to keep the start check from falling too far behind
|
||||
if index < len(results) - 1:
|
||||
results[index + 1].start_char += delta
|
||||
results[index + 1].end_char += delta
|
||||
return results
|
||||
|
||||
@ -6,7 +6,9 @@
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
from graphrag_chunking.chunk_result import ChunkResult
|
||||
from graphrag_chunking.chunker import Chunker
|
||||
from graphrag_chunking.create_chunk_results import create_chunk_results
|
||||
|
||||
|
||||
class TokenChunker(Chunker):
|
||||
@ -26,15 +28,16 @@ class TokenChunker(Chunker):
|
||||
self._encode = encode
|
||||
self._decode = decode
|
||||
|
||||
def chunk(self, text: str) -> list[str]:
|
||||
def chunk(self, text: str) -> list[ChunkResult]:
|
||||
"""Chunk the text into token-based chunks."""
|
||||
return split_text_on_tokens(
|
||||
chunks = split_text_on_tokens(
|
||||
text,
|
||||
chunk_size=self._size,
|
||||
chunk_overlap=self._overlap,
|
||||
encode=self._encode,
|
||||
decode=self._decode,
|
||||
)
|
||||
return create_chunk_results(chunks, encode=self._encode)
|
||||
|
||||
|
||||
def split_text_on_tokens(
|
||||
|
||||
@ -66,7 +66,7 @@ def create_base_text_units(
|
||||
logger.info("Starting chunking process for %d documents", total_rows)
|
||||
|
||||
def chunker_with_logging(row: pd.Series, row_index: int) -> Any:
|
||||
row["chunks"] = chunker.chunk(row["text"])
|
||||
row["chunks"] = [chunk.text for chunk in chunker.chunk(row["text"])]
|
||||
|
||||
metadata = row.get("metadata", None)
|
||||
if prepend_metadata and metadata is not None:
|
||||
|
||||
@ -28,22 +28,42 @@ class TestRunSentences:
|
||||
|
||||
def test_basic_functionality(self):
|
||||
"""Test basic sentence splitting without metadata"""
|
||||
input = "This is a test. Another sentence."
|
||||
input = "This is a test. Another sentence. And a third one!"
|
||||
chunker = create_chunker(ChunkingConfig(strategy=ChunkStrategyType.Sentence))
|
||||
chunks = chunker.chunk(input)
|
||||
|
||||
assert len(chunks) == 3
|
||||
assert chunks[0].text == "This is a test."
|
||||
assert chunks[0].index == 0
|
||||
assert chunks[0].start_char == 0
|
||||
assert chunks[0].end_char == 14
|
||||
|
||||
assert chunks[1].text == "Another sentence."
|
||||
assert chunks[1].index == 1
|
||||
assert chunks[1].start_char == 16
|
||||
assert chunks[1].end_char == 32
|
||||
|
||||
assert chunks[2].text == "And a third one!"
|
||||
assert chunks[2].index == 2
|
||||
assert chunks[2].start_char == 34
|
||||
assert chunks[2].end_char == 49
|
||||
|
||||
def test_mixed_whitespace_handling(self):
|
||||
"""Test input with irregular whitespace"""
|
||||
input = " Sentence with spaces. Another one! "
|
||||
chunker = create_chunker(ChunkingConfig(strategy=ChunkStrategyType.Sentence))
|
||||
chunks = chunker.chunk(input)
|
||||
|
||||
assert len(chunks) == 2
|
||||
assert chunks[0] == "This is a test."
|
||||
assert chunks[1] == "Another sentence."
|
||||
assert chunks[0].text == "Sentence with spaces."
|
||||
assert chunks[0].index == 0
|
||||
assert chunks[0].start_char == 3
|
||||
assert chunks[0].end_char == 23
|
||||
|
||||
def test_mixed_whitespace_handling(self):
|
||||
"""Test input with irregular whitespace"""
|
||||
|
||||
input = " Sentence with spaces. Another one! "
|
||||
chunker = create_chunker(ChunkingConfig(strategy=ChunkStrategyType.Sentence))
|
||||
chunks = chunker.chunk(input)
|
||||
assert chunks[0] == " Sentence with spaces."
|
||||
assert chunks[1] == "Another one!"
|
||||
assert chunks[1].text == "Another one!"
|
||||
assert chunks[1].index == 1
|
||||
assert chunks[1].start_char == 25
|
||||
assert chunks[1].end_char == 36
|
||||
|
||||
|
||||
class TestRunTokens:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user