add metadata to text

This commit is contained in:
Dayenne Souza 2025-01-30 17:12:01 -03:00
parent ad7144eb63
commit 574736c825
7 changed files with 91 additions and 24 deletions

View File

@ -5,6 +5,7 @@
import logging
import re
from datetime import datetime
from pathlib import Path
from typing import Any
@ -30,7 +31,10 @@ async def load(
"""Load text inputs from a directory."""
async def load_file(
path: str, group: dict | None = None, _encoding: str = "utf-8"
path: str,
group: dict | None = None,
_encoding: str = "utf-8",
metadata: list[str] | None = None,
) -> dict[str, Any]:
if group is None:
group = {}
@ -38,6 +42,8 @@ async def load(
new_item = {**group, "text": text}
new_item["id"] = gen_sha512_hash(new_item, new_item.keys())
new_item["title"] = str(Path(path).name)
if metadata and "creation_date" in metadata:
new_item["creation_date"] = storage.get_creation_date(path)
return new_item
files = list(
@ -57,7 +63,7 @@ async def load(
for file, group in files:
try:
files_loaded.append(await load_file(file, group))
files_loaded.append(await load_file(file, group, metadata=config.metadata))
except Exception: # noqa: BLE001 (catching Exception is fine here)
log.warning("Warning! Error loading file %s. Skipping...", file)

View File

@ -174,14 +174,15 @@ def split_multiple_texts_on_tokens(
else ""
)
metadata_tokens = tokenizer.encode(metadata_str)
if len(metadata_tokens) >= tokenizer.tokens_per_chunk:
message = "Metadata tokens exceed the maximum tokens per chunk. Please increase the tokens per chunk."
raise ValueError(message)
# Adjust tokenizer to account for metadata tokens
adjusted_tokenizer = Tokenizer(
chunk_overlap=tokenizer.chunk_overlap,
tokens_per_chunk=tokenizer.tokens_per_chunk - len(metadata_tokens),
decode=tokenizer.decode,
encode=tokenizer.encode,
)
adjusted_tokenizer = Tokenizer(**{
**tokenizer.__dict__,
"tokens_per_chunk": tokenizer.tokens_per_chunk - len(metadata_tokens),
})
input_ids = []
for source_doc_idx, text in enumerate(texts):

View File

@ -187,6 +187,19 @@ class BlobPipelineStorage(PipelineStorage):
else:
return blob_data
def get_creation_date(self, key: str) -> str:
"""Get a value from the cache."""
try:
key = self._keyname(key)
container_client = self._blob_service_client.get_container_client(
self._container_name
)
blob_client = container_client.get_blob_client(key)
return str(blob_client.download_blob().properties.creation_time)
except Exception:
log.exception("Error getting key %s", key)
return ""
async def set(self, key: str, value: Any, encoding: str | None = None) -> None:
"""Set a value in the cache."""
try:

View File

@ -118,6 +118,10 @@ class FilePipelineStorage(PipelineStorage):
) as f:
await f.write(value)
def get_creation_date(self, key: str) -> str:
"""Get the creation date of a file."""
return str(Path(join_path(self._root_dir, key)).stat().st_ctime)
async def has(self, key: str) -> bool:
"""Has method definition."""
return await exists(join_path(self._root_dir, key))

View File

@ -80,3 +80,15 @@ class PipelineStorage(metaclass=ABCMeta):
@abstractmethod
def keys(self) -> list[str]:
"""List all keys in the storage."""
@abstractmethod
def get_creation_date(self, key: str) -> str:
"""Get the creation date for the given key.
Args:
- key - The key to get the creation date for.
Returns
-------
- output - The creation date for the given key.
"""

View File

@ -67,19 +67,15 @@ class TestRunSentences:
class TestRunTokens:
@patch("tiktoken.get_encoding")
def test_basic_functionality(self, mock_get_encoding):
"""Test basic token-based chunking."""
# Mock tiktoken encoding
mock_encoder = Mock()
mock_encoder.encode.side_effect = lambda x: list(
x.encode()
) # Simulate encoding
mock_encoder.decode.side_effect = lambda x: bytes(
x
).decode() # Simulate decoding
mock_encoder.encode.side_effect = lambda x: list(x.encode())
mock_encoder.decode.side_effect = lambda x: bytes(x).decode()
mock_get_encoding.return_value = mock_encoder
# Input and config
input = ["hello world"]
input = [
"Marley was dead: to begin with. There is no doubt whatever about that. The register of his burial was signed by the clergyman, the clerk, the undertaker, and the chief mourner. Scrooge signed it. And Scrooge's name was good upon 'Change, for anything he chose to put his hand to."
]
config = ChunkingConfig(size=5, overlap=1, encoding_model="fake-encoding")
tick = Mock()
@ -87,7 +83,7 @@ class TestRunTokens:
chunks = list(run_tokens(input, config, tick))
# Verify output
assert len(chunks) > 0 # At least one chunk should be produced
assert len(chunks) > 0
assert all(isinstance(chunk, TextChunk) for chunk in chunks)
tick.assert_called_once_with(1)
@ -99,16 +95,18 @@ class TestRunTokens:
mock_encoder.decode.side_effect = lambda x: bytes(x).decode()
mock_get_encoding.return_value = mock_encoder
input = ["test"]
config = ChunkingConfig(size=5, overlap=1, encoding_model="fake-encoding")
input = [
"Marley was dead: to begin with. There is no doubt whatever about that. The register of his burial was signed by the clergyman, the clerk, the undertaker, and the chief mourner. Scrooge signed it. And Scrooge's name was good upon 'Change, for anything he chose to put his hand to."
]
config = ChunkingConfig(size=50, overlap=4, encoding_model="fake-encoding")
tick = Mock()
metadata = {"author": "John"}
metadata = {"author": "Charles"}
chunks = list(run_tokens(input, config, tick, metadata))
# Verify metadata is included in the chunk
assert len(chunks) > 0
assert "author: John" in chunks[0].text_chunk
assert "author: Charles" in chunks[0].text_chunk
@patch("tiktoken.get_encoding")
def test_custom_delimiter(self, mock_get_encoding):
@ -118,8 +116,10 @@ class TestRunTokens:
mock_encoder.decode.side_effect = lambda x: bytes(x).decode()
mock_get_encoding.return_value = mock_encoder
input = ["test"]
config = ChunkingConfig(size=5, overlap=1, encoding_model="fake-encoding")
input = [
"Marley was dead: to begin with. There is no doubt whatever about that. The register of his burial was signed by the clergyman, the clerk, the undertaker, and the chief mourner. Scrooge signed it. And Scrooge's name was good upon 'Change, for anything he chose to put his hand to."
]
config = ChunkingConfig(size=50, overlap=4, encoding_model="fake-encoding")
tick = Mock()
metadata = {"key": "value"}

View File

@ -7,6 +7,7 @@ from unittest.mock import MagicMock
import pandas as pd
import pytest
import tiktoken
from pydantic import ValidationError
from graphrag.index.operations.chunk_text.typing import TextChunk
from graphrag.index.text_splitting.text_splitting import (
@ -261,6 +262,36 @@ def test_split_multiple_texts_on_tokens_metadata_one_column():
)
assert split == expected
def test_split_multiple_texts_on_tokens_metadata_large():
input_df = pd.DataFrame({
"text": ["Receptionist", "Officer", "Captain"],
"command": ["Jump", "Walk", "Run"],
"metadata": [
"Table 1 red with glass",
"Office 1 with central table and 16 chairs",
"Ship 1 with weird name",
],
})
mocked_tokenizer = MockTokenizer()
mock_tick = MagicMock()
tokenizer = Tokenizer(
chunk_overlap=0,
tokens_per_chunk=5,
decode=mocked_tokenizer.decode,
encode=lambda text: mocked_tokenizer.encode(text),
)
texts = input_df["text"].to_numpy().tolist()
metadata = input_df["metadata"].to_numpy().tolist()
with pytest.raises(
ValueError,
match="Metadata tokens exceed the maximum tokens per chunk. Please increase the tokens per chunk.",
):
split_multiple_texts_on_tokens(
[texts[0]], tokenizer, tick=mock_tick, metadata={"metadata": metadata[0]}
)
def test_split_multiple_texts_on_tokens_metadata_two_columns():
input_df = pd.DataFrame({