mirror of
https://github.com/microsoft/graphrag.git
synced 2026-01-14 00:57:23 +08:00
add metadata to text
This commit is contained in:
parent
ad7144eb63
commit
574736c825
@ -5,6 +5,7 @@
|
||||
|
||||
import logging
|
||||
import re
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
@ -30,7 +31,10 @@ async def load(
|
||||
"""Load text inputs from a directory."""
|
||||
|
||||
async def load_file(
|
||||
path: str, group: dict | None = None, _encoding: str = "utf-8"
|
||||
path: str,
|
||||
group: dict | None = None,
|
||||
_encoding: str = "utf-8",
|
||||
metadata: list[str] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
if group is None:
|
||||
group = {}
|
||||
@ -38,6 +42,8 @@ async def load(
|
||||
new_item = {**group, "text": text}
|
||||
new_item["id"] = gen_sha512_hash(new_item, new_item.keys())
|
||||
new_item["title"] = str(Path(path).name)
|
||||
if metadata and "creation_date" in metadata:
|
||||
new_item["creation_date"] = storage.get_creation_date(path)
|
||||
return new_item
|
||||
|
||||
files = list(
|
||||
@ -57,7 +63,7 @@ async def load(
|
||||
|
||||
for file, group in files:
|
||||
try:
|
||||
files_loaded.append(await load_file(file, group))
|
||||
files_loaded.append(await load_file(file, group, metadata=config.metadata))
|
||||
except Exception: # noqa: BLE001 (catching Exception is fine here)
|
||||
log.warning("Warning! Error loading file %s. Skipping...", file)
|
||||
|
||||
|
||||
@ -174,14 +174,15 @@ def split_multiple_texts_on_tokens(
|
||||
else ""
|
||||
)
|
||||
metadata_tokens = tokenizer.encode(metadata_str)
|
||||
if len(metadata_tokens) >= tokenizer.tokens_per_chunk:
|
||||
message = "Metadata tokens exceed the maximum tokens per chunk. Please increase the tokens per chunk."
|
||||
raise ValueError(message)
|
||||
|
||||
# Adjust tokenizer to account for metadata tokens
|
||||
adjusted_tokenizer = Tokenizer(
|
||||
chunk_overlap=tokenizer.chunk_overlap,
|
||||
tokens_per_chunk=tokenizer.tokens_per_chunk - len(metadata_tokens),
|
||||
decode=tokenizer.decode,
|
||||
encode=tokenizer.encode,
|
||||
)
|
||||
adjusted_tokenizer = Tokenizer(**{
|
||||
**tokenizer.__dict__,
|
||||
"tokens_per_chunk": tokenizer.tokens_per_chunk - len(metadata_tokens),
|
||||
})
|
||||
|
||||
input_ids = []
|
||||
for source_doc_idx, text in enumerate(texts):
|
||||
|
||||
@ -187,6 +187,19 @@ class BlobPipelineStorage(PipelineStorage):
|
||||
else:
|
||||
return blob_data
|
||||
|
||||
def get_creation_date(self, key: str) -> str:
|
||||
"""Get a value from the cache."""
|
||||
try:
|
||||
key = self._keyname(key)
|
||||
container_client = self._blob_service_client.get_container_client(
|
||||
self._container_name
|
||||
)
|
||||
blob_client = container_client.get_blob_client(key)
|
||||
return str(blob_client.download_blob().properties.creation_time)
|
||||
except Exception:
|
||||
log.exception("Error getting key %s", key)
|
||||
return ""
|
||||
|
||||
async def set(self, key: str, value: Any, encoding: str | None = None) -> None:
|
||||
"""Set a value in the cache."""
|
||||
try:
|
||||
|
||||
@ -118,6 +118,10 @@ class FilePipelineStorage(PipelineStorage):
|
||||
) as f:
|
||||
await f.write(value)
|
||||
|
||||
def get_creation_date(self, key: str) -> str:
|
||||
"""Get the creation date of a file."""
|
||||
return str(Path(join_path(self._root_dir, key)).stat().st_ctime)
|
||||
|
||||
async def has(self, key: str) -> bool:
|
||||
"""Has method definition."""
|
||||
return await exists(join_path(self._root_dir, key))
|
||||
|
||||
@ -80,3 +80,15 @@ class PipelineStorage(metaclass=ABCMeta):
|
||||
@abstractmethod
|
||||
def keys(self) -> list[str]:
|
||||
"""List all keys in the storage."""
|
||||
|
||||
@abstractmethod
|
||||
def get_creation_date(self, key: str) -> str:
|
||||
"""Get the creation date for the given key.
|
||||
|
||||
Args:
|
||||
- key - The key to get the creation date for.
|
||||
|
||||
Returns
|
||||
-------
|
||||
- output - The creation date for the given key.
|
||||
"""
|
||||
@ -67,19 +67,15 @@ class TestRunSentences:
|
||||
class TestRunTokens:
|
||||
@patch("tiktoken.get_encoding")
|
||||
def test_basic_functionality(self, mock_get_encoding):
|
||||
"""Test basic token-based chunking."""
|
||||
# Mock tiktoken encoding
|
||||
mock_encoder = Mock()
|
||||
mock_encoder.encode.side_effect = lambda x: list(
|
||||
x.encode()
|
||||
) # Simulate encoding
|
||||
mock_encoder.decode.side_effect = lambda x: bytes(
|
||||
x
|
||||
).decode() # Simulate decoding
|
||||
mock_encoder.encode.side_effect = lambda x: list(x.encode())
|
||||
mock_encoder.decode.side_effect = lambda x: bytes(x).decode()
|
||||
mock_get_encoding.return_value = mock_encoder
|
||||
|
||||
# Input and config
|
||||
input = ["hello world"]
|
||||
input = [
|
||||
"Marley was dead: to begin with. There is no doubt whatever about that. The register of his burial was signed by the clergyman, the clerk, the undertaker, and the chief mourner. Scrooge signed it. And Scrooge's name was good upon 'Change, for anything he chose to put his hand to."
|
||||
]
|
||||
config = ChunkingConfig(size=5, overlap=1, encoding_model="fake-encoding")
|
||||
tick = Mock()
|
||||
|
||||
@ -87,7 +83,7 @@ class TestRunTokens:
|
||||
chunks = list(run_tokens(input, config, tick))
|
||||
|
||||
# Verify output
|
||||
assert len(chunks) > 0 # At least one chunk should be produced
|
||||
assert len(chunks) > 0
|
||||
assert all(isinstance(chunk, TextChunk) for chunk in chunks)
|
||||
tick.assert_called_once_with(1)
|
||||
|
||||
@ -99,16 +95,18 @@ class TestRunTokens:
|
||||
mock_encoder.decode.side_effect = lambda x: bytes(x).decode()
|
||||
mock_get_encoding.return_value = mock_encoder
|
||||
|
||||
input = ["test"]
|
||||
config = ChunkingConfig(size=5, overlap=1, encoding_model="fake-encoding")
|
||||
input = [
|
||||
"Marley was dead: to begin with. There is no doubt whatever about that. The register of his burial was signed by the clergyman, the clerk, the undertaker, and the chief mourner. Scrooge signed it. And Scrooge's name was good upon 'Change, for anything he chose to put his hand to."
|
||||
]
|
||||
config = ChunkingConfig(size=50, overlap=4, encoding_model="fake-encoding")
|
||||
tick = Mock()
|
||||
metadata = {"author": "John"}
|
||||
metadata = {"author": "Charles"}
|
||||
|
||||
chunks = list(run_tokens(input, config, tick, metadata))
|
||||
|
||||
# Verify metadata is included in the chunk
|
||||
assert len(chunks) > 0
|
||||
assert "author: John" in chunks[0].text_chunk
|
||||
assert "author: Charles" in chunks[0].text_chunk
|
||||
|
||||
@patch("tiktoken.get_encoding")
|
||||
def test_custom_delimiter(self, mock_get_encoding):
|
||||
@ -118,8 +116,10 @@ class TestRunTokens:
|
||||
mock_encoder.decode.side_effect = lambda x: bytes(x).decode()
|
||||
mock_get_encoding.return_value = mock_encoder
|
||||
|
||||
input = ["test"]
|
||||
config = ChunkingConfig(size=5, overlap=1, encoding_model="fake-encoding")
|
||||
input = [
|
||||
"Marley was dead: to begin with. There is no doubt whatever about that. The register of his burial was signed by the clergyman, the clerk, the undertaker, and the chief mourner. Scrooge signed it. And Scrooge's name was good upon 'Change, for anything he chose to put his hand to."
|
||||
]
|
||||
config = ChunkingConfig(size=50, overlap=4, encoding_model="fake-encoding")
|
||||
tick = Mock()
|
||||
metadata = {"key": "value"}
|
||||
|
||||
|
||||
@ -7,6 +7,7 @@ from unittest.mock import MagicMock
|
||||
import pandas as pd
|
||||
import pytest
|
||||
import tiktoken
|
||||
from pydantic import ValidationError
|
||||
|
||||
from graphrag.index.operations.chunk_text.typing import TextChunk
|
||||
from graphrag.index.text_splitting.text_splitting import (
|
||||
@ -261,6 +262,36 @@ def test_split_multiple_texts_on_tokens_metadata_one_column():
|
||||
)
|
||||
assert split == expected
|
||||
|
||||
def test_split_multiple_texts_on_tokens_metadata_large():
|
||||
input_df = pd.DataFrame({
|
||||
"text": ["Receptionist", "Officer", "Captain"],
|
||||
"command": ["Jump", "Walk", "Run"],
|
||||
"metadata": [
|
||||
"Table 1 red with glass",
|
||||
"Office 1 with central table and 16 chairs",
|
||||
"Ship 1 with weird name",
|
||||
],
|
||||
})
|
||||
|
||||
mocked_tokenizer = MockTokenizer()
|
||||
mock_tick = MagicMock()
|
||||
tokenizer = Tokenizer(
|
||||
chunk_overlap=0,
|
||||
tokens_per_chunk=5,
|
||||
decode=mocked_tokenizer.decode,
|
||||
encode=lambda text: mocked_tokenizer.encode(text),
|
||||
)
|
||||
|
||||
texts = input_df["text"].to_numpy().tolist()
|
||||
metadata = input_df["metadata"].to_numpy().tolist()
|
||||
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="Metadata tokens exceed the maximum tokens per chunk. Please increase the tokens per chunk.",
|
||||
):
|
||||
split_multiple_texts_on_tokens(
|
||||
[texts[0]], tokenizer, tick=mock_tick, metadata={"metadata": metadata[0]}
|
||||
)
|
||||
|
||||
def test_split_multiple_texts_on_tokens_metadata_two_columns():
|
||||
input_df = pd.DataFrame({
|
||||
|
||||
Loading…
Reference in New Issue
Block a user