mirror of
https://github.com/microsoft/graphrag.git
synced 2026-01-14 00:57:23 +08:00
Merge branch 'v3/main' into input-factory
This commit is contained in:
commit
8e3c7170f7
@ -100,12 +100,11 @@ These settings configure how we parse documents into text chunks. This is necess
|
||||
|
||||
#### Fields
|
||||
|
||||
- `strategy` **str**[tokens|sentences] - How to chunk the text.
|
||||
- `size` **int** - The max chunk size in tokens.
|
||||
- `overlap` **int** - The chunk overlap in tokens.
|
||||
- `strategy` **str**[tokens|sentences] - How to chunk the text.
|
||||
- `encoding_model` **str** - The text encoding model to use for splitting on token boundaries.
|
||||
- `prepend_metadata` **bool** - Determines if metadata values should be added at the beginning of each chunk. Default=`False`.
|
||||
- `chunk_size_includes_metadata` **bool** - Specifies whether the chunk size calculation should include metadata tokens. Default=`False`.
|
||||
|
||||
## Outputs and Storage
|
||||
|
||||
|
||||
@ -82,10 +82,9 @@ As described above, when documents are imported you can specify a list of `metad
|
||||
|
||||
### Chunking Config
|
||||
|
||||
Next, the `chunks` block needs to instruct the chunker how to handle this metadata when creating text units. By default, it is ignored. We have two settings to include it:
|
||||
Next, the `chunks` block needs to instruct the chunker how to handle this metadata when creating text units. By default, it is ignored. We have the following setting to include it:
|
||||
|
||||
- `prepend_metadata`. This instructs the importer to copy the contents of the `metadata` column for each row into the start of every single text chunk. This metadata is copied as key: value pairs on new lines.
|
||||
- `chunk_size_includes_metadata`: This tells the chunker how to compute the chunk size when metadata is included. By default, we create the text units using your specified `chunk_size` *and then* prepend the metadata. This means that the final text unit lengths may be longer than your configured `chunk_size`, and it will vary based on the length of the metadata for each document. When this setting is `True`, we will compute the raw text using the remainder after measuring the metadata length so that the resulting text units always comply with your configured `chunk_size`.
|
||||
|
||||
### Examples
|
||||
|
||||
@ -124,7 +123,6 @@ chunks:
|
||||
size: 100
|
||||
overlap: 0
|
||||
prepend_metadata: true
|
||||
chunk_size_includes_metadata: false
|
||||
```
|
||||
|
||||
Documents DataFrame
|
||||
@ -162,54 +160,6 @@ US to lift most federal COVID-19 vaccine mandates,WASHINGTON (AP) The Biden admi
|
||||
|
||||
NY lawmakers begin debating budget 1 month after due date,ALBANY, N.Y. (AP) New York lawmakers began voting Monday on a $229 billion state budget due a month ago that would raise the minimum wage, crack down on illicit pot shops and ban gas stoves and furnaces in new buildings. Negotiations among Gov. Kathy Hochul and her fellow Democrats in control of the Legislature dragged on past the April 1 budget deadline, largely because of disagreements over changes to the bail law and other policy proposals included in the spending plan. Floor debates on some budget bills began Monday. State Senate Majority Leader Andrea Stewart-Cousins said she expected voting to be wrapped up Tuesday for a budget she said contains "significant wins" for New Yorkers. "I would have liked to have done this sooner. I think we would all agree to that," Cousins told reporters before voting began. "This has been a very policy-laden budget and a lot of the policies had to parsed through." Hochul was able to push through a change to the bail law that will eliminate the standard that requires judges to prescribe the "least restrictive" means to ensure defendants return to court. Hochul said judges needed the extra discretion. Some liberal lawmakers argued that it would undercut the sweeping bail reforms approved in 2019 and result in more people with low incomes and people of color in pretrial detention. Here are some other policy provisions that will be included in the budget, according to state officials. The minimum wage would be raised to $17 in New York City and some of its suburbs and $16 in the rest of the state by 2026. That's up from $15 in the city and $14.20 upstate.
|
||||
|
||||
--
|
||||
|
||||
settings.yaml
|
||||
|
||||
```yaml
|
||||
input:
|
||||
file_type: csv
|
||||
title_column: headline
|
||||
text_column: article
|
||||
metadata: [headline]
|
||||
|
||||
chunks:
|
||||
size: 50
|
||||
overlap: 5
|
||||
prepend_metadata: true
|
||||
chunk_size_includes_metadata: true
|
||||
```
|
||||
|
||||
Documents DataFrame
|
||||
|
||||
| id | title | text | creation_date | metadata |
|
||||
| --------------------- | --------------------------------------------------------- | ------------------------ | ----------------------------- | --------------------------------------------------------------------------- |
|
||||
| (generated from text) | US to lift most federal COVID-19 vaccine mandates | (article column content) | (create date of articles.csv) | { "headline": "US to lift most federal COVID-19 vaccine mandates" } |
|
||||
| (generated from text) | NY lawmakers begin debating budget 1 month after due date | (article column content) | (create date of articles.csv) | { "headline": "NY lawmakers begin debating budget 1 month after due date" } |
|
||||
|
||||
Raw Text Chunks
|
||||
|
||||
| content | length |
|
||||
| ------- | ------: |
|
||||
| title: US to lift most federal COVID-19 vaccine mandates<br>WASHINGTON (AP) The Biden administration will end most of the last remaining federal COVID-19 vaccine requirements next week when the national public health emergency for the coronavirus ends, the White House said Monday. Vaccine requirements for federal workers and federal contractors, | 50 |
|
||||
| title: US to lift most federal COVID-19 vaccine mandates<br>federal workers and federal contractors as well as foreign air travelers to the U.S., will end May 11. The government is also beginning the process of lifting shot requirements for Head Start educators, healthcare workers, and noncitizens at U.S. land borders. | 50 |
|
||||
| title: US to lift most federal COVID-19 vaccine mandates<br>noncitizens at U.S. land borders. The requirements are among the last vestiges of some of the more coercive measures taken by the federal government to promote vaccination as the deadly virus raged, and their end marks the latest display of how | 50 |
|
||||
| title: US to lift most federal COVID-19 vaccine mandates<br>the latest display of how President Joe Biden's administration is moving to treat COVID-19 as a routine, endemic illness. "While I believe that these vaccine mandates had a tremendous beneficial impact, we are now at a point where we think that | 50 |
|
||||
| title: US to lift most federal COVID-19 vaccine mandates<br>point where we think that it makes a lot of sense to pull these requirements down," White House COVID-19 coordinator Dr. Ashish Jha told The Associated Press on Monday. | 38 |
|
||||
| title: NY lawmakers begin debating budget 1 month after due date<br>ALBANY, N.Y. (AP) New York lawmakers began voting Monday on a $229 billion state budget due a month ago that would raise the minimum wage, crack down on illicit pot shops and ban gas stoves and furnaces in new | 50 |
|
||||
| title: NY lawmakers begin debating budget 1 month after due date<br>stoves and furnaces in new buildings. Negotiations among Gov. Kathy Hochul and her fellow Democrats in control of the Legislature dragged on past the April 1 budget deadline, largely because of disagreements over changes to the bail law and | 50 |
|
||||
| title: NY lawmakers begin debating budget 1 month after due date<br>to the bail law and other policy proposals included in the spending plan. Floor debates on some budget bills began Monday. State Senate Majority Leader Andrea Stewart-Cousins said she expected voting to be wrapped up Tuesday for a budget | 50 |
|
||||
|title: NY lawmakers begin debating budget 1 month after due date<br>up Tuesday for a budget she said contains "significant wins" for New Yorkers. "I would have liked to have done this sooner. I think we would all agree to that," Cousins told reporters before voting began. "This has been | 50 |
|
||||
| title: NY lawmakers begin debating budget 1 month after due date<br>voting began. "This has been a very policy-laden budget and a lot of the policies had to parsed through." Hochul was able to push through a change to the bail law that will eliminate the standard that requires judges | 50 |
|
||||
| title: NY lawmakers begin debating budget 1 month after due date<br>the standard that requires judges to prescribe the "least restrictive" means to ensure defendants return to court. Hochul said judges needed the extra discretion. Some liberal lawmakers argued that it would undercut the sweeping bail reforms approved in 2019 | 50 |
|
||||
| title: NY lawmakers begin debating budget 1 month after due date<br>bail reforms approved in 2019 and result in more people with low incomes and people of color in pretrial detention. Here are some other policy provisions that will be included in the budget, according to state officials. The minimum | 50 |
|
||||
| title: NY lawmakers begin debating budget 1 month after due date<br>to state officials. The minimum wage would be raised to $17 in be raised to $17 in New York City and some of its suburbs and $16 in the rest of the state by 2026. That's up from $15 | 50 |
|
||||
| title: NY lawmakers begin debating budget 1 month after due date<br>2026. That's up from $15 in the city and $14.20 upstate. | 22 |
|
||||
|
||||
|
||||
In this example we can see that the two input documents were parsed into fourteen output text chunks. The title (headline) of each document is prepended and included in the computed chunk size, so each chunk matches the configured chunk size (except the last one for each document). We've also configured some overlap in these text chunks, so the last five tokens are shared. Why would you use overlap in your text chunks? Consider that when you are splitting documents based on tokens, it is highly likely that sentences or even related concepts will be split into separate chunks. Each text chunk is processed separately by the language model, so this may result in incomplete "ideas" at the boundaries of the chunk. Overlap ensures that these split concepts are fully contained in at least one of the chunks.
|
||||
|
||||
|
||||
#### JSON files
|
||||
|
||||
This final example uses a JSON file for each of the same two articles. In this example we'll set the object fields to read, but we will not add metadata to the text chunks.
|
||||
|
||||
32
packages/graphrag-chunking/README.md
Normal file
32
packages/graphrag-chunking/README.md
Normal file
@ -0,0 +1,32 @@
|
||||
# GraphRAG Chunking
|
||||
|
||||
This package contains a collection of text chunkers, a core config model, and a factory for acquiring instances.
|
||||
|
||||
## Examples
|
||||
|
||||
Basic sentence chunking with nltk
|
||||
```python
|
||||
chunker = SentenceChunker()
|
||||
chunks = chunker.chunk("This is a test. Another sentence.")
|
||||
print(chunks) # ["This is a test.", "Another sentence."]
|
||||
```
|
||||
|
||||
Token chunking
|
||||
```python
|
||||
tokenizer = tiktoken.get_encoding("o200k_base")
|
||||
chunker = TokenChunker(size=3, overlap=0, encode=tokenizer.encode, decode=tokenizer.decode)
|
||||
chunks = chunker.chunk("This is a random test fragment of some text")
|
||||
print(chunks) # ["This is a", " random test fragment", " of some text"]
|
||||
```
|
||||
|
||||
Using the factory via helper util
|
||||
```python
|
||||
tokenizer = tiktoken.get_encoding("o200k_base")
|
||||
config = ChunkingConfig(
|
||||
strategy="tokens",
|
||||
size=3,
|
||||
overlap=0
|
||||
)
|
||||
chunker = create_chunker(config, tokenizer.encode, tokenizer.decode)
|
||||
...
|
||||
```
|
||||
@ -1,2 +1,4 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""System-level chunking package."""
|
||||
19
packages/graphrag-chunking/graphrag_chunking/add_metadata.py
Normal file
19
packages/graphrag-chunking/graphrag_chunking/add_metadata.py
Normal file
@ -0,0 +1,19 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""A module containing 'prepend_metadata' function."""
|
||||
|
||||
|
||||
def add_metadata(
|
||||
text: str,
|
||||
metadata: dict,
|
||||
delimiter: str = ": ",
|
||||
line_delimiter: str = "\n",
|
||||
append: bool = False,
|
||||
) -> str:
|
||||
"""Add metadata to the given text, prepending by default. This utility writes the dict as rows of key/value pairs."""
|
||||
metadata_str = (
|
||||
line_delimiter.join(f"{k}{delimiter}{v}" for k, v in metadata.items())
|
||||
+ line_delimiter
|
||||
)
|
||||
return text + metadata_str if append else metadata_str + text
|
||||
17
packages/graphrag-chunking/graphrag_chunking/chunk_result.py
Normal file
17
packages/graphrag-chunking/graphrag_chunking/chunk_result.py
Normal file
@ -0,0 +1,17 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""The ChunkResult dataclass."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChunkResult:
|
||||
"""Result of chunking a document."""
|
||||
|
||||
text: str
|
||||
index: int
|
||||
start_char: int
|
||||
end_char: int
|
||||
token_count: int | None = None
|
||||
@ -0,0 +1,13 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""Chunk strategy type enumeration."""
|
||||
|
||||
from enum import StrEnum
|
||||
|
||||
|
||||
class ChunkerType(StrEnum):
|
||||
"""ChunkerType class definition."""
|
||||
|
||||
Tokens = "tokens"
|
||||
Sentence = "sentence"
|
||||
21
packages/graphrag-chunking/graphrag_chunking/chunker.py
Normal file
21
packages/graphrag-chunking/graphrag_chunking/chunker.py
Normal file
@ -0,0 +1,21 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""A module containing the 'Chunker' class."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
|
||||
from graphrag_chunking.chunk_result import ChunkResult
|
||||
|
||||
|
||||
class Chunker(ABC):
|
||||
"""Abstract base class for document chunkers."""
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
"""Create a chunker instance."""
|
||||
|
||||
@abstractmethod
|
||||
def chunk(self, text: str) -> list[ChunkResult]:
|
||||
"""Chunk method definition."""
|
||||
@ -0,0 +1,77 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""A module containing 'ChunkerFactory', 'register_chunker', and 'create_chunker'."""
|
||||
|
||||
from collections.abc import Callable
|
||||
|
||||
from graphrag_common.factory.factory import Factory, ServiceScope
|
||||
|
||||
from graphrag_chunking.chunk_strategy_type import ChunkerType
|
||||
from graphrag_chunking.chunker import Chunker
|
||||
from graphrag_chunking.chunking_config import ChunkingConfig
|
||||
|
||||
|
||||
class ChunkerFactory(Factory[Chunker]):
|
||||
"""Factory for creating Chunker instances."""
|
||||
|
||||
|
||||
chunker_factory = ChunkerFactory()
|
||||
|
||||
|
||||
def register_chunker(
|
||||
chunker_type: str,
|
||||
chunker_initializer: Callable[..., Chunker],
|
||||
scope: ServiceScope = "transient",
|
||||
) -> None:
|
||||
"""Register a custom chunker implementation.
|
||||
|
||||
Args
|
||||
----
|
||||
- chunker_type: str
|
||||
The chunker id to register.
|
||||
- chunker_initializer: Callable[..., Chunker]
|
||||
The chunker initializer to register.
|
||||
"""
|
||||
chunker_factory.register(chunker_type, chunker_initializer, scope)
|
||||
|
||||
|
||||
def create_chunker(
|
||||
config: ChunkingConfig,
|
||||
encode: Callable[[str], list[int]] | None = None,
|
||||
decode: Callable[[list[int]], str] | None = None,
|
||||
) -> Chunker:
|
||||
"""Create a chunker implementation based on the given configuration.
|
||||
|
||||
Args
|
||||
----
|
||||
- config: ChunkingConfig
|
||||
The chunker configuration to use.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Chunker
|
||||
The created chunker implementation.
|
||||
"""
|
||||
config_model = config.model_dump()
|
||||
if encode is not None:
|
||||
config_model["encode"] = encode
|
||||
if decode is not None:
|
||||
config_model["decode"] = decode
|
||||
chunker_strategy = config.type
|
||||
|
||||
if chunker_strategy not in chunker_factory:
|
||||
match chunker_strategy:
|
||||
case ChunkerType.Tokens:
|
||||
from graphrag_chunking.token_chunker import TokenChunker
|
||||
|
||||
register_chunker(ChunkerType.Tokens, TokenChunker)
|
||||
case ChunkerType.Sentence:
|
||||
from graphrag_chunking.sentence_chunker import SentenceChunker
|
||||
|
||||
register_chunker(ChunkerType.Sentence, SentenceChunker)
|
||||
case _:
|
||||
msg = f"ChunkingConfig.strategy '{chunker_strategy}' is not registered in the ChunkerFactory. Registered types: {', '.join(chunker_factory.keys())}."
|
||||
raise ValueError(msg)
|
||||
|
||||
return chunker_factory.create(chunker_strategy, init_args=config_model)
|
||||
@ -0,0 +1,36 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""Parameterization settings for the default configuration."""
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from graphrag_chunking.chunk_strategy_type import ChunkerType
|
||||
|
||||
|
||||
class ChunkingConfig(BaseModel):
|
||||
"""Configuration section for chunking."""
|
||||
|
||||
model_config = ConfigDict(extra="allow")
|
||||
"""Allow extra fields to support custom cache implementations."""
|
||||
|
||||
type: str = Field(
|
||||
description="The chunking type to use.",
|
||||
default=ChunkerType.Tokens,
|
||||
)
|
||||
encoding_model: str | None = Field(
|
||||
description="The encoding model to use.",
|
||||
default=None,
|
||||
)
|
||||
size: int = Field(
|
||||
description="The chunk size to use.",
|
||||
default=1200,
|
||||
)
|
||||
overlap: int = Field(
|
||||
description="The chunk overlap to use.",
|
||||
default=100,
|
||||
)
|
||||
prepend_metadata: bool = Field(
|
||||
description="Prepend metadata into each chunk.",
|
||||
default=False,
|
||||
)
|
||||
@ -0,0 +1,30 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""A module containing 'create_chunk_results' function."""
|
||||
|
||||
from collections.abc import Callable
|
||||
|
||||
from graphrag_chunking.chunk_result import ChunkResult
|
||||
|
||||
|
||||
def create_chunk_results(
|
||||
chunks: list[str],
|
||||
encode: Callable[[str], list[int]] | None = None,
|
||||
) -> list[ChunkResult]:
|
||||
"""Create chunk results from a list of text chunks. The index assignments are 0-based and assume chunks we not stripped relative to the source text."""
|
||||
results = []
|
||||
start_char = 0
|
||||
for index, chunk in enumerate(chunks):
|
||||
end_char = start_char + len(chunk) - 1 # 0-based indices
|
||||
chunk = ChunkResult(
|
||||
text=chunk,
|
||||
index=index,
|
||||
start_char=start_char,
|
||||
end_char=end_char,
|
||||
)
|
||||
if encode:
|
||||
chunk.token_count = len(encode(chunk.text))
|
||||
results.append(chunk)
|
||||
start_char = end_char + 1
|
||||
return results
|
||||
@ -0,0 +1,44 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""A module containing 'SentenceChunker' class."""
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
import nltk
|
||||
|
||||
from graphrag_chunking.bootstrap_nltk import bootstrap
|
||||
from graphrag_chunking.chunk_result import ChunkResult
|
||||
from graphrag_chunking.chunker import Chunker
|
||||
from graphrag_chunking.create_chunk_results import create_chunk_results
|
||||
|
||||
|
||||
class SentenceChunker(Chunker):
|
||||
"""A chunker that splits text into sentence-based chunks."""
|
||||
|
||||
def __init__(
|
||||
self, encode: Callable[[str], list[int]] | None = None, **kwargs: Any
|
||||
) -> None:
|
||||
"""Create a sentence chunker instance."""
|
||||
self._encode = encode
|
||||
bootstrap()
|
||||
|
||||
def chunk(self, text) -> list[ChunkResult]:
|
||||
"""Chunk the text into sentence-based chunks."""
|
||||
sentences = nltk.sent_tokenize(text.strip())
|
||||
results = create_chunk_results(sentences, encode=self._encode)
|
||||
# nltk sentence tokenizer may trim whitespace, so we need to adjust start/end chars
|
||||
for index, result in enumerate(results):
|
||||
txt = result.text
|
||||
start = result.start_char
|
||||
actual_start = text.find(txt, start)
|
||||
delta = actual_start - start
|
||||
if delta > 0:
|
||||
result.start_char += delta
|
||||
result.end_char += delta
|
||||
# bump the next to keep the start check from falling too far behind
|
||||
if index < len(results) - 1:
|
||||
results[index + 1].start_char += delta
|
||||
results[index + 1].end_char += delta
|
||||
return results
|
||||
@ -0,0 +1,67 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""A module containing 'TokenChunker' class."""
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
from graphrag_chunking.chunk_result import ChunkResult
|
||||
from graphrag_chunking.chunker import Chunker
|
||||
from graphrag_chunking.create_chunk_results import create_chunk_results
|
||||
|
||||
|
||||
class TokenChunker(Chunker):
|
||||
"""A chunker that splits text into token-based chunks."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
size: int,
|
||||
overlap: int,
|
||||
encode: Callable[[str], list[int]],
|
||||
decode: Callable[[list[int]], str],
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Create a token chunker instance."""
|
||||
self._size = size
|
||||
self._overlap = overlap
|
||||
self._encode = encode
|
||||
self._decode = decode
|
||||
|
||||
def chunk(self, text: str) -> list[ChunkResult]:
|
||||
"""Chunk the text into token-based chunks."""
|
||||
chunks = split_text_on_tokens(
|
||||
text,
|
||||
chunk_size=self._size,
|
||||
chunk_overlap=self._overlap,
|
||||
encode=self._encode,
|
||||
decode=self._decode,
|
||||
)
|
||||
return create_chunk_results(chunks, encode=self._encode)
|
||||
|
||||
|
||||
def split_text_on_tokens(
|
||||
text: str,
|
||||
chunk_size: int,
|
||||
chunk_overlap: int,
|
||||
encode: Callable[[str], list[int]],
|
||||
decode: Callable[[list[int]], str],
|
||||
) -> list[str]:
|
||||
"""Split a single text and return chunks using the tokenizer."""
|
||||
result = []
|
||||
input_tokens = encode(text)
|
||||
|
||||
start_idx = 0
|
||||
cur_idx = min(start_idx + chunk_size, len(input_tokens))
|
||||
chunk_tokens = input_tokens[start_idx:cur_idx]
|
||||
|
||||
while start_idx < len(input_tokens):
|
||||
chunk_text = decode(list(chunk_tokens))
|
||||
result.append(chunk_text) # Append chunked text as string
|
||||
if cur_idx == len(input_tokens):
|
||||
break
|
||||
start_idx += chunk_size - chunk_overlap
|
||||
cur_idx = min(start_idx + chunk_size, len(input_tokens))
|
||||
chunk_tokens = input_tokens[start_idx:cur_idx]
|
||||
|
||||
return result
|
||||
43
packages/graphrag-chunking/pyproject.toml
Normal file
43
packages/graphrag-chunking/pyproject.toml
Normal file
@ -0,0 +1,43 @@
|
||||
[project]
|
||||
name = "graphrag-chunking"
|
||||
version = "2.7.0"
|
||||
description = "Chunking utilities for GraphRAG"
|
||||
authors = [
|
||||
{name = "Alonso Guevara Fernández", email = "alonsog@microsoft.com"},
|
||||
{name = "Andrés Morales Esquivel", email = "andresmor@microsoft.com"},
|
||||
{name = "Chris Trevino", email = "chtrevin@microsoft.com"},
|
||||
{name = "David Tittsworth", email = "datittsw@microsoft.com"},
|
||||
{name = "Dayenne de Souza", email = "ddesouza@microsoft.com"},
|
||||
{name = "Derek Worthen", email = "deworthe@microsoft.com"},
|
||||
{name = "Gaudy Blanco Meneses", email = "gaudyb@microsoft.com"},
|
||||
{name = "Ha Trinh", email = "trinhha@microsoft.com"},
|
||||
{name = "Jonathan Larson", email = "jolarso@microsoft.com"},
|
||||
{name = "Josh Bradley", email = "joshbradley@microsoft.com"},
|
||||
{name = "Kate Lytvynets", email = "kalytv@microsoft.com"},
|
||||
{name = "Kenny Zhang", email = "zhangken@microsoft.com"},
|
||||
{name = "Mónica Carvajal"},
|
||||
{name = "Nathan Evans", email = "naevans@microsoft.com"},
|
||||
{name = "Rodrigo Racanicci", email = "rracanicci@microsoft.com"},
|
||||
{name = "Sarah Smith", email = "smithsarah@microsoft.com"},
|
||||
]
|
||||
license = {text = "MIT"}
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.11,<3.14"
|
||||
classifiers = [
|
||||
"Programming Language :: Python :: 3",
|
||||
"Programming Language :: Python :: 3.11",
|
||||
"Programming Language :: Python :: 3.12",
|
||||
"Programming Language :: Python :: 3.13",
|
||||
]
|
||||
dependencies = [
|
||||
"graphrag-common==2.7.0",
|
||||
"pydantic~=2.10",
|
||||
]
|
||||
|
||||
[project.urls]
|
||||
Source = "https://github.com/microsoft/graphrag"
|
||||
|
||||
[build-system]
|
||||
requires = ["hatchling>=1.27.0,<2.0.0"]
|
||||
build-backend = "hatchling.build"
|
||||
|
||||
@ -12,13 +12,10 @@ Backwards compatibility is not guaranteed at this time.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Annotated
|
||||
|
||||
import annotated_types
|
||||
from pydantic import PositiveInt, validate_call
|
||||
|
||||
from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks
|
||||
from graphrag.config.defaults import graphrag_config_defaults
|
||||
from graphrag.config.models.graph_rag_config import GraphRagConfig
|
||||
from graphrag.language_model.manager import ModelManager
|
||||
from graphrag.logger.standard_logging import init_loggers
|
||||
@ -55,10 +52,6 @@ logger = logging.getLogger(__name__)
|
||||
@validate_call(config={"arbitrary_types_allowed": True})
|
||||
async def generate_indexing_prompts(
|
||||
config: GraphRagConfig,
|
||||
chunk_size: PositiveInt = graphrag_config_defaults.chunks.size,
|
||||
overlap: Annotated[
|
||||
int, annotated_types.Gt(-1)
|
||||
] = graphrag_config_defaults.chunks.overlap,
|
||||
limit: PositiveInt = 15,
|
||||
selection_method: DocSelectionType = DocSelectionType.RANDOM,
|
||||
domain: str | None = None,
|
||||
@ -100,8 +93,6 @@ async def generate_indexing_prompts(
|
||||
limit=limit,
|
||||
select_method=selection_method,
|
||||
logger=logger,
|
||||
chunk_size=chunk_size,
|
||||
overlap=overlap,
|
||||
n_subset_max=n_subset_max,
|
||||
k=k,
|
||||
)
|
||||
|
||||
@ -306,14 +306,14 @@ def _prompt_tune_cli(
|
||||
help="The minimum number of examples to generate/include in the entity extraction prompt.",
|
||||
),
|
||||
chunk_size: int = typer.Option(
|
||||
graphrag_config_defaults.chunks.size,
|
||||
graphrag_config_defaults.chunking.size,
|
||||
"--chunk-size",
|
||||
help="The size of each example text chunk. Overrides chunks.size in the configuration file.",
|
||||
help="The size of each example text chunk. Overrides chunking.size in the configuration file.",
|
||||
),
|
||||
overlap: int = typer.Option(
|
||||
graphrag_config_defaults.chunks.overlap,
|
||||
graphrag_config_defaults.chunking.overlap,
|
||||
"--overlap",
|
||||
help="The overlap size for chunking documents. Overrides chunks.overlap in the configuration file.",
|
||||
help="The overlap size for chunking documents. Overrides chunking.overlap in the configuration file.",
|
||||
),
|
||||
language: str | None = typer.Option(
|
||||
None,
|
||||
|
||||
@ -61,11 +61,11 @@ async def prompt_tune(
|
||||
)
|
||||
|
||||
# override chunking config in the configuration
|
||||
if chunk_size != graph_config.chunks.size:
|
||||
graph_config.chunks.size = chunk_size
|
||||
if chunk_size != graph_config.chunking.size:
|
||||
graph_config.chunking.size = chunk_size
|
||||
|
||||
if overlap != graph_config.chunks.overlap:
|
||||
graph_config.chunks.overlap = overlap
|
||||
if overlap != graph_config.chunking.overlap:
|
||||
graph_config.chunking.overlap = overlap
|
||||
|
||||
# configure the root logger with the specified log level
|
||||
from graphrag.logger.standard_logging import init_loggers
|
||||
@ -81,8 +81,6 @@ async def prompt_tune(
|
||||
|
||||
prompts = await api.generate_indexing_prompts(
|
||||
config=graph_config,
|
||||
chunk_size=chunk_size,
|
||||
overlap=overlap,
|
||||
limit=limit,
|
||||
selection_method=selection_method,
|
||||
domain=domain,
|
||||
|
||||
@ -8,13 +8,13 @@ from pathlib import Path
|
||||
from typing import ClassVar
|
||||
|
||||
from graphrag_cache import CacheType
|
||||
from graphrag_chunking.chunk_strategy_type import ChunkerType
|
||||
from graphrag_storage import StorageType
|
||||
|
||||
from graphrag.config.embeddings import default_embeddings
|
||||
from graphrag.config.enums import (
|
||||
AsyncType,
|
||||
AuthType,
|
||||
ChunkStrategyType,
|
||||
ModelType,
|
||||
NounPhraseExtractorType,
|
||||
ReportingType,
|
||||
@ -57,15 +57,14 @@ class BasicSearchDefaults:
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChunksDefaults:
|
||||
"""Default values for chunks."""
|
||||
class ChunkingDefaults:
|
||||
"""Default values for chunking."""
|
||||
|
||||
type: str = ChunkerType.Tokens
|
||||
size: int = 1200
|
||||
overlap: int = 100
|
||||
strategy: ClassVar[ChunkStrategyType] = ChunkStrategyType.tokens
|
||||
encoding_model: str = ENCODING_MODEL
|
||||
prepend_metadata: bool = False
|
||||
chunk_size_includes_metadata: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -127,7 +126,6 @@ class EmbedTextDefaults:
|
||||
batch_size: int = 16
|
||||
batch_max_tokens: int = 8191
|
||||
names: list[str] = field(default_factory=lambda: default_embeddings)
|
||||
strategy: None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -140,7 +138,6 @@ class ExtractClaimsDefaults:
|
||||
"Any claims or facts that could be relevant to information discovery."
|
||||
)
|
||||
max_gleanings: int = 1
|
||||
strategy: None = None
|
||||
model_id: str = DEFAULT_CHAT_MODEL_ID
|
||||
model_instance_name: str = "extract_claims"
|
||||
|
||||
@ -154,7 +151,6 @@ class ExtractGraphDefaults:
|
||||
default_factory=lambda: ["organization", "person", "geo", "event"]
|
||||
)
|
||||
max_gleanings: int = 1
|
||||
strategy: None = None
|
||||
model_id: str = DEFAULT_CHAT_MODEL_ID
|
||||
model_instance_name: str = "extract_graph"
|
||||
|
||||
@ -362,7 +358,6 @@ class SummarizeDescriptionsDefaults:
|
||||
prompt: None = None
|
||||
max_length: int = 500
|
||||
max_input_tokens: int = 4_000
|
||||
strategy: None = None
|
||||
model_id: str = DEFAULT_CHAT_MODEL_ID
|
||||
model_instance_name: str = "summarize_descriptions"
|
||||
|
||||
@ -403,7 +398,7 @@ class GraphRagConfigDefaults:
|
||||
cache: CacheDefaults = field(default_factory=CacheDefaults)
|
||||
input: InputDefaults = field(default_factory=InputDefaults)
|
||||
embed_text: EmbedTextDefaults = field(default_factory=EmbedTextDefaults)
|
||||
chunks: ChunksDefaults = field(default_factory=ChunksDefaults)
|
||||
chunking: ChunkingDefaults = field(default_factory=ChunkingDefaults)
|
||||
snapshots: SnapshotsDefaults = field(default_factory=SnapshotsDefaults)
|
||||
extract_graph: ExtractGraphDefaults = field(default_factory=ExtractGraphDefaults)
|
||||
extract_graph_nlp: ExtractGraphNLPDefaults = field(
|
||||
|
||||
@ -80,17 +80,6 @@ class AsyncType(str, Enum):
|
||||
Threaded = "threaded"
|
||||
|
||||
|
||||
class ChunkStrategyType(str, Enum):
|
||||
"""ChunkStrategy class definition."""
|
||||
|
||||
tokens = "tokens"
|
||||
sentence = "sentence"
|
||||
|
||||
def __repr__(self):
|
||||
"""Get a string representation."""
|
||||
return f'"{self.value}"'
|
||||
|
||||
|
||||
class SearchMethod(Enum):
|
||||
"""The type of search to run."""
|
||||
|
||||
|
||||
@ -54,9 +54,11 @@ input:
|
||||
base_dir: "{graphrag_config_defaults.input.storage.base_dir}"
|
||||
file_type: {graphrag_config_defaults.input.file_type.value} # [csv, text, json]
|
||||
|
||||
chunks:
|
||||
size: {graphrag_config_defaults.chunks.size}
|
||||
overlap: {graphrag_config_defaults.chunks.overlap}
|
||||
chunking:
|
||||
type: {graphrag_config_defaults.chunking.type}
|
||||
size: {graphrag_config_defaults.chunking.size}
|
||||
overlap: {graphrag_config_defaults.chunking.overlap}
|
||||
encoding_model: {graphrag_config_defaults.chunking.encoding_model}
|
||||
|
||||
### Output/storage settings ###
|
||||
## If blob storage is specified in the following four sections,
|
||||
|
||||
@ -1,38 +0,0 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""Parameterization settings for the default configuration."""
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from graphrag.config.defaults import graphrag_config_defaults
|
||||
from graphrag.config.enums import ChunkStrategyType
|
||||
|
||||
|
||||
class ChunkingConfig(BaseModel):
|
||||
"""Configuration section for chunking."""
|
||||
|
||||
size: int = Field(
|
||||
description="The chunk size to use.",
|
||||
default=graphrag_config_defaults.chunks.size,
|
||||
)
|
||||
overlap: int = Field(
|
||||
description="The chunk overlap to use.",
|
||||
default=graphrag_config_defaults.chunks.overlap,
|
||||
)
|
||||
strategy: ChunkStrategyType = Field(
|
||||
description="The chunking strategy to use.",
|
||||
default=graphrag_config_defaults.chunks.strategy,
|
||||
)
|
||||
encoding_model: str = Field(
|
||||
description="The encoding model to use.",
|
||||
default=graphrag_config_defaults.chunks.encoding_model,
|
||||
)
|
||||
prepend_metadata: bool = Field(
|
||||
description="Prepend metadata into each chunk.",
|
||||
default=graphrag_config_defaults.chunks.prepend_metadata,
|
||||
)
|
||||
chunk_size_includes_metadata: bool = Field(
|
||||
description="Count metadata in max tokens.",
|
||||
default=graphrag_config_defaults.chunks.chunk_size_includes_metadata,
|
||||
)
|
||||
@ -8,6 +8,7 @@ from pathlib import Path
|
||||
|
||||
from devtools import pformat
|
||||
from graphrag_cache import CacheConfig
|
||||
from graphrag_chunking.chunking_config import ChunkingConfig
|
||||
from graphrag_storage import StorageConfig, StorageType
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
@ -15,7 +16,6 @@ import graphrag.config.defaults as defs
|
||||
from graphrag.config.defaults import graphrag_config_defaults
|
||||
from graphrag.config.enums import VectorStoreType
|
||||
from graphrag.config.models.basic_search_config import BasicSearchConfig
|
||||
from graphrag.config.models.chunking_config import ChunkingConfig
|
||||
from graphrag.config.models.cluster_graph_config import ClusterGraphConfig
|
||||
from graphrag.config.models.community_reports_config import CommunityReportsConfig
|
||||
from graphrag.config.models.drift_search_config import DRIFTSearchConfig
|
||||
@ -117,9 +117,15 @@ class GraphRagConfig(BaseModel):
|
||||
Path(self.input.storage.base_dir).resolve()
|
||||
)
|
||||
|
||||
chunks: ChunkingConfig = Field(
|
||||
chunking: ChunkingConfig = Field(
|
||||
description="The chunking configuration to use.",
|
||||
default=ChunkingConfig(),
|
||||
default=ChunkingConfig(
|
||||
type=graphrag_config_defaults.chunking.type,
|
||||
size=graphrag_config_defaults.chunking.size,
|
||||
overlap=graphrag_config_defaults.chunking.overlap,
|
||||
encoding_model=graphrag_config_defaults.chunking.encoding_model,
|
||||
prepend_metadata=graphrag_config_defaults.chunking.prepend_metadata,
|
||||
),
|
||||
)
|
||||
"""The chunking configuration to use."""
|
||||
|
||||
|
||||
@ -1,4 +0,0 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""The Indexing Engine text chunk package root."""
|
||||
@ -1,140 +0,0 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""A module containing chunk_text method definitions."""
|
||||
|
||||
from typing import Any, cast
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
|
||||
from graphrag.config.models.chunking_config import ChunkingConfig, ChunkStrategyType
|
||||
from graphrag.index.operations.chunk_text.typing import (
|
||||
ChunkInput,
|
||||
ChunkStrategy,
|
||||
)
|
||||
from graphrag.logger.progress import ProgressTicker, progress_ticker
|
||||
|
||||
|
||||
def chunk_text(
|
||||
input: pd.DataFrame,
|
||||
column: str,
|
||||
size: int,
|
||||
overlap: int,
|
||||
encoding_model: str,
|
||||
strategy: ChunkStrategyType,
|
||||
callbacks: WorkflowCallbacks,
|
||||
) -> pd.Series:
|
||||
"""
|
||||
Chunk a piece of text into smaller pieces.
|
||||
|
||||
## Usage
|
||||
```yaml
|
||||
args:
|
||||
column: <column name> # The name of the column containing the text to chunk, this can either be a column with text, or a column with a list[tuple[doc_id, str]]
|
||||
strategy: <strategy config> # The strategy to use to chunk the text, see below for more details
|
||||
```
|
||||
|
||||
## Strategies
|
||||
The text chunk verb uses a strategy to chunk the text. The strategy is an object which defines the strategy to use. The following strategies are available:
|
||||
|
||||
### tokens
|
||||
This strategy uses the [tokens] library to chunk a piece of text. The strategy config is as follows:
|
||||
|
||||
```yaml
|
||||
strategy: tokens
|
||||
size: 1200 # Optional, The chunk size to use, default: 1200
|
||||
overlap: 100 # Optional, The chunk overlap to use, default: 100
|
||||
```
|
||||
|
||||
### sentence
|
||||
This strategy uses the nltk library to chunk a piece of text into sentences. The strategy config is as follows:
|
||||
|
||||
```yaml
|
||||
strategy: sentence
|
||||
```
|
||||
"""
|
||||
strategy_exec = _load_strategy(strategy)
|
||||
|
||||
num_total = _get_num_total(input, column)
|
||||
tick = progress_ticker(callbacks.progress, num_total)
|
||||
|
||||
# collapse the config back to a single object to support "polymorphic" function call
|
||||
config = ChunkingConfig(size=size, overlap=overlap, encoding_model=encoding_model)
|
||||
|
||||
return cast(
|
||||
"pd.Series",
|
||||
input.apply(
|
||||
cast(
|
||||
"Any",
|
||||
lambda x: _run_strategy(
|
||||
strategy_exec,
|
||||
x[column],
|
||||
config,
|
||||
tick,
|
||||
),
|
||||
),
|
||||
axis=1,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _run_strategy(
|
||||
strategy_exec: ChunkStrategy,
|
||||
input: ChunkInput,
|
||||
config: ChunkingConfig,
|
||||
tick: ProgressTicker,
|
||||
) -> list[str | tuple[list[str] | None, str, int]]:
|
||||
"""Run strategy method definition."""
|
||||
if isinstance(input, str):
|
||||
return [item.text_chunk for item in strategy_exec([input], config, tick)]
|
||||
|
||||
# We can work with both just a list of text content
|
||||
# or a list of tuples of (document_id, text content)
|
||||
# text_to_chunk = '''
|
||||
texts = [item if isinstance(item, str) else item[1] for item in input]
|
||||
|
||||
strategy_results = strategy_exec(texts, config, tick)
|
||||
|
||||
results = []
|
||||
for strategy_result in strategy_results:
|
||||
doc_indices = strategy_result.source_doc_indices
|
||||
if isinstance(input[doc_indices[0]], str):
|
||||
results.append(strategy_result.text_chunk)
|
||||
else:
|
||||
doc_ids = [input[doc_idx][0] for doc_idx in doc_indices]
|
||||
results.append((
|
||||
doc_ids,
|
||||
strategy_result.text_chunk,
|
||||
strategy_result.n_tokens,
|
||||
))
|
||||
return results
|
||||
|
||||
|
||||
def _load_strategy(strategy: ChunkStrategyType) -> ChunkStrategy:
|
||||
"""Load strategy method definition."""
|
||||
match strategy:
|
||||
case ChunkStrategyType.tokens:
|
||||
from graphrag.index.operations.chunk_text.strategies import run_tokens
|
||||
|
||||
return run_tokens
|
||||
case ChunkStrategyType.sentence:
|
||||
# NLTK
|
||||
from graphrag.index.operations.chunk_text.bootstrap import bootstrap
|
||||
from graphrag.index.operations.chunk_text.strategies import run_sentences
|
||||
|
||||
bootstrap()
|
||||
return run_sentences
|
||||
case _:
|
||||
msg = f"Unknown strategy: {strategy}"
|
||||
raise ValueError(msg)
|
||||
|
||||
|
||||
def _get_num_total(output: pd.DataFrame, column: str) -> int:
|
||||
num_total = 0
|
||||
for row in output[column]:
|
||||
if isinstance(row, str):
|
||||
num_total += 1
|
||||
else:
|
||||
num_total += len(row)
|
||||
return num_total
|
||||
@ -1,52 +0,0 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""A module containing run_tokens and run_sentences methods."""
|
||||
|
||||
from collections.abc import Iterable
|
||||
|
||||
import nltk
|
||||
|
||||
from graphrag.config.models.chunking_config import ChunkingConfig
|
||||
from graphrag.index.operations.chunk_text.typing import TextChunk
|
||||
from graphrag.index.text_splitting.text_splitting import (
|
||||
TokenChunkerOptions,
|
||||
split_multiple_texts_on_tokens,
|
||||
)
|
||||
from graphrag.logger.progress import ProgressTicker
|
||||
from graphrag.tokenizer.get_tokenizer import get_tokenizer
|
||||
|
||||
|
||||
def run_tokens(
|
||||
input: list[str],
|
||||
config: ChunkingConfig,
|
||||
tick: ProgressTicker,
|
||||
) -> Iterable[TextChunk]:
|
||||
"""Chunks text into chunks based on encoding tokens."""
|
||||
tokens_per_chunk = config.size
|
||||
chunk_overlap = config.overlap
|
||||
tokenizer = get_tokenizer(encoding_model=config.encoding_model)
|
||||
return split_multiple_texts_on_tokens(
|
||||
input,
|
||||
TokenChunkerOptions(
|
||||
chunk_overlap=chunk_overlap,
|
||||
tokens_per_chunk=tokens_per_chunk,
|
||||
encode=tokenizer.encode,
|
||||
decode=tokenizer.decode,
|
||||
),
|
||||
tick,
|
||||
)
|
||||
|
||||
|
||||
def run_sentences(
|
||||
input: list[str], _config: ChunkingConfig, tick: ProgressTicker
|
||||
) -> Iterable[TextChunk]:
|
||||
"""Chunks text into multiple parts by sentence."""
|
||||
for doc_idx, text in enumerate(input):
|
||||
sentences = nltk.sent_tokenize(text)
|
||||
for sentence in sentences:
|
||||
yield TextChunk(
|
||||
text_chunk=sentence,
|
||||
source_doc_indices=[doc_idx],
|
||||
)
|
||||
tick(1)
|
||||
@ -1,27 +0,0 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""A module containing 'TextChunk' model."""
|
||||
|
||||
from collections.abc import Callable, Iterable
|
||||
from dataclasses import dataclass
|
||||
|
||||
from graphrag.config.models.chunking_config import ChunkingConfig
|
||||
from graphrag.logger.progress import ProgressTicker
|
||||
|
||||
|
||||
@dataclass
|
||||
class TextChunk:
|
||||
"""Text chunk class definition."""
|
||||
|
||||
text_chunk: str
|
||||
source_doc_indices: list[int]
|
||||
n_tokens: int | None = None
|
||||
|
||||
|
||||
ChunkInput = str | list[str] | list[tuple[str, str]]
|
||||
"""Input to a chunking strategy. Can be a string, a list of strings, or a list of tuples of (id, text)."""
|
||||
|
||||
ChunkStrategy = Callable[
|
||||
[list[str], ChunkingConfig, ProgressTicker], Iterable[TextChunk]
|
||||
]
|
||||
@ -8,9 +8,9 @@ import logging
|
||||
from dataclasses import dataclass
|
||||
|
||||
import numpy as np
|
||||
from graphrag_chunking.token_chunker import split_text_on_tokens
|
||||
|
||||
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
|
||||
from graphrag.index.text_splitting.text_splitting import TokenTextSplitter
|
||||
from graphrag.index.utils.is_null import is_null
|
||||
from graphrag.language_model.protocol.base import EmbeddingModel
|
||||
from graphrag.logger.progress import ProgressTicker, progress_ticker
|
||||
@ -39,17 +39,15 @@ async def run_embed_text(
|
||||
if is_null(input):
|
||||
return TextEmbeddingResult(embeddings=None)
|
||||
|
||||
splitter = _get_splitter(tokenizer, batch_max_tokens)
|
||||
|
||||
semaphore: asyncio.Semaphore = asyncio.Semaphore(num_threads)
|
||||
|
||||
# Break up the input texts. The sizes here indicate how many snippets are in each input text
|
||||
texts, input_sizes = _prepare_embed_texts(input, splitter)
|
||||
texts, input_sizes = _prepare_embed_texts(input, tokenizer, batch_max_tokens)
|
||||
text_batches = _create_text_batches(
|
||||
texts,
|
||||
tokenizer,
|
||||
batch_size,
|
||||
batch_max_tokens,
|
||||
splitter,
|
||||
)
|
||||
logger.info(
|
||||
"embedding %d inputs via %d snippets using %d batches. max_batch_size=%d, batch_max_tokens=%d",
|
||||
@ -72,13 +70,6 @@ async def run_embed_text(
|
||||
return TextEmbeddingResult(embeddings=embeddings)
|
||||
|
||||
|
||||
def _get_splitter(tokenizer: Tokenizer, batch_max_tokens: int) -> TokenTextSplitter:
|
||||
return TokenTextSplitter(
|
||||
tokenizer=tokenizer,
|
||||
chunk_size=batch_max_tokens,
|
||||
)
|
||||
|
||||
|
||||
async def _execute(
|
||||
model: EmbeddingModel,
|
||||
chunks: list[list[str]],
|
||||
@ -100,9 +91,9 @@ async def _execute(
|
||||
|
||||
def _create_text_batches(
|
||||
texts: list[str],
|
||||
tokenizer: Tokenizer,
|
||||
max_batch_size: int,
|
||||
max_batch_tokens: int,
|
||||
splitter: TokenTextSplitter,
|
||||
) -> list[list[str]]:
|
||||
"""Create batches of texts to embed."""
|
||||
# https://learn.microsoft.com/en-us/azure/ai-services/openai/reference
|
||||
@ -112,7 +103,7 @@ def _create_text_batches(
|
||||
current_batch_tokens = 0
|
||||
|
||||
for text in texts:
|
||||
token_count = splitter.num_tokens(text)
|
||||
token_count = tokenizer.num_tokens(text)
|
||||
if (
|
||||
len(current_batch) >= max_batch_size
|
||||
or current_batch_tokens + token_count > max_batch_tokens
|
||||
@ -131,18 +122,23 @@ def _create_text_batches(
|
||||
|
||||
|
||||
def _prepare_embed_texts(
|
||||
input: list[str], splitter: TokenTextSplitter
|
||||
input: list[str],
|
||||
tokenizer: Tokenizer,
|
||||
batch_max_tokens: int = 8191,
|
||||
chunk_overlap: int = 100,
|
||||
) -> tuple[list[str], list[int]]:
|
||||
sizes: list[int] = []
|
||||
snippets: list[str] = []
|
||||
|
||||
for text in input:
|
||||
# Split the input text and filter out any empty content
|
||||
split_texts = splitter.split_text(text)
|
||||
if split_texts is None:
|
||||
continue
|
||||
split_texts = split_text_on_tokens(
|
||||
text,
|
||||
chunk_size=batch_max_tokens,
|
||||
chunk_overlap=chunk_overlap,
|
||||
encode=tokenizer.encode,
|
||||
decode=tokenizer.decode,
|
||||
)
|
||||
split_texts = [text for text in split_texts if len(text) > 0]
|
||||
|
||||
sizes.append(len(split_texts))
|
||||
snippets.extend(split_texts)
|
||||
|
||||
|
||||
@ -1,15 +0,0 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""Token limit method definition."""
|
||||
|
||||
from graphrag.index.text_splitting.text_splitting import TokenTextSplitter
|
||||
|
||||
|
||||
def check_token_limit(text, max_token):
|
||||
"""Check token limit."""
|
||||
text_splitter = TokenTextSplitter(chunk_size=max_token, chunk_overlap=0)
|
||||
docs = text_splitter.split_text(text)
|
||||
if len(docs) > 1:
|
||||
return 0
|
||||
return 1
|
||||
@ -1,18 +1,15 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""A module containing 'TokenChunkerOptions', 'TextSplitter', 'NoopTextSplitter', 'TokenTextSplitter', 'split_single_text_on_tokens', and 'split_multiple_texts_on_tokens'."""
|
||||
"""A module containing 'TokenTextSplitter' class and 'split_single_text_on_tokens' function."""
|
||||
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable, Iterable
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, cast
|
||||
from abc import ABC
|
||||
from collections.abc import Callable
|
||||
from typing import cast
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from graphrag.index.operations.chunk_text.typing import TextChunk
|
||||
from graphrag.logger.progress import ProgressTicker
|
||||
from graphrag.tokenizer.get_tokenizer import get_tokenizer
|
||||
from graphrag.tokenizer.tokenizer import Tokenizer
|
||||
|
||||
@ -24,22 +21,8 @@ LengthFn = Callable[[str], int]
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TokenChunkerOptions:
|
||||
"""TokenChunkerOptions data class."""
|
||||
|
||||
chunk_overlap: int
|
||||
"""Overlap in tokens between chunks"""
|
||||
tokens_per_chunk: int
|
||||
"""Maximum number of tokens per chunk"""
|
||||
decode: DecodeFn
|
||||
""" Function to decode a list of token ids to a string"""
|
||||
encode: EncodeFn
|
||||
""" Function to encode a string to a list of token ids"""
|
||||
|
||||
|
||||
class TextSplitter(ABC):
|
||||
"""Text splitter class definition."""
|
||||
class TokenTextSplitter(ABC):
|
||||
"""Token text splitter class definition."""
|
||||
|
||||
_chunk_size: int
|
||||
_chunk_overlap: int
|
||||
@ -58,6 +41,7 @@ class TextSplitter(ABC):
|
||||
keep_separator: bool = False,
|
||||
add_start_index: bool = False,
|
||||
strip_whitespace: bool = True,
|
||||
tokenizer: Tokenizer | None = None,
|
||||
):
|
||||
"""Init method definition."""
|
||||
self._chunk_size = chunk_size
|
||||
@ -66,30 +50,6 @@ class TextSplitter(ABC):
|
||||
self._keep_separator = keep_separator
|
||||
self._add_start_index = add_start_index
|
||||
self._strip_whitespace = strip_whitespace
|
||||
|
||||
@abstractmethod
|
||||
def split_text(self, text: str | list[str]) -> Iterable[str]:
|
||||
"""Split text method definition."""
|
||||
|
||||
|
||||
class NoopTextSplitter(TextSplitter):
|
||||
"""Noop text splitter class definition."""
|
||||
|
||||
def split_text(self, text: str | list[str]) -> Iterable[str]:
|
||||
"""Split text method definition."""
|
||||
return [text] if isinstance(text, str) else text
|
||||
|
||||
|
||||
class TokenTextSplitter(TextSplitter):
|
||||
"""Token text splitter class definition."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer: Tokenizer | None = None,
|
||||
**kwargs: Any,
|
||||
):
|
||||
"""Init method definition."""
|
||||
super().__init__(**kwargs)
|
||||
self._tokenizer = tokenizer or get_tokenizer()
|
||||
|
||||
def num_tokens(self, text: str) -> int:
|
||||
@ -106,68 +66,37 @@ class TokenTextSplitter(TextSplitter):
|
||||
msg = f"Attempting to split a non-string value, actual is {type(text)}"
|
||||
raise TypeError(msg)
|
||||
|
||||
token_chunker_options = TokenChunkerOptions(
|
||||
return split_single_text_on_tokens(
|
||||
text,
|
||||
chunk_overlap=self._chunk_overlap,
|
||||
tokens_per_chunk=self._chunk_size,
|
||||
decode=self._tokenizer.decode,
|
||||
encode=self._tokenizer.encode,
|
||||
)
|
||||
|
||||
return split_single_text_on_tokens(text=text, tokenizer=token_chunker_options)
|
||||
|
||||
|
||||
def split_single_text_on_tokens(text: str, tokenizer: TokenChunkerOptions) -> list[str]:
|
||||
def split_single_text_on_tokens(
|
||||
text: str,
|
||||
tokens_per_chunk: int,
|
||||
chunk_overlap: int,
|
||||
encode: EncodeFn,
|
||||
decode: DecodeFn,
|
||||
) -> list[str]:
|
||||
"""Split a single text and return chunks using the tokenizer."""
|
||||
result = []
|
||||
input_ids = tokenizer.encode(text)
|
||||
input_ids = encode(text)
|
||||
|
||||
start_idx = 0
|
||||
cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids))
|
||||
cur_idx = min(start_idx + tokens_per_chunk, len(input_ids))
|
||||
chunk_ids = input_ids[start_idx:cur_idx]
|
||||
|
||||
while start_idx < len(input_ids):
|
||||
chunk_text = tokenizer.decode(list(chunk_ids))
|
||||
chunk_text = decode(list(chunk_ids))
|
||||
result.append(chunk_text) # Append chunked text as string
|
||||
if cur_idx == len(input_ids):
|
||||
break
|
||||
start_idx += tokenizer.tokens_per_chunk - tokenizer.chunk_overlap
|
||||
cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids))
|
||||
chunk_ids = input_ids[start_idx:cur_idx]
|
||||
|
||||
return result
|
||||
|
||||
|
||||
# Adapted from - https://github.com/langchain-ai/langchain/blob/77b359edf5df0d37ef0d539f678cf64f5557cb54/libs/langchain/langchain/text_splitter.py#L471
|
||||
# So we could have better control over the chunking process
|
||||
def split_multiple_texts_on_tokens(
|
||||
texts: list[str], tokenizer: TokenChunkerOptions, tick: ProgressTicker
|
||||
) -> list[TextChunk]:
|
||||
"""Split multiple texts and return chunks with metadata using the tokenizer."""
|
||||
result = []
|
||||
mapped_ids = []
|
||||
|
||||
for source_doc_idx, text in enumerate(texts):
|
||||
encoded = tokenizer.encode(text)
|
||||
if tick:
|
||||
tick(1) # Track progress if tick callback is provided
|
||||
mapped_ids.append((source_doc_idx, encoded))
|
||||
|
||||
input_ids = [
|
||||
(source_doc_idx, id) for source_doc_idx, ids in mapped_ids for id in ids
|
||||
]
|
||||
|
||||
start_idx = 0
|
||||
cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids))
|
||||
chunk_ids = input_ids[start_idx:cur_idx]
|
||||
|
||||
while start_idx < len(input_ids):
|
||||
chunk_text = tokenizer.decode([id for _, id in chunk_ids])
|
||||
doc_indices = list({doc_idx for doc_idx, _ in chunk_ids})
|
||||
result.append(TextChunk(chunk_text, doc_indices, len(chunk_ids)))
|
||||
if cur_idx == len(input_ids):
|
||||
break
|
||||
start_idx += tokenizer.tokens_per_chunk - tokenizer.chunk_overlap
|
||||
cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids))
|
||||
start_idx += tokens_per_chunk - chunk_overlap
|
||||
cur_idx = min(start_idx + tokens_per_chunk, len(input_ids))
|
||||
chunk_ids = input_ids[start_idx:cur_idx]
|
||||
|
||||
return result
|
||||
|
||||
@ -8,15 +8,18 @@ import logging
|
||||
from typing import Any, cast
|
||||
|
||||
import pandas as pd
|
||||
from graphrag_chunking.add_metadata import add_metadata
|
||||
from graphrag_chunking.chunker import Chunker
|
||||
from graphrag_chunking.chunker_factory import create_chunker
|
||||
|
||||
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
|
||||
from graphrag.config.models.chunking_config import ChunkStrategyType
|
||||
from graphrag.config.models.graph_rag_config import GraphRagConfig
|
||||
from graphrag.index.operations.chunk_text.chunk_text import chunk_text
|
||||
from graphrag.index.typing.context import PipelineRunContext
|
||||
from graphrag.index.typing.workflow import WorkflowFunctionOutput
|
||||
from graphrag.index.utils.hashing import gen_sha512_hash
|
||||
from graphrag.logger.progress import progress_ticker
|
||||
from graphrag.tokenizer.get_tokenizer import get_tokenizer
|
||||
from graphrag.tokenizer.tokenizer import Tokenizer
|
||||
from graphrag.utils.storage import load_table_from_storage, write_table_to_storage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -30,17 +33,14 @@ async def run_workflow(
|
||||
logger.info("Workflow started: create_base_text_units")
|
||||
documents = await load_table_from_storage("documents", context.output_storage)
|
||||
|
||||
chunks = config.chunks
|
||||
|
||||
tokenizer = get_tokenizer(encoding_model=config.chunking.encoding_model)
|
||||
chunker = create_chunker(config.chunking, tokenizer.encode, tokenizer.decode)
|
||||
output = create_base_text_units(
|
||||
documents,
|
||||
context.callbacks,
|
||||
chunks.size,
|
||||
chunks.overlap,
|
||||
chunks.encoding_model,
|
||||
strategy=chunks.strategy,
|
||||
prepend_metadata=chunks.prepend_metadata,
|
||||
chunk_size_includes_metadata=chunks.chunk_size_includes_metadata,
|
||||
tokenizer=tokenizer,
|
||||
chunker=chunker,
|
||||
prepend_metadata=config.chunking.prepend_metadata,
|
||||
)
|
||||
|
||||
await write_table_to_storage(output, "text_units", context.output_storage)
|
||||
@ -52,70 +52,32 @@ async def run_workflow(
|
||||
def create_base_text_units(
|
||||
documents: pd.DataFrame,
|
||||
callbacks: WorkflowCallbacks,
|
||||
size: int,
|
||||
overlap: int,
|
||||
encoding_model: str,
|
||||
strategy: ChunkStrategyType,
|
||||
prepend_metadata: bool,
|
||||
chunk_size_includes_metadata: bool,
|
||||
tokenizer: Tokenizer,
|
||||
chunker: Chunker,
|
||||
prepend_metadata: bool | None = False,
|
||||
) -> pd.DataFrame:
|
||||
"""All the steps to transform base text_units."""
|
||||
documents.sort_values(by=["id"], ascending=[True], inplace=True)
|
||||
|
||||
tokenizer = get_tokenizer(encoding_model=encoding_model)
|
||||
|
||||
def chunker(row: pd.Series) -> Any:
|
||||
line_delimiter = ".\n"
|
||||
metadata_str = ""
|
||||
metadata_tokens = 0
|
||||
|
||||
if prepend_metadata and "metadata" in row:
|
||||
metadata = row["metadata"]
|
||||
if isinstance(metadata, str):
|
||||
metadata = json.loads(metadata)
|
||||
if isinstance(metadata, dict):
|
||||
metadata_str = (
|
||||
line_delimiter.join(f"{k}: {v}" for k, v in metadata.items())
|
||||
+ line_delimiter
|
||||
)
|
||||
|
||||
if chunk_size_includes_metadata:
|
||||
metadata_tokens = len(tokenizer.encode(metadata_str))
|
||||
if metadata_tokens >= size:
|
||||
message = "Metadata tokens exceeds the maximum tokens per chunk. Please increase the tokens per chunk."
|
||||
raise ValueError(message)
|
||||
|
||||
chunked = chunk_text(
|
||||
pd.DataFrame([row]).reset_index(drop=True),
|
||||
column="text",
|
||||
size=size - metadata_tokens,
|
||||
overlap=overlap,
|
||||
encoding_model=encoding_model,
|
||||
strategy=strategy,
|
||||
callbacks=callbacks,
|
||||
)[0]
|
||||
|
||||
if prepend_metadata:
|
||||
for index, chunk in enumerate(chunked):
|
||||
if isinstance(chunk, str):
|
||||
chunked[index] = metadata_str + chunk
|
||||
else:
|
||||
chunked[index] = (
|
||||
(chunk[0], metadata_str + chunk[1], chunk[2]) if chunk else None
|
||||
)
|
||||
|
||||
row["chunks"] = chunked
|
||||
return row
|
||||
total_rows = len(documents)
|
||||
tick = progress_ticker(callbacks.progress, total_rows)
|
||||
|
||||
# Track progress of row-wise apply operation
|
||||
total_rows = len(documents)
|
||||
logger.info("Starting chunking process for %d documents", total_rows)
|
||||
|
||||
def chunker_with_logging(row: pd.Series, row_index: int) -> Any:
|
||||
"""Add logging to chunker execution."""
|
||||
result = chunker(row)
|
||||
row["chunks"] = [chunk.text for chunk in chunker.chunk(row["text"])]
|
||||
|
||||
metadata = row.get("metadata", None)
|
||||
if prepend_metadata and metadata is not None:
|
||||
metadata = json.loads(metadata) if isinstance(metadata, str) else metadata
|
||||
row["chunks"] = [
|
||||
add_metadata(chunk, metadata, line_delimiter=".\n")
|
||||
for chunk in row["chunks"]
|
||||
]
|
||||
tick()
|
||||
logger.info("chunker progress: %d/%d", row_index + 1, total_rows)
|
||||
return result
|
||||
return row
|
||||
|
||||
text_units = documents.apply(
|
||||
lambda row: chunker_with_logging(row, row.name), axis=1
|
||||
|
||||
@ -9,6 +9,7 @@ from typing import Any
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from graphrag_cache.noop_cache import NoopCache
|
||||
from graphrag_chunking.chunker_factory import create_chunker
|
||||
from graphrag_storage import create_storage
|
||||
|
||||
from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks
|
||||
@ -46,8 +47,6 @@ async def load_docs_in_chunks(
|
||||
select_method: DocSelectionType,
|
||||
limit: int,
|
||||
logger: logging.Logger,
|
||||
chunk_size: int,
|
||||
overlap: int,
|
||||
n_subset_max: int = N_SUBSET_MAX,
|
||||
k: int = K,
|
||||
) -> list[str]:
|
||||
@ -63,19 +62,15 @@ async def load_docs_in_chunks(
|
||||
cache=NoopCache(),
|
||||
)
|
||||
tokenizer = get_tokenizer(embeddings_llm_settings)
|
||||
chunker = create_chunker(config.chunking, tokenizer.encode, tokenizer.decode)
|
||||
input_storage = create_storage(config.input.storage)
|
||||
input_reader = create_input_reader(config.input, input_storage)
|
||||
dataset = await input_reader.read_files()
|
||||
chunk_config = config.chunks
|
||||
chunks_df = create_base_text_units(
|
||||
documents=pd.DataFrame(dataset),
|
||||
callbacks=NoopWorkflowCallbacks(),
|
||||
size=chunk_size,
|
||||
overlap=overlap,
|
||||
encoding_model=chunk_config.encoding_model,
|
||||
strategy=chunk_config.strategy,
|
||||
prepend_metadata=chunk_config.prepend_metadata,
|
||||
chunk_size_includes_metadata=chunk_config.chunk_size_includes_metadata,
|
||||
tokenizer=tokenizer,
|
||||
chunker=chunker,
|
||||
)
|
||||
|
||||
# Depending on the select method, build the dataset
|
||||
|
||||
@ -12,7 +12,7 @@ from graphrag.tokenizer.tokenizer import Tokenizer
|
||||
|
||||
def get_tokenizer(
|
||||
model_config: LanguageModelConfig | None = None,
|
||||
encoding_model: str = ENCODING_MODEL,
|
||||
encoding_model: str | None = None,
|
||||
) -> Tokenizer:
|
||||
"""
|
||||
Get the tokenizer for the given model configuration or fallback to a tiktoken based tokenizer.
|
||||
@ -38,4 +38,6 @@ def get_tokenizer(
|
||||
|
||||
return LitellmTokenizer(model_name=model_config.model)
|
||||
|
||||
if encoding_model is None:
|
||||
encoding_model = ENCODING_MODEL
|
||||
return TiktokenTokenizer(encoding_name=encoding_model)
|
||||
|
||||
@ -53,6 +53,7 @@ package = false
|
||||
members = ["packages/*"]
|
||||
|
||||
[tool.uv.sources]
|
||||
graphrag-chunking = { workspace = true }
|
||||
graphrag-common = { workspace = true }
|
||||
graphrag-storage = { workspace = true }
|
||||
graphrag-cache = { workspace = true }
|
||||
@ -70,6 +71,7 @@ _semversioner_release = "semversioner release"
|
||||
_semversioner_changelog = "semversioner changelog > CHANGELOG.md"
|
||||
# Add more update toml tasks as packages are added
|
||||
_semversioner_update_graphrag_toml_version = "update-toml update --file packages/graphrag/pyproject.toml --path project.version --value $(uv run semversioner current-version)"
|
||||
_semversioner_update_graphrag_chunking_toml_version = "update-toml update --file packages/graphrag-chunking/pyproject.toml --path project.version --value $(uv run semversioner current-version)"
|
||||
_semversioner_update_graphrag_common_toml_version = "update-toml update --file packages/graphrag-common/pyproject.toml --path project.version --value $(uv run semversioner current-version)"
|
||||
_semversioner_update_graphrag_storage_toml_version = "update-toml update --file packages/graphrag-storage/pyproject.toml --path project.version --value $(uv run semversioner current-version)"
|
||||
_semversioner_update_graphrag_cache_toml_version = "update-toml update --file packages/graphrag-cache/pyproject.toml --path project.version --value $(uv run semversioner current-version)"
|
||||
@ -107,6 +109,7 @@ sequence = [
|
||||
# Add more update toml tasks as packages are added
|
||||
'_semversioner_update_graphrag_toml_version',
|
||||
'_semversioner_update_graphrag_common_toml_version',
|
||||
'_semversioner_update_graphrag_chunking_toml_version',
|
||||
'_semversioner_update_graphrag_storage_toml_version',
|
||||
'_semversioner_update_graphrag_cache_toml_version',
|
||||
'_semversioner_update_workspace_dependency_versions',
|
||||
|
||||
186
tests/unit/chunking/test_chunker.py
Normal file
186
tests/unit/chunking/test_chunker.py
Normal file
@ -0,0 +1,186 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
from graphrag.tokenizer.get_tokenizer import get_tokenizer
|
||||
from graphrag.tokenizer.tokenizer import Tokenizer
|
||||
from graphrag_chunking.bootstrap_nltk import bootstrap
|
||||
from graphrag_chunking.chunk_strategy_type import ChunkerType
|
||||
from graphrag_chunking.chunker_factory import create_chunker
|
||||
from graphrag_chunking.chunking_config import ChunkingConfig
|
||||
from graphrag_chunking.token_chunker import (
|
||||
split_text_on_tokens,
|
||||
)
|
||||
|
||||
|
||||
class MockTokenizer(Tokenizer):
|
||||
def encode(self, text) -> list[int]:
|
||||
return [ord(char) for char in text]
|
||||
|
||||
def decode(self, tokens) -> str:
|
||||
return "".join(chr(id) for id in tokens)
|
||||
|
||||
|
||||
class TestRunSentences:
|
||||
def setup_method(self, method):
|
||||
bootstrap()
|
||||
|
||||
def test_basic_functionality(self):
|
||||
"""Test basic sentence splitting without metadata"""
|
||||
input = "This is a test. Another sentence. And a third one!"
|
||||
chunker = create_chunker(ChunkingConfig(type=ChunkerType.Sentence))
|
||||
chunks = chunker.chunk(input)
|
||||
|
||||
assert len(chunks) == 3
|
||||
assert chunks[0].text == "This is a test."
|
||||
assert chunks[0].index == 0
|
||||
assert chunks[0].start_char == 0
|
||||
assert chunks[0].end_char == 14
|
||||
|
||||
assert chunks[1].text == "Another sentence."
|
||||
assert chunks[1].index == 1
|
||||
assert chunks[1].start_char == 16
|
||||
assert chunks[1].end_char == 32
|
||||
|
||||
assert chunks[2].text == "And a third one!"
|
||||
assert chunks[2].index == 2
|
||||
assert chunks[2].start_char == 34
|
||||
assert chunks[2].end_char == 49
|
||||
|
||||
def test_mixed_whitespace_handling(self):
|
||||
"""Test input with irregular whitespace"""
|
||||
input = " Sentence with spaces. Another one! "
|
||||
chunker = create_chunker(ChunkingConfig(type=ChunkerType.Sentence))
|
||||
chunks = chunker.chunk(input)
|
||||
|
||||
assert len(chunks) == 2
|
||||
assert chunks[0].text == "Sentence with spaces."
|
||||
assert chunks[0].index == 0
|
||||
assert chunks[0].start_char == 3
|
||||
assert chunks[0].end_char == 23
|
||||
|
||||
assert chunks[1].text == "Another one!"
|
||||
assert chunks[1].index == 1
|
||||
assert chunks[1].start_char == 25
|
||||
assert chunks[1].end_char == 36
|
||||
|
||||
|
||||
class TestRunTokens:
|
||||
@patch("tiktoken.get_encoding")
|
||||
def test_basic_functionality(self, mock_get_encoding):
|
||||
mock_encoder = Mock()
|
||||
mock_encoder.encode.side_effect = lambda x: list(x.encode())
|
||||
mock_encoder.decode.side_effect = lambda x: bytes(x).decode()
|
||||
mock_get_encoding.return_value = mock_encoder
|
||||
|
||||
input = "Marley was dead: to begin with. There is no doubt whatever about that. The register of his burial was signed by the clergyman, the clerk, the undertaker, and the chief mourner. Scrooge signed it. And Scrooge's name was good upon 'Change, for anything he chose to put his hand to."
|
||||
|
||||
config = ChunkingConfig(
|
||||
size=5,
|
||||
overlap=1,
|
||||
encoding_model="fake-encoding",
|
||||
type=ChunkerType.Tokens,
|
||||
)
|
||||
|
||||
chunker = create_chunker(config, mock_encoder.encode, mock_encoder.decode)
|
||||
chunks = chunker.chunk(input)
|
||||
|
||||
assert len(chunks) > 0
|
||||
|
||||
|
||||
def test_split_text_str_empty():
|
||||
tokenizer = get_tokenizer()
|
||||
result = split_text_on_tokens(
|
||||
"",
|
||||
chunk_size=5,
|
||||
chunk_overlap=2,
|
||||
encode=tokenizer.encode,
|
||||
decode=tokenizer.decode,
|
||||
)
|
||||
assert result == []
|
||||
|
||||
|
||||
def test_split_text_on_tokens():
|
||||
text = "This is a test text, meaning to be taken seriously by this test only."
|
||||
mocked_tokenizer = MockTokenizer()
|
||||
expected_splits = [
|
||||
"This is a ",
|
||||
"is a test ",
|
||||
"test text,",
|
||||
"text, mean",
|
||||
" meaning t",
|
||||
"ing to be ",
|
||||
"o be taken",
|
||||
"taken seri", # cspell:disable-line
|
||||
" seriously",
|
||||
"ously by t", # cspell:disable-line
|
||||
" by this t",
|
||||
"his test o",
|
||||
"est only.",
|
||||
]
|
||||
|
||||
result = split_text_on_tokens(
|
||||
text=text,
|
||||
chunk_overlap=5,
|
||||
chunk_size=10,
|
||||
decode=mocked_tokenizer.decode,
|
||||
encode=lambda text: mocked_tokenizer.encode(text),
|
||||
)
|
||||
assert result == expected_splits
|
||||
|
||||
|
||||
def test_split_text_on_tokens_one_overlap():
|
||||
text = "This is a test text, meaning to be taken seriously by this test only."
|
||||
tokenizer = get_tokenizer(encoding_model="o200k_base")
|
||||
|
||||
expected_splits = [
|
||||
"This is",
|
||||
" is a",
|
||||
" a test",
|
||||
" test text",
|
||||
" text,",
|
||||
", meaning",
|
||||
" meaning to",
|
||||
" to be",
|
||||
" be taken",
|
||||
" taken seriously",
|
||||
" seriously by",
|
||||
" by this",
|
||||
" this test",
|
||||
" test only",
|
||||
" only.",
|
||||
]
|
||||
|
||||
result = split_text_on_tokens(
|
||||
text=text,
|
||||
chunk_size=2,
|
||||
chunk_overlap=1,
|
||||
decode=tokenizer.decode,
|
||||
encode=tokenizer.encode,
|
||||
)
|
||||
assert result == expected_splits
|
||||
|
||||
|
||||
def test_split_text_on_tokens_no_overlap():
|
||||
text = "This is a test text, meaning to be taken seriously by this test only."
|
||||
tokenizer = get_tokenizer(encoding_model="o200k_base")
|
||||
|
||||
expected_splits = [
|
||||
"This is a",
|
||||
" test text,",
|
||||
" meaning to be",
|
||||
" taken seriously by",
|
||||
" this test only",
|
||||
".",
|
||||
]
|
||||
|
||||
result = split_text_on_tokens(
|
||||
text=text,
|
||||
chunk_size=3,
|
||||
chunk_overlap=0,
|
||||
decode=tokenizer.decode,
|
||||
encode=tokenizer.encode,
|
||||
)
|
||||
|
||||
assert result == expected_splits
|
||||
43
tests/unit/chunking/test_prepend_metadata.py
Normal file
43
tests/unit/chunking/test_prepend_metadata.py
Normal file
@ -0,0 +1,43 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
from graphrag_chunking.add_metadata import add_metadata
|
||||
|
||||
|
||||
def test_add_metadata_one_row():
|
||||
"""Test prepending metadata to chunks"""
|
||||
chunks = ["This is a test.", "Another sentence."]
|
||||
metadata = {"message": "hello"}
|
||||
results = [add_metadata(chunk, metadata) for chunk in chunks]
|
||||
assert results[0] == "message: hello\nThis is a test."
|
||||
assert results[1] == "message: hello\nAnother sentence."
|
||||
|
||||
|
||||
def test_add_metadata_one_row_append():
|
||||
"""Test prepending metadata to chunks"""
|
||||
chunks = ["This is a test.", "Another sentence."]
|
||||
metadata = {"message": "hello"}
|
||||
results = [add_metadata(chunk, metadata, append=True) for chunk in chunks]
|
||||
assert results[0] == "This is a test.message: hello\n"
|
||||
assert results[1] == "Another sentence.message: hello\n"
|
||||
|
||||
|
||||
def test_add_metadata_multiple_rows():
|
||||
"""Test prepending metadata to chunks"""
|
||||
chunks = ["This is a test.", "Another sentence."]
|
||||
metadata = {"message": "hello", "tag": "first"}
|
||||
results = [add_metadata(chunk, metadata) for chunk in chunks]
|
||||
assert results[0] == "message: hello\ntag: first\nThis is a test."
|
||||
assert results[1] == "message: hello\ntag: first\nAnother sentence."
|
||||
|
||||
|
||||
def test_add_metadata_custom_delimiters():
|
||||
"""Test prepending metadata to chunks"""
|
||||
chunks = ["This is a test.", "Another sentence."]
|
||||
metadata = {"message": "hello", "tag": "first"}
|
||||
results = [
|
||||
add_metadata(chunk, metadata, delimiter="-", line_delimiter="_")
|
||||
for chunk in chunks
|
||||
]
|
||||
assert results[0] == "message-hello_tag-first_This is a test."
|
||||
assert results[1] == "message-hello_tag-first_Another sentence."
|
||||
@ -5,7 +5,6 @@ from dataclasses import asdict
|
||||
|
||||
import graphrag.config.defaults as defs
|
||||
from graphrag.config.models.basic_search_config import BasicSearchConfig
|
||||
from graphrag.config.models.chunking_config import ChunkingConfig
|
||||
from graphrag.config.models.cluster_graph_config import ClusterGraphConfig
|
||||
from graphrag.config.models.community_reports_config import CommunityReportsConfig
|
||||
from graphrag.config.models.drift_search_config import DRIFTSearchConfig
|
||||
@ -29,6 +28,7 @@ from graphrag.config.models.summarize_descriptions_config import (
|
||||
from graphrag.config.models.vector_store_config import VectorStoreConfig
|
||||
from graphrag.index.input.input_config import InputConfig
|
||||
from graphrag_cache import CacheConfig
|
||||
from graphrag_chunking.chunking_config import ChunkingConfig
|
||||
from graphrag_storage import StorageConfig
|
||||
from pydantic import BaseModel
|
||||
|
||||
@ -164,10 +164,9 @@ def assert_text_embedding_configs(
|
||||
def assert_chunking_configs(actual: ChunkingConfig, expected: ChunkingConfig) -> None:
|
||||
assert actual.size == expected.size
|
||||
assert actual.overlap == expected.overlap
|
||||
assert actual.strategy == expected.strategy
|
||||
assert actual.type == expected.type
|
||||
assert actual.encoding_model == expected.encoding_model
|
||||
assert actual.prepend_metadata == expected.prepend_metadata
|
||||
assert actual.chunk_size_includes_metadata == expected.chunk_size_includes_metadata
|
||||
|
||||
|
||||
def assert_snapshots_configs(
|
||||
@ -345,7 +344,7 @@ def assert_graphrag_configs(actual: GraphRagConfig, expected: GraphRagConfig) ->
|
||||
assert_cache_configs(actual.cache, expected.cache)
|
||||
assert_input_configs(actual.input, expected.input)
|
||||
assert_text_embedding_configs(actual.embed_text, expected.embed_text)
|
||||
assert_chunking_configs(actual.chunks, expected.chunks)
|
||||
assert_chunking_configs(actual.chunking, expected.chunking)
|
||||
assert_snapshots_configs(actual.snapshots, expected.snapshots)
|
||||
assert_extract_graph_configs(actual.extract_graph, expected.extract_graph)
|
||||
assert_extract_graph_nlp_configs(
|
||||
|
||||
@ -1,180 +0,0 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
|
||||
from unittest import mock
|
||||
from unittest.mock import ANY, Mock
|
||||
|
||||
import pandas as pd
|
||||
import pytest
|
||||
from graphrag.config.enums import ChunkStrategyType
|
||||
from graphrag.index.operations.chunk_text.chunk_text import (
|
||||
_get_num_total,
|
||||
_load_strategy,
|
||||
_run_strategy,
|
||||
chunk_text,
|
||||
)
|
||||
from graphrag.index.operations.chunk_text.typing import (
|
||||
TextChunk,
|
||||
)
|
||||
|
||||
|
||||
def test_get_num_total_default():
|
||||
output = pd.DataFrame({"column": ["a", "b", "c"]})
|
||||
|
||||
total = _get_num_total(output, "column")
|
||||
assert total == 3
|
||||
|
||||
|
||||
def test_get_num_total_array():
|
||||
output = pd.DataFrame({"column": [["a", "b", "c"], ["x", "y"]]})
|
||||
|
||||
total = _get_num_total(output, "column")
|
||||
assert total == 5
|
||||
|
||||
|
||||
def test_load_strategy_tokens():
|
||||
strategy_type = ChunkStrategyType.tokens
|
||||
|
||||
strategy_loaded = _load_strategy(strategy_type)
|
||||
|
||||
assert strategy_loaded.__name__ == "run_tokens"
|
||||
|
||||
|
||||
def test_load_strategy_sentence():
|
||||
strategy_type = ChunkStrategyType.sentence
|
||||
|
||||
strategy_loaded = _load_strategy(strategy_type)
|
||||
|
||||
assert strategy_loaded.__name__ == "run_sentences"
|
||||
|
||||
|
||||
def test_load_strategy_none():
|
||||
strategy_type = ChunkStrategyType
|
||||
|
||||
with pytest.raises(
|
||||
ValueError, match="Unknown strategy: <enum 'ChunkStrategyType'>"
|
||||
):
|
||||
_load_strategy(strategy_type) # type: ignore
|
||||
|
||||
|
||||
def test_run_strategy_str():
|
||||
input = "text test for run strategy"
|
||||
config = Mock()
|
||||
tick = Mock()
|
||||
strategy_mocked = Mock()
|
||||
|
||||
strategy_mocked.return_value = [
|
||||
TextChunk(
|
||||
text_chunk="text test for run strategy",
|
||||
source_doc_indices=[0],
|
||||
)
|
||||
]
|
||||
|
||||
runned = _run_strategy(strategy_mocked, input, config, tick)
|
||||
assert runned == ["text test for run strategy"]
|
||||
|
||||
|
||||
def test_run_strategy_arr_str():
|
||||
input = ["text test for run strategy", "use for strategy"]
|
||||
config = Mock()
|
||||
tick = Mock()
|
||||
strategy_mocked = Mock()
|
||||
|
||||
strategy_mocked.return_value = [
|
||||
TextChunk(
|
||||
text_chunk="text test for run strategy", source_doc_indices=[0], n_tokens=5
|
||||
),
|
||||
TextChunk(text_chunk="use for strategy", source_doc_indices=[1], n_tokens=3),
|
||||
]
|
||||
|
||||
expected = [
|
||||
"text test for run strategy",
|
||||
"use for strategy",
|
||||
]
|
||||
|
||||
runned = _run_strategy(strategy_mocked, input, config, tick)
|
||||
assert runned == expected
|
||||
|
||||
|
||||
def test_run_strategy_arr_tuple():
|
||||
input = [("text test for run strategy", "3"), ("use for strategy", "5")]
|
||||
config = Mock()
|
||||
tick = Mock()
|
||||
strategy_mocked = Mock()
|
||||
|
||||
strategy_mocked.return_value = [
|
||||
TextChunk(
|
||||
text_chunk="text test for run strategy", source_doc_indices=[0], n_tokens=5
|
||||
),
|
||||
TextChunk(text_chunk="use for strategy", source_doc_indices=[1], n_tokens=3),
|
||||
]
|
||||
|
||||
expected = [
|
||||
(
|
||||
["text test for run strategy"],
|
||||
"text test for run strategy",
|
||||
5,
|
||||
),
|
||||
(
|
||||
["use for strategy"],
|
||||
"use for strategy",
|
||||
3,
|
||||
),
|
||||
]
|
||||
|
||||
runned = _run_strategy(strategy_mocked, input, config, tick)
|
||||
assert runned == expected
|
||||
|
||||
|
||||
def test_run_strategy_arr_tuple_same_doc():
|
||||
input = [("text test for run strategy", "3"), ("use for strategy", "5")]
|
||||
config = Mock()
|
||||
tick = Mock()
|
||||
strategy_mocked = Mock()
|
||||
|
||||
strategy_mocked.return_value = [
|
||||
TextChunk(
|
||||
text_chunk="text test for run strategy", source_doc_indices=[0], n_tokens=5
|
||||
),
|
||||
TextChunk(text_chunk="use for strategy", source_doc_indices=[0], n_tokens=3),
|
||||
]
|
||||
|
||||
expected = [
|
||||
(
|
||||
["text test for run strategy"],
|
||||
"text test for run strategy",
|
||||
5,
|
||||
),
|
||||
(
|
||||
["text test for run strategy"],
|
||||
"use for strategy",
|
||||
3,
|
||||
),
|
||||
]
|
||||
|
||||
runned = _run_strategy(strategy_mocked, input, config, tick)
|
||||
assert runned == expected
|
||||
|
||||
|
||||
@mock.patch("graphrag.index.operations.chunk_text.chunk_text._load_strategy")
|
||||
@mock.patch("graphrag.index.operations.chunk_text.chunk_text._run_strategy")
|
||||
@mock.patch("graphrag.index.operations.chunk_text.chunk_text.progress_ticker")
|
||||
def test_chunk_text(mock_progress_ticker, mock_run_strategy, mock_load_strategy):
|
||||
input_data = pd.DataFrame({"name": ["The Shining"]})
|
||||
column = "name"
|
||||
size = 10
|
||||
overlap = 2
|
||||
encoding_model = "model"
|
||||
strategy = ChunkStrategyType.sentence
|
||||
callbacks = Mock()
|
||||
callbacks.progress = Mock()
|
||||
|
||||
mock_load_strategy.return_value = Mock()
|
||||
mock_progress_ticker.return_value = Mock()
|
||||
|
||||
chunk_text(input_data, column, size, overlap, encoding_model, strategy, callbacks)
|
||||
|
||||
mock_run_strategy.assert_called_with(
|
||||
mock_load_strategy(), "The Shining", ANY, mock_progress_ticker.return_value
|
||||
)
|
||||
@ -1,127 +0,0 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
from graphrag.config.models.chunking_config import ChunkingConfig
|
||||
from graphrag.index.operations.chunk_text.bootstrap import bootstrap
|
||||
from graphrag.index.operations.chunk_text.strategies import (
|
||||
run_sentences,
|
||||
run_tokens,
|
||||
)
|
||||
from graphrag.index.operations.chunk_text.typing import TextChunk
|
||||
from graphrag.tokenizer.get_tokenizer import get_tokenizer
|
||||
|
||||
|
||||
class TestRunSentences:
|
||||
def setup_method(self, method):
|
||||
bootstrap()
|
||||
|
||||
def test_basic_functionality(self):
|
||||
"""Test basic sentence splitting without metadata"""
|
||||
input = ["This is a test. Another sentence."]
|
||||
tick = Mock()
|
||||
chunks = list(run_sentences(input, ChunkingConfig(), tick))
|
||||
|
||||
assert len(chunks) == 2
|
||||
assert chunks[0].text_chunk == "This is a test."
|
||||
assert chunks[1].text_chunk == "Another sentence."
|
||||
assert all(c.source_doc_indices == [0] for c in chunks)
|
||||
tick.assert_called_once_with(1)
|
||||
|
||||
def test_multiple_documents(self):
|
||||
"""Test processing multiple input documents"""
|
||||
input = ["First. Document.", "Second. Doc."]
|
||||
tick = Mock()
|
||||
chunks = list(run_sentences(input, ChunkingConfig(), tick))
|
||||
|
||||
assert len(chunks) == 4
|
||||
assert chunks[0].source_doc_indices == [0]
|
||||
assert chunks[2].source_doc_indices == [1]
|
||||
assert tick.call_count == 2
|
||||
|
||||
def test_mixed_whitespace_handling(self):
|
||||
"""Test input with irregular whitespace"""
|
||||
input = [" Sentence with spaces. Another one! "]
|
||||
chunks = list(run_sentences(input, ChunkingConfig(), Mock()))
|
||||
assert chunks[0].text_chunk == " Sentence with spaces."
|
||||
assert chunks[1].text_chunk == "Another one!"
|
||||
|
||||
|
||||
class TestRunTokens:
|
||||
@patch("tiktoken.get_encoding")
|
||||
def test_basic_functionality(self, mock_get_encoding):
|
||||
mock_encoder = Mock()
|
||||
mock_encoder.encode.side_effect = lambda x: list(x.encode())
|
||||
mock_encoder.decode.side_effect = lambda x: bytes(x).decode()
|
||||
mock_get_encoding.return_value = mock_encoder
|
||||
|
||||
# Input and config
|
||||
input = [
|
||||
"Marley was dead: to begin with. There is no doubt whatever about that. The register of his burial was signed by the clergyman, the clerk, the undertaker, and the chief mourner. Scrooge signed it. And Scrooge's name was good upon 'Change, for anything he chose to put his hand to."
|
||||
]
|
||||
config = ChunkingConfig(size=5, overlap=1, encoding_model="fake-encoding")
|
||||
tick = Mock()
|
||||
|
||||
# Run the function
|
||||
chunks = list(run_tokens(input, config, tick))
|
||||
|
||||
# Verify output
|
||||
assert len(chunks) > 0
|
||||
assert all(isinstance(chunk, TextChunk) for chunk in chunks)
|
||||
tick.assert_called_once_with(1)
|
||||
|
||||
@patch("tiktoken.get_encoding")
|
||||
def test_non_string_input(self, mock_get_encoding):
|
||||
"""Test handling of non-string input (e.g., numbers)."""
|
||||
mock_encoder = Mock()
|
||||
mock_encoder.encode.side_effect = lambda x: list(str(x).encode())
|
||||
mock_encoder.decode.side_effect = lambda x: bytes(x).decode()
|
||||
mock_get_encoding.return_value = mock_encoder
|
||||
|
||||
input = [123] # Non-string input
|
||||
config = ChunkingConfig(size=5, overlap=1, encoding_model="fake-encoding")
|
||||
tick = Mock()
|
||||
|
||||
chunks = list(run_tokens(input, config, tick)) # type: ignore
|
||||
|
||||
# Verify non-string input is handled
|
||||
assert len(chunks) > 0
|
||||
assert "123" in chunks[0].text_chunk
|
||||
|
||||
|
||||
@patch("tiktoken.get_encoding")
|
||||
def test_get_encoding_fn_encode(mock_get_encoding):
|
||||
# Create a mock encoding object with encode and decode methods
|
||||
mock_encoding = Mock()
|
||||
mock_encoding.encode = Mock(return_value=[1, 2, 3])
|
||||
mock_encoding.decode = Mock(return_value="decoded text")
|
||||
|
||||
# Configure the mock_get_encoding to return the mock encoding object
|
||||
mock_get_encoding.return_value = mock_encoding
|
||||
|
||||
# Call the function to get encode and decode functions
|
||||
tokenizer = get_tokenizer(encoding_model="mock_encoding")
|
||||
|
||||
# Test the encode function
|
||||
encoded_text = tokenizer.encode("test text")
|
||||
assert encoded_text == [1, 2, 3]
|
||||
mock_encoding.encode.assert_called_once_with("test text")
|
||||
|
||||
|
||||
@patch("tiktoken.get_encoding")
|
||||
def test_get_encoding_fn_decode(mock_get_encoding):
|
||||
# Create a mock encoding object with encode and decode methods
|
||||
mock_encoding = Mock()
|
||||
mock_encoding.encode = Mock(return_value=[1, 2, 3])
|
||||
mock_encoding.decode = Mock(return_value="decoded text")
|
||||
|
||||
# Configure the mock_get_encoding to return the mock encoding object
|
||||
mock_get_encoding.return_value = mock_encoding
|
||||
|
||||
# Call the function to get encode and decode functions
|
||||
tokenizer = get_tokenizer(encoding_model="mock_encoding")
|
||||
|
||||
decoded_text = tokenizer.decode([1, 2, 3])
|
||||
assert decoded_text == "decoded text"
|
||||
mock_encoding.decode.assert_called_once_with([1, 2, 3])
|
||||
@ -1,2 +0,0 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
@ -1,169 +0,0 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
from unittest import mock
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
import tiktoken
|
||||
from graphrag.index.text_splitting.text_splitting import (
|
||||
NoopTextSplitter,
|
||||
TokenChunkerOptions,
|
||||
TokenTextSplitter,
|
||||
split_multiple_texts_on_tokens,
|
||||
split_single_text_on_tokens,
|
||||
)
|
||||
|
||||
|
||||
def test_noop_text_splitter() -> None:
|
||||
splitter = NoopTextSplitter()
|
||||
|
||||
assert list(splitter.split_text("some text")) == ["some text"]
|
||||
assert list(splitter.split_text(["some", "text"])) == ["some", "text"]
|
||||
|
||||
|
||||
class MockTokenizer:
|
||||
def encode(self, text):
|
||||
return [ord(char) for char in text]
|
||||
|
||||
def decode(self, token_ids):
|
||||
return "".join(chr(id) for id in token_ids)
|
||||
|
||||
|
||||
def test_split_text_str_empty():
|
||||
splitter = TokenTextSplitter(chunk_size=5, chunk_overlap=2)
|
||||
result = splitter.split_text("")
|
||||
|
||||
assert result == []
|
||||
|
||||
|
||||
def test_split_text_str_bool():
|
||||
splitter = TokenTextSplitter(chunk_size=5, chunk_overlap=2)
|
||||
result = splitter.split_text(None) # type: ignore
|
||||
|
||||
assert result == []
|
||||
|
||||
|
||||
def test_split_text_str_int():
|
||||
splitter = TokenTextSplitter(chunk_size=5, chunk_overlap=2)
|
||||
with pytest.raises(TypeError):
|
||||
splitter.split_text(123) # type: ignore
|
||||
|
||||
|
||||
@mock.patch("graphrag.index.text_splitting.text_splitting.split_single_text_on_tokens")
|
||||
def test_split_text_large_input(mock_split):
|
||||
large_text = "a" * 10_000
|
||||
mock_split.return_value = ["chunk"] * 2_000
|
||||
splitter = TokenTextSplitter(chunk_size=5, chunk_overlap=2)
|
||||
|
||||
result = splitter.split_text(large_text)
|
||||
|
||||
assert len(result) == 2_000, "Large input was not split correctly"
|
||||
mock_split.assert_called_once()
|
||||
|
||||
|
||||
@mock.patch("graphrag.index.text_splitting.text_splitting.split_single_text_on_tokens")
|
||||
@mock.patch("graphrag.index.text_splitting.text_splitting.TokenChunkerOptions")
|
||||
def test_token_text_splitter(mock_tokenizer, mock_split_text):
|
||||
text = "chunk1 chunk2 chunk3"
|
||||
expected_chunks = ["chunk1", "chunk2", "chunk3"]
|
||||
|
||||
mocked_tokenizer = MagicMock()
|
||||
mock_tokenizer.return_value = mocked_tokenizer
|
||||
mock_split_text.return_value = expected_chunks
|
||||
|
||||
splitter = TokenTextSplitter()
|
||||
|
||||
splitter.split_text(["chunk1", "chunk2", "chunk3"])
|
||||
|
||||
mock_split_text.assert_called_once_with(text=text, tokenizer=mocked_tokenizer)
|
||||
|
||||
|
||||
def test_split_single_text_on_tokens():
|
||||
text = "This is a test text, meaning to be taken seriously by this test only."
|
||||
mocked_tokenizer = MockTokenizer()
|
||||
tokenizer = TokenChunkerOptions(
|
||||
chunk_overlap=5,
|
||||
tokens_per_chunk=10,
|
||||
decode=mocked_tokenizer.decode,
|
||||
encode=lambda text: mocked_tokenizer.encode(text),
|
||||
)
|
||||
|
||||
expected_splits = [
|
||||
"This is a ",
|
||||
"is a test ",
|
||||
"test text,",
|
||||
"text, mean",
|
||||
" meaning t",
|
||||
"ing to be ",
|
||||
"o be taken",
|
||||
"taken seri", # cspell:disable-line
|
||||
" seriously",
|
||||
"ously by t", # cspell:disable-line
|
||||
" by this t",
|
||||
"his test o",
|
||||
"est only.",
|
||||
]
|
||||
|
||||
result = split_single_text_on_tokens(text=text, tokenizer=tokenizer)
|
||||
assert result == expected_splits
|
||||
|
||||
|
||||
def test_split_multiple_texts_on_tokens():
|
||||
texts = [
|
||||
"This is a test text, meaning to be taken seriously by this test only.",
|
||||
"This is th second text, meaning to be taken seriously by this test only.",
|
||||
]
|
||||
|
||||
mocked_tokenizer = MockTokenizer()
|
||||
mock_tick = MagicMock()
|
||||
tokenizer = TokenChunkerOptions(
|
||||
chunk_overlap=5,
|
||||
tokens_per_chunk=10,
|
||||
decode=mocked_tokenizer.decode,
|
||||
encode=lambda text: mocked_tokenizer.encode(text),
|
||||
)
|
||||
|
||||
split_multiple_texts_on_tokens(texts, tokenizer, tick=mock_tick)
|
||||
mock_tick.assert_called()
|
||||
|
||||
|
||||
def test_split_single_text_on_tokens_no_overlap():
|
||||
text = "This is a test text, meaning to be taken seriously by this test only."
|
||||
enc = tiktoken.get_encoding("cl100k_base")
|
||||
|
||||
def encode(text: str) -> list[int]:
|
||||
if not isinstance(text, str):
|
||||
text = f"{text}"
|
||||
return enc.encode(text)
|
||||
|
||||
def decode(tokens: list[int]) -> str:
|
||||
return enc.decode(tokens)
|
||||
|
||||
tokenizer = TokenChunkerOptions(
|
||||
chunk_overlap=1,
|
||||
tokens_per_chunk=2,
|
||||
decode=decode,
|
||||
encode=lambda text: encode(text),
|
||||
)
|
||||
|
||||
expected_splits = [
|
||||
"This is",
|
||||
" is a",
|
||||
" a test",
|
||||
" test text",
|
||||
" text,",
|
||||
", meaning",
|
||||
" meaning to",
|
||||
" to be",
|
||||
" be taken", # cspell:disable-line
|
||||
" taken seriously", # cspell:disable-line
|
||||
" seriously by",
|
||||
" by this", # cspell:disable-line
|
||||
" this test",
|
||||
" test only",
|
||||
" only.",
|
||||
]
|
||||
|
||||
result = split_single_text_on_tokens(text=text, tokenizer=tokenizer)
|
||||
assert result == expected_splits
|
||||
@ -35,7 +35,7 @@ async def test_create_base_text_units_metadata():
|
||||
|
||||
config = GraphRagConfig(models=DEFAULT_MODEL_CONFIG) # type: ignore
|
||||
config.input.metadata = ["title"]
|
||||
config.chunks.prepend_metadata = True
|
||||
config.chunking.prepend_metadata = True
|
||||
|
||||
await update_document_metadata(config.input.metadata, context)
|
||||
|
||||
@ -43,22 +43,3 @@ async def test_create_base_text_units_metadata():
|
||||
|
||||
actual = await load_table_from_storage("text_units", context.output_storage)
|
||||
compare_outputs(actual, expected, ["text", "document_id", "n_tokens"])
|
||||
|
||||
|
||||
async def test_create_base_text_units_metadata_included_in_chunk():
|
||||
expected = load_test_table("text_units_metadata_included_chunk")
|
||||
|
||||
context = await create_test_context()
|
||||
|
||||
config = GraphRagConfig(models=DEFAULT_MODEL_CONFIG) # type: ignore
|
||||
config.input.metadata = ["title"]
|
||||
config.chunks.prepend_metadata = True
|
||||
config.chunks.chunk_size_includes_metadata = True
|
||||
|
||||
await update_document_metadata(config.input.metadata, context)
|
||||
|
||||
await run_workflow(config, context)
|
||||
|
||||
actual = await load_table_from_storage("text_units", context.output_storage)
|
||||
# only check the columns from the base workflow - our expected table is the final and will have more
|
||||
compare_outputs(actual, expected, columns=["text", "document_id", "n_tokens"])
|
||||
|
||||
Loading…
Reference in New Issue
Block a user