mirror of
https://github.com/microsoft/graphrag.git
synced 2026-01-14 00:57:23 +08:00
Add prepending tests
This commit is contained in:
parent
eb22d7a61b
commit
780a03827c
@ -15,5 +15,5 @@ class Chunker(ABC):
|
||||
"""Create a chunker instance."""
|
||||
|
||||
@abstractmethod
|
||||
def chunk(self, text: str, metadata: str | dict | None = None) -> list[str]:
|
||||
def chunk(self, text: str, metadata: dict | None = None) -> list[str]:
|
||||
"""Chunk method definition."""
|
||||
|
||||
@ -3,7 +3,6 @@
|
||||
|
||||
"""A module containing 'SentenceChunker' class."""
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
import nltk
|
||||
@ -20,19 +19,11 @@ class SentenceChunker(Chunker):
|
||||
self._prepend_metadata = prepend_metadata
|
||||
bootstrap()
|
||||
|
||||
def chunk(self, text: str, metadata: str | dict | None = None) -> list[str]:
|
||||
def chunk(self, text: str, metadata: dict | None = None) -> list[str]:
|
||||
"""Chunk the text into sentence-based chunks."""
|
||||
chunks = nltk.sent_tokenize(text)
|
||||
|
||||
if self._prepend_metadata and metadata is not None:
|
||||
line_delimiter = ".\n"
|
||||
metadata_str = ""
|
||||
if isinstance(metadata, str):
|
||||
metadata = json.loads(metadata)
|
||||
if isinstance(metadata, dict):
|
||||
metadata_str = (
|
||||
line_delimiter.join(f"{k}: {v}" for k, v in metadata.items())
|
||||
+ line_delimiter
|
||||
)
|
||||
metadata_str = ".\n".join(f"{k}: {v}" for k, v in metadata.items()) + ".\n"
|
||||
chunks = [metadata_str + chunk for chunk in chunks]
|
||||
return chunks
|
||||
|
||||
@ -3,7 +3,6 @@
|
||||
|
||||
"""A module containing 'TokenChunker' class."""
|
||||
|
||||
import json
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
@ -23,7 +22,6 @@ class TokenChunker(Chunker):
|
||||
self,
|
||||
size: int,
|
||||
overlap: int,
|
||||
encoding_model: str,
|
||||
tokenizer: Tokenizer,
|
||||
prepend_metadata: bool = False,
|
||||
chunk_size_includes_metadata: bool = False,
|
||||
@ -32,25 +30,18 @@ class TokenChunker(Chunker):
|
||||
"""Create a token chunker instance."""
|
||||
self._size = size
|
||||
self._overlap = overlap
|
||||
self._encoding_model = encoding_model
|
||||
self._prepend_metadata = prepend_metadata
|
||||
self._chunk_size_includes_metadata = chunk_size_includes_metadata
|
||||
self._tokenizer = tokenizer
|
||||
|
||||
def chunk(self, text: str, metadata: str | dict | None = None) -> list[str]:
|
||||
def chunk(self, text: str, metadata: dict | None = None) -> list[str]:
|
||||
"""Chunk the text into token-based chunks."""
|
||||
line_delimiter = ".\n"
|
||||
# we have to create and measure the metadata first to account for the length when chunking
|
||||
metadata_str = ""
|
||||
metadata_tokens = 0
|
||||
|
||||
if self._prepend_metadata and metadata is not None:
|
||||
if isinstance(metadata, str):
|
||||
metadata = json.loads(metadata)
|
||||
if isinstance(metadata, dict):
|
||||
metadata_str = (
|
||||
line_delimiter.join(f"{k}: {v}" for k, v in metadata.items())
|
||||
+ line_delimiter
|
||||
)
|
||||
metadata_str = ".\n".join(f"{k}: {v}" for k, v in metadata.items()) + ".\n"
|
||||
|
||||
if self._chunk_size_includes_metadata:
|
||||
metadata_tokens = len(self._tokenizer.encode(metadata_str))
|
||||
|
||||
@ -3,6 +3,7 @@
|
||||
|
||||
"""A module containing run_workflow method definition."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, cast
|
||||
|
||||
@ -62,7 +63,10 @@ def create_base_text_units(
|
||||
logger.info("Starting chunking process for %d documents", total_rows)
|
||||
|
||||
def chunker_with_logging(row: pd.Series, row_index: int) -> Any:
|
||||
row["chunks"] = chunker.chunk(row["text"], row.get("metadata"))
|
||||
metadata = row.get("metadata")
|
||||
if (metadata is not None) and isinstance(metadata, str):
|
||||
metadata = json.loads(metadata)
|
||||
row["chunks"] = chunker.chunk(row["text"], metadata=metadata)
|
||||
tick()
|
||||
logger.info("chunker progress: %d/%d", row_index + 1, total_rows)
|
||||
return row
|
||||
|
||||
@ -11,14 +11,15 @@ from graphrag.chunking.token_chunker import (
|
||||
split_text_on_tokens,
|
||||
)
|
||||
from graphrag.tokenizer.get_tokenizer import get_tokenizer
|
||||
from graphrag_common.types.tokenizer import Tokenizer
|
||||
|
||||
|
||||
class MockTokenizer:
|
||||
def encode(self, text):
|
||||
class MockTokenizer(Tokenizer):
|
||||
def encode(self, text) -> list[int]:
|
||||
return [ord(char) for char in text]
|
||||
|
||||
def decode(self, token_ids):
|
||||
return "".join(chr(id) for id in token_ids)
|
||||
def decode(self, tokens) -> str:
|
||||
return "".join(chr(id) for id in tokens)
|
||||
|
||||
|
||||
tokenizer = get_tokenizer()
|
||||
@ -53,6 +54,18 @@ class TestRunSentences:
|
||||
assert chunks[0] == " Sentence with spaces."
|
||||
assert chunks[1] == "Another one!"
|
||||
|
||||
def test_prepend_metadata(self):
|
||||
"""Test prepending metadata to chunks"""
|
||||
input = "This is a test. Another sentence."
|
||||
config = ChunkingConfig(
|
||||
strategy=ChunkStrategyType.Sentence, prepend_metadata=True
|
||||
)
|
||||
chunker = create_chunker(config)
|
||||
chunks = chunker.chunk(input, metadata={"message": "hello"})
|
||||
|
||||
assert chunks[0] == "message: hello.\nThis is a test."
|
||||
assert chunks[1] == "message: hello.\nAnother sentence."
|
||||
|
||||
|
||||
class TestRunTokens:
|
||||
@patch("tiktoken.get_encoding")
|
||||
@ -75,6 +88,20 @@ class TestRunTokens:
|
||||
|
||||
assert len(chunks) > 0
|
||||
|
||||
def test_prepend_metadata(self):
|
||||
"""Test prepending metadata to chunks"""
|
||||
mocked_tokenizer = MockTokenizer()
|
||||
input = "This is a test."
|
||||
config = ChunkingConfig(
|
||||
strategy=ChunkStrategyType.Tokens, size=5, overlap=0, prepend_metadata=True
|
||||
)
|
||||
chunker = create_chunker(config, tokenizer=mocked_tokenizer)
|
||||
chunks = chunker.chunk(input, metadata={"message": "hello"})
|
||||
|
||||
assert chunks[0] == "message: hello.\nThis "
|
||||
assert chunks[1] == "message: hello.\nis a "
|
||||
assert chunks[2] == "message: hello.\ntest."
|
||||
|
||||
|
||||
def test_split_text_str_empty():
|
||||
result = split_text_on_tokens(
|
||||
@ -146,40 +173,3 @@ def test_split_text_on_tokens_no_overlap():
|
||||
encode=tok.encode,
|
||||
)
|
||||
assert result == expected_splits
|
||||
|
||||
|
||||
@patch("tiktoken.get_encoding")
|
||||
def test_get_encoding_fn_encode(mock_get_encoding):
|
||||
# Create a mock encoding object with encode and decode methods
|
||||
mock_encoding = Mock()
|
||||
mock_encoding.encode = Mock(return_value=[1, 2, 3])
|
||||
mock_encoding.decode = Mock(return_value="decoded text")
|
||||
|
||||
# Configure the mock_get_encoding to return the mock encoding object
|
||||
mock_get_encoding.return_value = mock_encoding
|
||||
|
||||
# Call the function to get encode and decode functions
|
||||
tokenizer = get_tokenizer(encoding_model="mock_encoding")
|
||||
|
||||
# Test the encode function
|
||||
encoded_text = tokenizer.encode("test text")
|
||||
assert encoded_text == [1, 2, 3]
|
||||
mock_encoding.encode.assert_called_once_with("test text")
|
||||
|
||||
|
||||
@patch("tiktoken.get_encoding")
|
||||
def test_get_encoding_fn_decode(mock_get_encoding):
|
||||
# Create a mock encoding object with encode and decode methods
|
||||
mock_encoding = Mock()
|
||||
mock_encoding.encode = Mock(return_value=[1, 2, 3])
|
||||
mock_encoding.decode = Mock(return_value="decoded text")
|
||||
|
||||
# Configure the mock_get_encoding to return the mock encoding object
|
||||
mock_get_encoding.return_value = mock_encoding
|
||||
|
||||
# Call the function to get encode and decode functions
|
||||
tokenizer = get_tokenizer(encoding_model="mock_encoding")
|
||||
|
||||
decoded_text = tokenizer.decode([1, 2, 3])
|
||||
assert decoded_text == "decoded text"
|
||||
mock_encoding.decode.assert_called_once_with([1, 2, 3])
|
||||
|
||||
Loading…
Reference in New Issue
Block a user