Merge branch 'v3/main' into input-factory

This commit is contained in:
Nathan Evans 2026-01-06 15:42:16 -08:00
commit 8e3c7170f7
44 changed files with 1985 additions and 2296 deletions

View File

@ -100,12 +100,11 @@ These settings configure how we parse documents into text chunks. This is necess
#### Fields
- `strategy` **str**[tokens|sentences] - How to chunk the text.
- `size` **int** - The max chunk size in tokens.
- `overlap` **int** - The chunk overlap in tokens.
- `strategy` **str**[tokens|sentences] - How to chunk the text.
- `encoding_model` **str** - The text encoding model to use for splitting on token boundaries.
- `prepend_metadata` **bool** - Determines if metadata values should be added at the beginning of each chunk. Default=`False`.
- `chunk_size_includes_metadata` **bool** - Specifies whether the chunk size calculation should include metadata tokens. Default=`False`.
## Outputs and Storage

View File

@ -82,10 +82,9 @@ As described above, when documents are imported you can specify a list of `metad
### Chunking Config
Next, the `chunks` block needs to instruct the chunker how to handle this metadata when creating text units. By default, it is ignored. We have two settings to include it:
Next, the `chunks` block needs to instruct the chunker how to handle this metadata when creating text units. By default, it is ignored. We have the following setting to include it:
- `prepend_metadata`. This instructs the importer to copy the contents of the `metadata` column for each row into the start of every single text chunk. This metadata is copied as key: value pairs on new lines.
- `chunk_size_includes_metadata`: This tells the chunker how to compute the chunk size when metadata is included. By default, we create the text units using your specified `chunk_size` *and then* prepend the metadata. This means that the final text unit lengths may be longer than your configured `chunk_size`, and it will vary based on the length of the metadata for each document. When this setting is `True`, we will compute the raw text using the remainder after measuring the metadata length so that the resulting text units always comply with your configured `chunk_size`.
### Examples
@ -124,7 +123,6 @@ chunks:
size: 100
overlap: 0
prepend_metadata: true
chunk_size_includes_metadata: false
```
Documents DataFrame
@ -162,54 +160,6 @@ US to lift most federal COVID-19 vaccine mandates,WASHINGTON (AP) The Biden admi
NY lawmakers begin debating budget 1 month after due date,ALBANY, N.Y. (AP) New York lawmakers began voting Monday on a $229 billion state budget due a month ago that would raise the minimum wage, crack down on illicit pot shops and ban gas stoves and furnaces in new buildings. Negotiations among Gov. Kathy Hochul and her fellow Democrats in control of the Legislature dragged on past the April 1 budget deadline, largely because of disagreements over changes to the bail law and other policy proposals included in the spending plan. Floor debates on some budget bills began Monday. State Senate Majority Leader Andrea Stewart-Cousins said she expected voting to be wrapped up Tuesday for a budget she said contains "significant wins" for New Yorkers. "I would have liked to have done this sooner. I think we would all agree to that," Cousins told reporters before voting began. "This has been a very policy-laden budget and a lot of the policies had to parsed through." Hochul was able to push through a change to the bail law that will eliminate the standard that requires judges to prescribe the "least restrictive" means to ensure defendants return to court. Hochul said judges needed the extra discretion. Some liberal lawmakers argued that it would undercut the sweeping bail reforms approved in 2019 and result in more people with low incomes and people of color in pretrial detention. Here are some other policy provisions that will be included in the budget, according to state officials. The minimum wage would be raised to $17 in New York City and some of its suburbs and $16 in the rest of the state by 2026. That's up from $15 in the city and $14.20 upstate.
--
settings.yaml
```yaml
input:
file_type: csv
title_column: headline
text_column: article
metadata: [headline]
chunks:
size: 50
overlap: 5
prepend_metadata: true
chunk_size_includes_metadata: true
```
Documents DataFrame
| id | title | text | creation_date | metadata |
| --------------------- | --------------------------------------------------------- | ------------------------ | ----------------------------- | --------------------------------------------------------------------------- |
| (generated from text) | US to lift most federal COVID-19 vaccine mandates | (article column content) | (create date of articles.csv) | { "headline": "US to lift most federal COVID-19 vaccine mandates" } |
| (generated from text) | NY lawmakers begin debating budget 1 month after due date | (article column content) | (create date of articles.csv) | { "headline": "NY lawmakers begin debating budget 1 month after due date" } |
Raw Text Chunks
| content | length |
| ------- | ------: |
| title: US to lift most federal COVID-19 vaccine mandates<br>WASHINGTON (AP) The Biden administration will end most of the last remaining federal COVID-19 vaccine requirements next week when the national public health emergency for the coronavirus ends, the White House said Monday. Vaccine requirements for federal workers and federal contractors, | 50 |
| title: US to lift most federal COVID-19 vaccine mandates<br>federal workers and federal contractors as well as foreign air travelers to the U.S., will end May 11. The government is also beginning the process of lifting shot requirements for Head Start educators, healthcare workers, and noncitizens at U.S. land borders. | 50 |
| title: US to lift most federal COVID-19 vaccine mandates<br>noncitizens at U.S. land borders. The requirements are among the last vestiges of some of the more coercive measures taken by the federal government to promote vaccination as the deadly virus raged, and their end marks the latest display of how | 50 |
| title: US to lift most federal COVID-19 vaccine mandates<br>the latest display of how President Joe Biden's administration is moving to treat COVID-19 as a routine, endemic illness. "While I believe that these vaccine mandates had a tremendous beneficial impact, we are now at a point where we think that | 50 |
| title: US to lift most federal COVID-19 vaccine mandates<br>point where we think that it makes a lot of sense to pull these requirements down," White House COVID-19 coordinator Dr. Ashish Jha told The Associated Press on Monday. | 38 |
| title: NY lawmakers begin debating budget 1 month after due date<br>ALBANY, N.Y. (AP) New York lawmakers began voting Monday on a $229 billion state budget due a month ago that would raise the minimum wage, crack down on illicit pot shops and ban gas stoves and furnaces in new | 50 |
| title: NY lawmakers begin debating budget 1 month after due date<br>stoves and furnaces in new buildings. Negotiations among Gov. Kathy Hochul and her fellow Democrats in control of the Legislature dragged on past the April 1 budget deadline, largely because of disagreements over changes to the bail law and | 50 |
| title: NY lawmakers begin debating budget 1 month after due date<br>to the bail law and other policy proposals included in the spending plan. Floor debates on some budget bills began Monday. State Senate Majority Leader Andrea Stewart-Cousins said she expected voting to be wrapped up Tuesday for a budget | 50 |
|title: NY lawmakers begin debating budget 1 month after due date<br>up Tuesday for a budget she said contains "significant wins" for New Yorkers. "I would have liked to have done this sooner. I think we would all agree to that," Cousins told reporters before voting began. "This has been | 50 |
| title: NY lawmakers begin debating budget 1 month after due date<br>voting began. "This has been a very policy-laden budget and a lot of the policies had to parsed through." Hochul was able to push through a change to the bail law that will eliminate the standard that requires judges | 50 |
| title: NY lawmakers begin debating budget 1 month after due date<br>the standard that requires judges to prescribe the "least restrictive" means to ensure defendants return to court. Hochul said judges needed the extra discretion. Some liberal lawmakers argued that it would undercut the sweeping bail reforms approved in 2019 | 50 |
| title: NY lawmakers begin debating budget 1 month after due date<br>bail reforms approved in 2019 and result in more people with low incomes and people of color in pretrial detention. Here are some other policy provisions that will be included in the budget, according to state officials. The minimum | 50 |
| title: NY lawmakers begin debating budget 1 month after due date<br>to state officials. The minimum wage would be raised to $17 in be raised to $17 in New York City and some of its suburbs and $16 in the rest of the state by 2026. That's up from $15 | 50 |
| title: NY lawmakers begin debating budget 1 month after due date<br>2026. That's up from $15 in the city and $14.20 upstate. | 22 |
In this example we can see that the two input documents were parsed into fourteen output text chunks. The title (headline) of each document is prepended and included in the computed chunk size, so each chunk matches the configured chunk size (except the last one for each document). We've also configured some overlap in these text chunks, so the last five tokens are shared. Why would you use overlap in your text chunks? Consider that when you are splitting documents based on tokens, it is highly likely that sentences or even related concepts will be split into separate chunks. Each text chunk is processed separately by the language model, so this may result in incomplete "ideas" at the boundaries of the chunk. Overlap ensures that these split concepts are fully contained in at least one of the chunks.
#### JSON files
This final example uses a JSON file for each of the same two articles. In this example we'll set the object fields to read, but we will not add metadata to the text chunks.

View File

@ -0,0 +1,32 @@
# GraphRAG Chunking
This package contains a collection of text chunkers, a core config model, and a factory for acquiring instances.
## Examples
Basic sentence chunking with nltk
```python
chunker = SentenceChunker()
chunks = chunker.chunk("This is a test. Another sentence.")
print(chunks) # ["This is a test.", "Another sentence."]
```
Token chunking
```python
tokenizer = tiktoken.get_encoding("o200k_base")
chunker = TokenChunker(size=3, overlap=0, encode=tokenizer.encode, decode=tokenizer.decode)
chunks = chunker.chunk("This is a random test fragment of some text")
print(chunks) # ["This is a", " random test fragment", " of some text"]
```
Using the factory via helper util
```python
tokenizer = tiktoken.get_encoding("o200k_base")
config = ChunkingConfig(
strategy="tokens",
size=3,
overlap=0
)
chunker = create_chunker(config, tokenizer.encode, tokenizer.decode)
...
```

View File

@ -1,2 +1,4 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""System-level chunking package."""

View File

@ -0,0 +1,19 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""A module containing 'prepend_metadata' function."""
def add_metadata(
text: str,
metadata: dict,
delimiter: str = ": ",
line_delimiter: str = "\n",
append: bool = False,
) -> str:
"""Add metadata to the given text, prepending by default. This utility writes the dict as rows of key/value pairs."""
metadata_str = (
line_delimiter.join(f"{k}{delimiter}{v}" for k, v in metadata.items())
+ line_delimiter
)
return text + metadata_str if append else metadata_str + text

View File

@ -0,0 +1,17 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""The ChunkResult dataclass."""
from dataclasses import dataclass
@dataclass
class ChunkResult:
"""Result of chunking a document."""
text: str
index: int
start_char: int
end_char: int
token_count: int | None = None

View File

@ -0,0 +1,13 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""Chunk strategy type enumeration."""
from enum import StrEnum
class ChunkerType(StrEnum):
"""ChunkerType class definition."""
Tokens = "tokens"
Sentence = "sentence"

View File

@ -0,0 +1,21 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""A module containing the 'Chunker' class."""
from abc import ABC, abstractmethod
from typing import Any
from graphrag_chunking.chunk_result import ChunkResult
class Chunker(ABC):
"""Abstract base class for document chunkers."""
@abstractmethod
def __init__(self, **kwargs: Any) -> None:
"""Create a chunker instance."""
@abstractmethod
def chunk(self, text: str) -> list[ChunkResult]:
"""Chunk method definition."""

View File

@ -0,0 +1,77 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""A module containing 'ChunkerFactory', 'register_chunker', and 'create_chunker'."""
from collections.abc import Callable
from graphrag_common.factory.factory import Factory, ServiceScope
from graphrag_chunking.chunk_strategy_type import ChunkerType
from graphrag_chunking.chunker import Chunker
from graphrag_chunking.chunking_config import ChunkingConfig
class ChunkerFactory(Factory[Chunker]):
"""Factory for creating Chunker instances."""
chunker_factory = ChunkerFactory()
def register_chunker(
chunker_type: str,
chunker_initializer: Callable[..., Chunker],
scope: ServiceScope = "transient",
) -> None:
"""Register a custom chunker implementation.
Args
----
- chunker_type: str
The chunker id to register.
- chunker_initializer: Callable[..., Chunker]
The chunker initializer to register.
"""
chunker_factory.register(chunker_type, chunker_initializer, scope)
def create_chunker(
config: ChunkingConfig,
encode: Callable[[str], list[int]] | None = None,
decode: Callable[[list[int]], str] | None = None,
) -> Chunker:
"""Create a chunker implementation based on the given configuration.
Args
----
- config: ChunkingConfig
The chunker configuration to use.
Returns
-------
Chunker
The created chunker implementation.
"""
config_model = config.model_dump()
if encode is not None:
config_model["encode"] = encode
if decode is not None:
config_model["decode"] = decode
chunker_strategy = config.type
if chunker_strategy not in chunker_factory:
match chunker_strategy:
case ChunkerType.Tokens:
from graphrag_chunking.token_chunker import TokenChunker
register_chunker(ChunkerType.Tokens, TokenChunker)
case ChunkerType.Sentence:
from graphrag_chunking.sentence_chunker import SentenceChunker
register_chunker(ChunkerType.Sentence, SentenceChunker)
case _:
msg = f"ChunkingConfig.strategy '{chunker_strategy}' is not registered in the ChunkerFactory. Registered types: {', '.join(chunker_factory.keys())}."
raise ValueError(msg)
return chunker_factory.create(chunker_strategy, init_args=config_model)

View File

@ -0,0 +1,36 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""Parameterization settings for the default configuration."""
from pydantic import BaseModel, ConfigDict, Field
from graphrag_chunking.chunk_strategy_type import ChunkerType
class ChunkingConfig(BaseModel):
"""Configuration section for chunking."""
model_config = ConfigDict(extra="allow")
"""Allow extra fields to support custom cache implementations."""
type: str = Field(
description="The chunking type to use.",
default=ChunkerType.Tokens,
)
encoding_model: str | None = Field(
description="The encoding model to use.",
default=None,
)
size: int = Field(
description="The chunk size to use.",
default=1200,
)
overlap: int = Field(
description="The chunk overlap to use.",
default=100,
)
prepend_metadata: bool = Field(
description="Prepend metadata into each chunk.",
default=False,
)

View File

@ -0,0 +1,30 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""A module containing 'create_chunk_results' function."""
from collections.abc import Callable
from graphrag_chunking.chunk_result import ChunkResult
def create_chunk_results(
chunks: list[str],
encode: Callable[[str], list[int]] | None = None,
) -> list[ChunkResult]:
"""Create chunk results from a list of text chunks. The index assignments are 0-based and assume chunks we not stripped relative to the source text."""
results = []
start_char = 0
for index, chunk in enumerate(chunks):
end_char = start_char + len(chunk) - 1 # 0-based indices
chunk = ChunkResult(
text=chunk,
index=index,
start_char=start_char,
end_char=end_char,
)
if encode:
chunk.token_count = len(encode(chunk.text))
results.append(chunk)
start_char = end_char + 1
return results

View File

@ -0,0 +1,44 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""A module containing 'SentenceChunker' class."""
from collections.abc import Callable
from typing import Any
import nltk
from graphrag_chunking.bootstrap_nltk import bootstrap
from graphrag_chunking.chunk_result import ChunkResult
from graphrag_chunking.chunker import Chunker
from graphrag_chunking.create_chunk_results import create_chunk_results
class SentenceChunker(Chunker):
"""A chunker that splits text into sentence-based chunks."""
def __init__(
self, encode: Callable[[str], list[int]] | None = None, **kwargs: Any
) -> None:
"""Create a sentence chunker instance."""
self._encode = encode
bootstrap()
def chunk(self, text) -> list[ChunkResult]:
"""Chunk the text into sentence-based chunks."""
sentences = nltk.sent_tokenize(text.strip())
results = create_chunk_results(sentences, encode=self._encode)
# nltk sentence tokenizer may trim whitespace, so we need to adjust start/end chars
for index, result in enumerate(results):
txt = result.text
start = result.start_char
actual_start = text.find(txt, start)
delta = actual_start - start
if delta > 0:
result.start_char += delta
result.end_char += delta
# bump the next to keep the start check from falling too far behind
if index < len(results) - 1:
results[index + 1].start_char += delta
results[index + 1].end_char += delta
return results

View File

@ -0,0 +1,67 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""A module containing 'TokenChunker' class."""
from collections.abc import Callable
from typing import Any
from graphrag_chunking.chunk_result import ChunkResult
from graphrag_chunking.chunker import Chunker
from graphrag_chunking.create_chunk_results import create_chunk_results
class TokenChunker(Chunker):
"""A chunker that splits text into token-based chunks."""
def __init__(
self,
size: int,
overlap: int,
encode: Callable[[str], list[int]],
decode: Callable[[list[int]], str],
**kwargs: Any,
) -> None:
"""Create a token chunker instance."""
self._size = size
self._overlap = overlap
self._encode = encode
self._decode = decode
def chunk(self, text: str) -> list[ChunkResult]:
"""Chunk the text into token-based chunks."""
chunks = split_text_on_tokens(
text,
chunk_size=self._size,
chunk_overlap=self._overlap,
encode=self._encode,
decode=self._decode,
)
return create_chunk_results(chunks, encode=self._encode)
def split_text_on_tokens(
text: str,
chunk_size: int,
chunk_overlap: int,
encode: Callable[[str], list[int]],
decode: Callable[[list[int]], str],
) -> list[str]:
"""Split a single text and return chunks using the tokenizer."""
result = []
input_tokens = encode(text)
start_idx = 0
cur_idx = min(start_idx + chunk_size, len(input_tokens))
chunk_tokens = input_tokens[start_idx:cur_idx]
while start_idx < len(input_tokens):
chunk_text = decode(list(chunk_tokens))
result.append(chunk_text) # Append chunked text as string
if cur_idx == len(input_tokens):
break
start_idx += chunk_size - chunk_overlap
cur_idx = min(start_idx + chunk_size, len(input_tokens))
chunk_tokens = input_tokens[start_idx:cur_idx]
return result

View File

@ -0,0 +1,43 @@
[project]
name = "graphrag-chunking"
version = "2.7.0"
description = "Chunking utilities for GraphRAG"
authors = [
{name = "Alonso Guevara Fernández", email = "alonsog@microsoft.com"},
{name = "Andrés Morales Esquivel", email = "andresmor@microsoft.com"},
{name = "Chris Trevino", email = "chtrevin@microsoft.com"},
{name = "David Tittsworth", email = "datittsw@microsoft.com"},
{name = "Dayenne de Souza", email = "ddesouza@microsoft.com"},
{name = "Derek Worthen", email = "deworthe@microsoft.com"},
{name = "Gaudy Blanco Meneses", email = "gaudyb@microsoft.com"},
{name = "Ha Trinh", email = "trinhha@microsoft.com"},
{name = "Jonathan Larson", email = "jolarso@microsoft.com"},
{name = "Josh Bradley", email = "joshbradley@microsoft.com"},
{name = "Kate Lytvynets", email = "kalytv@microsoft.com"},
{name = "Kenny Zhang", email = "zhangken@microsoft.com"},
{name = "Mónica Carvajal"},
{name = "Nathan Evans", email = "naevans@microsoft.com"},
{name = "Rodrigo Racanicci", email = "rracanicci@microsoft.com"},
{name = "Sarah Smith", email = "smithsarah@microsoft.com"},
]
license = {text = "MIT"}
readme = "README.md"
requires-python = ">=3.11,<3.14"
classifiers = [
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Programming Language :: Python :: 3.13",
]
dependencies = [
"graphrag-common==2.7.0",
"pydantic~=2.10",
]
[project.urls]
Source = "https://github.com/microsoft/graphrag"
[build-system]
requires = ["hatchling>=1.27.0,<2.0.0"]
build-backend = "hatchling.build"

View File

@ -12,13 +12,10 @@ Backwards compatibility is not guaranteed at this time.
"""
import logging
from typing import Annotated
import annotated_types
from pydantic import PositiveInt, validate_call
from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks
from graphrag.config.defaults import graphrag_config_defaults
from graphrag.config.models.graph_rag_config import GraphRagConfig
from graphrag.language_model.manager import ModelManager
from graphrag.logger.standard_logging import init_loggers
@ -55,10 +52,6 @@ logger = logging.getLogger(__name__)
@validate_call(config={"arbitrary_types_allowed": True})
async def generate_indexing_prompts(
config: GraphRagConfig,
chunk_size: PositiveInt = graphrag_config_defaults.chunks.size,
overlap: Annotated[
int, annotated_types.Gt(-1)
] = graphrag_config_defaults.chunks.overlap,
limit: PositiveInt = 15,
selection_method: DocSelectionType = DocSelectionType.RANDOM,
domain: str | None = None,
@ -100,8 +93,6 @@ async def generate_indexing_prompts(
limit=limit,
select_method=selection_method,
logger=logger,
chunk_size=chunk_size,
overlap=overlap,
n_subset_max=n_subset_max,
k=k,
)

View File

@ -306,14 +306,14 @@ def _prompt_tune_cli(
help="The minimum number of examples to generate/include in the entity extraction prompt.",
),
chunk_size: int = typer.Option(
graphrag_config_defaults.chunks.size,
graphrag_config_defaults.chunking.size,
"--chunk-size",
help="The size of each example text chunk. Overrides chunks.size in the configuration file.",
help="The size of each example text chunk. Overrides chunking.size in the configuration file.",
),
overlap: int = typer.Option(
graphrag_config_defaults.chunks.overlap,
graphrag_config_defaults.chunking.overlap,
"--overlap",
help="The overlap size for chunking documents. Overrides chunks.overlap in the configuration file.",
help="The overlap size for chunking documents. Overrides chunking.overlap in the configuration file.",
),
language: str | None = typer.Option(
None,

View File

@ -61,11 +61,11 @@ async def prompt_tune(
)
# override chunking config in the configuration
if chunk_size != graph_config.chunks.size:
graph_config.chunks.size = chunk_size
if chunk_size != graph_config.chunking.size:
graph_config.chunking.size = chunk_size
if overlap != graph_config.chunks.overlap:
graph_config.chunks.overlap = overlap
if overlap != graph_config.chunking.overlap:
graph_config.chunking.overlap = overlap
# configure the root logger with the specified log level
from graphrag.logger.standard_logging import init_loggers
@ -81,8 +81,6 @@ async def prompt_tune(
prompts = await api.generate_indexing_prompts(
config=graph_config,
chunk_size=chunk_size,
overlap=overlap,
limit=limit,
selection_method=selection_method,
domain=domain,

View File

@ -8,13 +8,13 @@ from pathlib import Path
from typing import ClassVar
from graphrag_cache import CacheType
from graphrag_chunking.chunk_strategy_type import ChunkerType
from graphrag_storage import StorageType
from graphrag.config.embeddings import default_embeddings
from graphrag.config.enums import (
AsyncType,
AuthType,
ChunkStrategyType,
ModelType,
NounPhraseExtractorType,
ReportingType,
@ -57,15 +57,14 @@ class BasicSearchDefaults:
@dataclass
class ChunksDefaults:
"""Default values for chunks."""
class ChunkingDefaults:
"""Default values for chunking."""
type: str = ChunkerType.Tokens
size: int = 1200
overlap: int = 100
strategy: ClassVar[ChunkStrategyType] = ChunkStrategyType.tokens
encoding_model: str = ENCODING_MODEL
prepend_metadata: bool = False
chunk_size_includes_metadata: bool = False
@dataclass
@ -127,7 +126,6 @@ class EmbedTextDefaults:
batch_size: int = 16
batch_max_tokens: int = 8191
names: list[str] = field(default_factory=lambda: default_embeddings)
strategy: None = None
@dataclass
@ -140,7 +138,6 @@ class ExtractClaimsDefaults:
"Any claims or facts that could be relevant to information discovery."
)
max_gleanings: int = 1
strategy: None = None
model_id: str = DEFAULT_CHAT_MODEL_ID
model_instance_name: str = "extract_claims"
@ -154,7 +151,6 @@ class ExtractGraphDefaults:
default_factory=lambda: ["organization", "person", "geo", "event"]
)
max_gleanings: int = 1
strategy: None = None
model_id: str = DEFAULT_CHAT_MODEL_ID
model_instance_name: str = "extract_graph"
@ -362,7 +358,6 @@ class SummarizeDescriptionsDefaults:
prompt: None = None
max_length: int = 500
max_input_tokens: int = 4_000
strategy: None = None
model_id: str = DEFAULT_CHAT_MODEL_ID
model_instance_name: str = "summarize_descriptions"
@ -403,7 +398,7 @@ class GraphRagConfigDefaults:
cache: CacheDefaults = field(default_factory=CacheDefaults)
input: InputDefaults = field(default_factory=InputDefaults)
embed_text: EmbedTextDefaults = field(default_factory=EmbedTextDefaults)
chunks: ChunksDefaults = field(default_factory=ChunksDefaults)
chunking: ChunkingDefaults = field(default_factory=ChunkingDefaults)
snapshots: SnapshotsDefaults = field(default_factory=SnapshotsDefaults)
extract_graph: ExtractGraphDefaults = field(default_factory=ExtractGraphDefaults)
extract_graph_nlp: ExtractGraphNLPDefaults = field(

View File

@ -80,17 +80,6 @@ class AsyncType(str, Enum):
Threaded = "threaded"
class ChunkStrategyType(str, Enum):
"""ChunkStrategy class definition."""
tokens = "tokens"
sentence = "sentence"
def __repr__(self):
"""Get a string representation."""
return f'"{self.value}"'
class SearchMethod(Enum):
"""The type of search to run."""

View File

@ -54,9 +54,11 @@ input:
base_dir: "{graphrag_config_defaults.input.storage.base_dir}"
file_type: {graphrag_config_defaults.input.file_type.value} # [csv, text, json]
chunks:
size: {graphrag_config_defaults.chunks.size}
overlap: {graphrag_config_defaults.chunks.overlap}
chunking:
type: {graphrag_config_defaults.chunking.type}
size: {graphrag_config_defaults.chunking.size}
overlap: {graphrag_config_defaults.chunking.overlap}
encoding_model: {graphrag_config_defaults.chunking.encoding_model}
### Output/storage settings ###
## If blob storage is specified in the following four sections,

View File

@ -1,38 +0,0 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""Parameterization settings for the default configuration."""
from pydantic import BaseModel, Field
from graphrag.config.defaults import graphrag_config_defaults
from graphrag.config.enums import ChunkStrategyType
class ChunkingConfig(BaseModel):
"""Configuration section for chunking."""
size: int = Field(
description="The chunk size to use.",
default=graphrag_config_defaults.chunks.size,
)
overlap: int = Field(
description="The chunk overlap to use.",
default=graphrag_config_defaults.chunks.overlap,
)
strategy: ChunkStrategyType = Field(
description="The chunking strategy to use.",
default=graphrag_config_defaults.chunks.strategy,
)
encoding_model: str = Field(
description="The encoding model to use.",
default=graphrag_config_defaults.chunks.encoding_model,
)
prepend_metadata: bool = Field(
description="Prepend metadata into each chunk.",
default=graphrag_config_defaults.chunks.prepend_metadata,
)
chunk_size_includes_metadata: bool = Field(
description="Count metadata in max tokens.",
default=graphrag_config_defaults.chunks.chunk_size_includes_metadata,
)

View File

@ -8,6 +8,7 @@ from pathlib import Path
from devtools import pformat
from graphrag_cache import CacheConfig
from graphrag_chunking.chunking_config import ChunkingConfig
from graphrag_storage import StorageConfig, StorageType
from pydantic import BaseModel, Field, model_validator
@ -15,7 +16,6 @@ import graphrag.config.defaults as defs
from graphrag.config.defaults import graphrag_config_defaults
from graphrag.config.enums import VectorStoreType
from graphrag.config.models.basic_search_config import BasicSearchConfig
from graphrag.config.models.chunking_config import ChunkingConfig
from graphrag.config.models.cluster_graph_config import ClusterGraphConfig
from graphrag.config.models.community_reports_config import CommunityReportsConfig
from graphrag.config.models.drift_search_config import DRIFTSearchConfig
@ -117,9 +117,15 @@ class GraphRagConfig(BaseModel):
Path(self.input.storage.base_dir).resolve()
)
chunks: ChunkingConfig = Field(
chunking: ChunkingConfig = Field(
description="The chunking configuration to use.",
default=ChunkingConfig(),
default=ChunkingConfig(
type=graphrag_config_defaults.chunking.type,
size=graphrag_config_defaults.chunking.size,
overlap=graphrag_config_defaults.chunking.overlap,
encoding_model=graphrag_config_defaults.chunking.encoding_model,
prepend_metadata=graphrag_config_defaults.chunking.prepend_metadata,
),
)
"""The chunking configuration to use."""

View File

@ -1,4 +0,0 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""The Indexing Engine text chunk package root."""

View File

@ -1,140 +0,0 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""A module containing chunk_text method definitions."""
from typing import Any, cast
import pandas as pd
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
from graphrag.config.models.chunking_config import ChunkingConfig, ChunkStrategyType
from graphrag.index.operations.chunk_text.typing import (
ChunkInput,
ChunkStrategy,
)
from graphrag.logger.progress import ProgressTicker, progress_ticker
def chunk_text(
input: pd.DataFrame,
column: str,
size: int,
overlap: int,
encoding_model: str,
strategy: ChunkStrategyType,
callbacks: WorkflowCallbacks,
) -> pd.Series:
"""
Chunk a piece of text into smaller pieces.
## Usage
```yaml
args:
column: <column name> # The name of the column containing the text to chunk, this can either be a column with text, or a column with a list[tuple[doc_id, str]]
strategy: <strategy config> # The strategy to use to chunk the text, see below for more details
```
## Strategies
The text chunk verb uses a strategy to chunk the text. The strategy is an object which defines the strategy to use. The following strategies are available:
### tokens
This strategy uses the [tokens] library to chunk a piece of text. The strategy config is as follows:
```yaml
strategy: tokens
size: 1200 # Optional, The chunk size to use, default: 1200
overlap: 100 # Optional, The chunk overlap to use, default: 100
```
### sentence
This strategy uses the nltk library to chunk a piece of text into sentences. The strategy config is as follows:
```yaml
strategy: sentence
```
"""
strategy_exec = _load_strategy(strategy)
num_total = _get_num_total(input, column)
tick = progress_ticker(callbacks.progress, num_total)
# collapse the config back to a single object to support "polymorphic" function call
config = ChunkingConfig(size=size, overlap=overlap, encoding_model=encoding_model)
return cast(
"pd.Series",
input.apply(
cast(
"Any",
lambda x: _run_strategy(
strategy_exec,
x[column],
config,
tick,
),
),
axis=1,
),
)
def _run_strategy(
strategy_exec: ChunkStrategy,
input: ChunkInput,
config: ChunkingConfig,
tick: ProgressTicker,
) -> list[str | tuple[list[str] | None, str, int]]:
"""Run strategy method definition."""
if isinstance(input, str):
return [item.text_chunk for item in strategy_exec([input], config, tick)]
# We can work with both just a list of text content
# or a list of tuples of (document_id, text content)
# text_to_chunk = '''
texts = [item if isinstance(item, str) else item[1] for item in input]
strategy_results = strategy_exec(texts, config, tick)
results = []
for strategy_result in strategy_results:
doc_indices = strategy_result.source_doc_indices
if isinstance(input[doc_indices[0]], str):
results.append(strategy_result.text_chunk)
else:
doc_ids = [input[doc_idx][0] for doc_idx in doc_indices]
results.append((
doc_ids,
strategy_result.text_chunk,
strategy_result.n_tokens,
))
return results
def _load_strategy(strategy: ChunkStrategyType) -> ChunkStrategy:
"""Load strategy method definition."""
match strategy:
case ChunkStrategyType.tokens:
from graphrag.index.operations.chunk_text.strategies import run_tokens
return run_tokens
case ChunkStrategyType.sentence:
# NLTK
from graphrag.index.operations.chunk_text.bootstrap import bootstrap
from graphrag.index.operations.chunk_text.strategies import run_sentences
bootstrap()
return run_sentences
case _:
msg = f"Unknown strategy: {strategy}"
raise ValueError(msg)
def _get_num_total(output: pd.DataFrame, column: str) -> int:
num_total = 0
for row in output[column]:
if isinstance(row, str):
num_total += 1
else:
num_total += len(row)
return num_total

View File

@ -1,52 +0,0 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""A module containing run_tokens and run_sentences methods."""
from collections.abc import Iterable
import nltk
from graphrag.config.models.chunking_config import ChunkingConfig
from graphrag.index.operations.chunk_text.typing import TextChunk
from graphrag.index.text_splitting.text_splitting import (
TokenChunkerOptions,
split_multiple_texts_on_tokens,
)
from graphrag.logger.progress import ProgressTicker
from graphrag.tokenizer.get_tokenizer import get_tokenizer
def run_tokens(
input: list[str],
config: ChunkingConfig,
tick: ProgressTicker,
) -> Iterable[TextChunk]:
"""Chunks text into chunks based on encoding tokens."""
tokens_per_chunk = config.size
chunk_overlap = config.overlap
tokenizer = get_tokenizer(encoding_model=config.encoding_model)
return split_multiple_texts_on_tokens(
input,
TokenChunkerOptions(
chunk_overlap=chunk_overlap,
tokens_per_chunk=tokens_per_chunk,
encode=tokenizer.encode,
decode=tokenizer.decode,
),
tick,
)
def run_sentences(
input: list[str], _config: ChunkingConfig, tick: ProgressTicker
) -> Iterable[TextChunk]:
"""Chunks text into multiple parts by sentence."""
for doc_idx, text in enumerate(input):
sentences = nltk.sent_tokenize(text)
for sentence in sentences:
yield TextChunk(
text_chunk=sentence,
source_doc_indices=[doc_idx],
)
tick(1)

View File

@ -1,27 +0,0 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""A module containing 'TextChunk' model."""
from collections.abc import Callable, Iterable
from dataclasses import dataclass
from graphrag.config.models.chunking_config import ChunkingConfig
from graphrag.logger.progress import ProgressTicker
@dataclass
class TextChunk:
"""Text chunk class definition."""
text_chunk: str
source_doc_indices: list[int]
n_tokens: int | None = None
ChunkInput = str | list[str] | list[tuple[str, str]]
"""Input to a chunking strategy. Can be a string, a list of strings, or a list of tuples of (id, text)."""
ChunkStrategy = Callable[
[list[str], ChunkingConfig, ProgressTicker], Iterable[TextChunk]
]

View File

@ -8,9 +8,9 @@ import logging
from dataclasses import dataclass
import numpy as np
from graphrag_chunking.token_chunker import split_text_on_tokens
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
from graphrag.index.text_splitting.text_splitting import TokenTextSplitter
from graphrag.index.utils.is_null import is_null
from graphrag.language_model.protocol.base import EmbeddingModel
from graphrag.logger.progress import ProgressTicker, progress_ticker
@ -39,17 +39,15 @@ async def run_embed_text(
if is_null(input):
return TextEmbeddingResult(embeddings=None)
splitter = _get_splitter(tokenizer, batch_max_tokens)
semaphore: asyncio.Semaphore = asyncio.Semaphore(num_threads)
# Break up the input texts. The sizes here indicate how many snippets are in each input text
texts, input_sizes = _prepare_embed_texts(input, splitter)
texts, input_sizes = _prepare_embed_texts(input, tokenizer, batch_max_tokens)
text_batches = _create_text_batches(
texts,
tokenizer,
batch_size,
batch_max_tokens,
splitter,
)
logger.info(
"embedding %d inputs via %d snippets using %d batches. max_batch_size=%d, batch_max_tokens=%d",
@ -72,13 +70,6 @@ async def run_embed_text(
return TextEmbeddingResult(embeddings=embeddings)
def _get_splitter(tokenizer: Tokenizer, batch_max_tokens: int) -> TokenTextSplitter:
return TokenTextSplitter(
tokenizer=tokenizer,
chunk_size=batch_max_tokens,
)
async def _execute(
model: EmbeddingModel,
chunks: list[list[str]],
@ -100,9 +91,9 @@ async def _execute(
def _create_text_batches(
texts: list[str],
tokenizer: Tokenizer,
max_batch_size: int,
max_batch_tokens: int,
splitter: TokenTextSplitter,
) -> list[list[str]]:
"""Create batches of texts to embed."""
# https://learn.microsoft.com/en-us/azure/ai-services/openai/reference
@ -112,7 +103,7 @@ def _create_text_batches(
current_batch_tokens = 0
for text in texts:
token_count = splitter.num_tokens(text)
token_count = tokenizer.num_tokens(text)
if (
len(current_batch) >= max_batch_size
or current_batch_tokens + token_count > max_batch_tokens
@ -131,18 +122,23 @@ def _create_text_batches(
def _prepare_embed_texts(
input: list[str], splitter: TokenTextSplitter
input: list[str],
tokenizer: Tokenizer,
batch_max_tokens: int = 8191,
chunk_overlap: int = 100,
) -> tuple[list[str], list[int]]:
sizes: list[int] = []
snippets: list[str] = []
for text in input:
# Split the input text and filter out any empty content
split_texts = splitter.split_text(text)
if split_texts is None:
continue
split_texts = split_text_on_tokens(
text,
chunk_size=batch_max_tokens,
chunk_overlap=chunk_overlap,
encode=tokenizer.encode,
decode=tokenizer.decode,
)
split_texts = [text for text in split_texts if len(text) > 0]
sizes.append(len(split_texts))
snippets.extend(split_texts)

View File

@ -1,15 +0,0 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""Token limit method definition."""
from graphrag.index.text_splitting.text_splitting import TokenTextSplitter
def check_token_limit(text, max_token):
"""Check token limit."""
text_splitter = TokenTextSplitter(chunk_size=max_token, chunk_overlap=0)
docs = text_splitter.split_text(text)
if len(docs) > 1:
return 0
return 1

View File

@ -1,18 +1,15 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""A module containing 'TokenChunkerOptions', 'TextSplitter', 'NoopTextSplitter', 'TokenTextSplitter', 'split_single_text_on_tokens', and 'split_multiple_texts_on_tokens'."""
"""A module containing 'TokenTextSplitter' class and 'split_single_text_on_tokens' function."""
import logging
from abc import ABC, abstractmethod
from collections.abc import Callable, Iterable
from dataclasses import dataclass
from typing import Any, cast
from abc import ABC
from collections.abc import Callable
from typing import cast
import pandas as pd
from graphrag.index.operations.chunk_text.typing import TextChunk
from graphrag.logger.progress import ProgressTicker
from graphrag.tokenizer.get_tokenizer import get_tokenizer
from graphrag.tokenizer.tokenizer import Tokenizer
@ -24,22 +21,8 @@ LengthFn = Callable[[str], int]
logger = logging.getLogger(__name__)
@dataclass(frozen=True)
class TokenChunkerOptions:
"""TokenChunkerOptions data class."""
chunk_overlap: int
"""Overlap in tokens between chunks"""
tokens_per_chunk: int
"""Maximum number of tokens per chunk"""
decode: DecodeFn
""" Function to decode a list of token ids to a string"""
encode: EncodeFn
""" Function to encode a string to a list of token ids"""
class TextSplitter(ABC):
"""Text splitter class definition."""
class TokenTextSplitter(ABC):
"""Token text splitter class definition."""
_chunk_size: int
_chunk_overlap: int
@ -58,6 +41,7 @@ class TextSplitter(ABC):
keep_separator: bool = False,
add_start_index: bool = False,
strip_whitespace: bool = True,
tokenizer: Tokenizer | None = None,
):
"""Init method definition."""
self._chunk_size = chunk_size
@ -66,30 +50,6 @@ class TextSplitter(ABC):
self._keep_separator = keep_separator
self._add_start_index = add_start_index
self._strip_whitespace = strip_whitespace
@abstractmethod
def split_text(self, text: str | list[str]) -> Iterable[str]:
"""Split text method definition."""
class NoopTextSplitter(TextSplitter):
"""Noop text splitter class definition."""
def split_text(self, text: str | list[str]) -> Iterable[str]:
"""Split text method definition."""
return [text] if isinstance(text, str) else text
class TokenTextSplitter(TextSplitter):
"""Token text splitter class definition."""
def __init__(
self,
tokenizer: Tokenizer | None = None,
**kwargs: Any,
):
"""Init method definition."""
super().__init__(**kwargs)
self._tokenizer = tokenizer or get_tokenizer()
def num_tokens(self, text: str) -> int:
@ -106,68 +66,37 @@ class TokenTextSplitter(TextSplitter):
msg = f"Attempting to split a non-string value, actual is {type(text)}"
raise TypeError(msg)
token_chunker_options = TokenChunkerOptions(
return split_single_text_on_tokens(
text,
chunk_overlap=self._chunk_overlap,
tokens_per_chunk=self._chunk_size,
decode=self._tokenizer.decode,
encode=self._tokenizer.encode,
)
return split_single_text_on_tokens(text=text, tokenizer=token_chunker_options)
def split_single_text_on_tokens(text: str, tokenizer: TokenChunkerOptions) -> list[str]:
def split_single_text_on_tokens(
text: str,
tokens_per_chunk: int,
chunk_overlap: int,
encode: EncodeFn,
decode: DecodeFn,
) -> list[str]:
"""Split a single text and return chunks using the tokenizer."""
result = []
input_ids = tokenizer.encode(text)
input_ids = encode(text)
start_idx = 0
cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids))
cur_idx = min(start_idx + tokens_per_chunk, len(input_ids))
chunk_ids = input_ids[start_idx:cur_idx]
while start_idx < len(input_ids):
chunk_text = tokenizer.decode(list(chunk_ids))
chunk_text = decode(list(chunk_ids))
result.append(chunk_text) # Append chunked text as string
if cur_idx == len(input_ids):
break
start_idx += tokenizer.tokens_per_chunk - tokenizer.chunk_overlap
cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids))
chunk_ids = input_ids[start_idx:cur_idx]
return result
# Adapted from - https://github.com/langchain-ai/langchain/blob/77b359edf5df0d37ef0d539f678cf64f5557cb54/libs/langchain/langchain/text_splitter.py#L471
# So we could have better control over the chunking process
def split_multiple_texts_on_tokens(
texts: list[str], tokenizer: TokenChunkerOptions, tick: ProgressTicker
) -> list[TextChunk]:
"""Split multiple texts and return chunks with metadata using the tokenizer."""
result = []
mapped_ids = []
for source_doc_idx, text in enumerate(texts):
encoded = tokenizer.encode(text)
if tick:
tick(1) # Track progress if tick callback is provided
mapped_ids.append((source_doc_idx, encoded))
input_ids = [
(source_doc_idx, id) for source_doc_idx, ids in mapped_ids for id in ids
]
start_idx = 0
cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids))
chunk_ids = input_ids[start_idx:cur_idx]
while start_idx < len(input_ids):
chunk_text = tokenizer.decode([id for _, id in chunk_ids])
doc_indices = list({doc_idx for doc_idx, _ in chunk_ids})
result.append(TextChunk(chunk_text, doc_indices, len(chunk_ids)))
if cur_idx == len(input_ids):
break
start_idx += tokenizer.tokens_per_chunk - tokenizer.chunk_overlap
cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids))
start_idx += tokens_per_chunk - chunk_overlap
cur_idx = min(start_idx + tokens_per_chunk, len(input_ids))
chunk_ids = input_ids[start_idx:cur_idx]
return result

View File

@ -8,15 +8,18 @@ import logging
from typing import Any, cast
import pandas as pd
from graphrag_chunking.add_metadata import add_metadata
from graphrag_chunking.chunker import Chunker
from graphrag_chunking.chunker_factory import create_chunker
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
from graphrag.config.models.chunking_config import ChunkStrategyType
from graphrag.config.models.graph_rag_config import GraphRagConfig
from graphrag.index.operations.chunk_text.chunk_text import chunk_text
from graphrag.index.typing.context import PipelineRunContext
from graphrag.index.typing.workflow import WorkflowFunctionOutput
from graphrag.index.utils.hashing import gen_sha512_hash
from graphrag.logger.progress import progress_ticker
from graphrag.tokenizer.get_tokenizer import get_tokenizer
from graphrag.tokenizer.tokenizer import Tokenizer
from graphrag.utils.storage import load_table_from_storage, write_table_to_storage
logger = logging.getLogger(__name__)
@ -30,17 +33,14 @@ async def run_workflow(
logger.info("Workflow started: create_base_text_units")
documents = await load_table_from_storage("documents", context.output_storage)
chunks = config.chunks
tokenizer = get_tokenizer(encoding_model=config.chunking.encoding_model)
chunker = create_chunker(config.chunking, tokenizer.encode, tokenizer.decode)
output = create_base_text_units(
documents,
context.callbacks,
chunks.size,
chunks.overlap,
chunks.encoding_model,
strategy=chunks.strategy,
prepend_metadata=chunks.prepend_metadata,
chunk_size_includes_metadata=chunks.chunk_size_includes_metadata,
tokenizer=tokenizer,
chunker=chunker,
prepend_metadata=config.chunking.prepend_metadata,
)
await write_table_to_storage(output, "text_units", context.output_storage)
@ -52,70 +52,32 @@ async def run_workflow(
def create_base_text_units(
documents: pd.DataFrame,
callbacks: WorkflowCallbacks,
size: int,
overlap: int,
encoding_model: str,
strategy: ChunkStrategyType,
prepend_metadata: bool,
chunk_size_includes_metadata: bool,
tokenizer: Tokenizer,
chunker: Chunker,
prepend_metadata: bool | None = False,
) -> pd.DataFrame:
"""All the steps to transform base text_units."""
documents.sort_values(by=["id"], ascending=[True], inplace=True)
tokenizer = get_tokenizer(encoding_model=encoding_model)
def chunker(row: pd.Series) -> Any:
line_delimiter = ".\n"
metadata_str = ""
metadata_tokens = 0
if prepend_metadata and "metadata" in row:
metadata = row["metadata"]
if isinstance(metadata, str):
metadata = json.loads(metadata)
if isinstance(metadata, dict):
metadata_str = (
line_delimiter.join(f"{k}: {v}" for k, v in metadata.items())
+ line_delimiter
)
if chunk_size_includes_metadata:
metadata_tokens = len(tokenizer.encode(metadata_str))
if metadata_tokens >= size:
message = "Metadata tokens exceeds the maximum tokens per chunk. Please increase the tokens per chunk."
raise ValueError(message)
chunked = chunk_text(
pd.DataFrame([row]).reset_index(drop=True),
column="text",
size=size - metadata_tokens,
overlap=overlap,
encoding_model=encoding_model,
strategy=strategy,
callbacks=callbacks,
)[0]
if prepend_metadata:
for index, chunk in enumerate(chunked):
if isinstance(chunk, str):
chunked[index] = metadata_str + chunk
else:
chunked[index] = (
(chunk[0], metadata_str + chunk[1], chunk[2]) if chunk else None
)
row["chunks"] = chunked
return row
total_rows = len(documents)
tick = progress_ticker(callbacks.progress, total_rows)
# Track progress of row-wise apply operation
total_rows = len(documents)
logger.info("Starting chunking process for %d documents", total_rows)
def chunker_with_logging(row: pd.Series, row_index: int) -> Any:
"""Add logging to chunker execution."""
result = chunker(row)
row["chunks"] = [chunk.text for chunk in chunker.chunk(row["text"])]
metadata = row.get("metadata", None)
if prepend_metadata and metadata is not None:
metadata = json.loads(metadata) if isinstance(metadata, str) else metadata
row["chunks"] = [
add_metadata(chunk, metadata, line_delimiter=".\n")
for chunk in row["chunks"]
]
tick()
logger.info("chunker progress: %d/%d", row_index + 1, total_rows)
return result
return row
text_units = documents.apply(
lambda row: chunker_with_logging(row, row.name), axis=1

View File

@ -9,6 +9,7 @@ from typing import Any
import numpy as np
import pandas as pd
from graphrag_cache.noop_cache import NoopCache
from graphrag_chunking.chunker_factory import create_chunker
from graphrag_storage import create_storage
from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks
@ -46,8 +47,6 @@ async def load_docs_in_chunks(
select_method: DocSelectionType,
limit: int,
logger: logging.Logger,
chunk_size: int,
overlap: int,
n_subset_max: int = N_SUBSET_MAX,
k: int = K,
) -> list[str]:
@ -63,19 +62,15 @@ async def load_docs_in_chunks(
cache=NoopCache(),
)
tokenizer = get_tokenizer(embeddings_llm_settings)
chunker = create_chunker(config.chunking, tokenizer.encode, tokenizer.decode)
input_storage = create_storage(config.input.storage)
input_reader = create_input_reader(config.input, input_storage)
dataset = await input_reader.read_files()
chunk_config = config.chunks
chunks_df = create_base_text_units(
documents=pd.DataFrame(dataset),
callbacks=NoopWorkflowCallbacks(),
size=chunk_size,
overlap=overlap,
encoding_model=chunk_config.encoding_model,
strategy=chunk_config.strategy,
prepend_metadata=chunk_config.prepend_metadata,
chunk_size_includes_metadata=chunk_config.chunk_size_includes_metadata,
tokenizer=tokenizer,
chunker=chunker,
)
# Depending on the select method, build the dataset

View File

@ -12,7 +12,7 @@ from graphrag.tokenizer.tokenizer import Tokenizer
def get_tokenizer(
model_config: LanguageModelConfig | None = None,
encoding_model: str = ENCODING_MODEL,
encoding_model: str | None = None,
) -> Tokenizer:
"""
Get the tokenizer for the given model configuration or fallback to a tiktoken based tokenizer.
@ -38,4 +38,6 @@ def get_tokenizer(
return LitellmTokenizer(model_name=model_config.model)
if encoding_model is None:
encoding_model = ENCODING_MODEL
return TiktokenTokenizer(encoding_name=encoding_model)

View File

@ -53,6 +53,7 @@ package = false
members = ["packages/*"]
[tool.uv.sources]
graphrag-chunking = { workspace = true }
graphrag-common = { workspace = true }
graphrag-storage = { workspace = true }
graphrag-cache = { workspace = true }
@ -70,6 +71,7 @@ _semversioner_release = "semversioner release"
_semversioner_changelog = "semversioner changelog > CHANGELOG.md"
# Add more update toml tasks as packages are added
_semversioner_update_graphrag_toml_version = "update-toml update --file packages/graphrag/pyproject.toml --path project.version --value $(uv run semversioner current-version)"
_semversioner_update_graphrag_chunking_toml_version = "update-toml update --file packages/graphrag-chunking/pyproject.toml --path project.version --value $(uv run semversioner current-version)"
_semversioner_update_graphrag_common_toml_version = "update-toml update --file packages/graphrag-common/pyproject.toml --path project.version --value $(uv run semversioner current-version)"
_semversioner_update_graphrag_storage_toml_version = "update-toml update --file packages/graphrag-storage/pyproject.toml --path project.version --value $(uv run semversioner current-version)"
_semversioner_update_graphrag_cache_toml_version = "update-toml update --file packages/graphrag-cache/pyproject.toml --path project.version --value $(uv run semversioner current-version)"
@ -107,6 +109,7 @@ sequence = [
# Add more update toml tasks as packages are added
'_semversioner_update_graphrag_toml_version',
'_semversioner_update_graphrag_common_toml_version',
'_semversioner_update_graphrag_chunking_toml_version',
'_semversioner_update_graphrag_storage_toml_version',
'_semversioner_update_graphrag_cache_toml_version',
'_semversioner_update_workspace_dependency_versions',

View File

@ -0,0 +1,186 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
from unittest.mock import Mock, patch
from graphrag.tokenizer.get_tokenizer import get_tokenizer
from graphrag.tokenizer.tokenizer import Tokenizer
from graphrag_chunking.bootstrap_nltk import bootstrap
from graphrag_chunking.chunk_strategy_type import ChunkerType
from graphrag_chunking.chunker_factory import create_chunker
from graphrag_chunking.chunking_config import ChunkingConfig
from graphrag_chunking.token_chunker import (
split_text_on_tokens,
)
class MockTokenizer(Tokenizer):
def encode(self, text) -> list[int]:
return [ord(char) for char in text]
def decode(self, tokens) -> str:
return "".join(chr(id) for id in tokens)
class TestRunSentences:
def setup_method(self, method):
bootstrap()
def test_basic_functionality(self):
"""Test basic sentence splitting without metadata"""
input = "This is a test. Another sentence. And a third one!"
chunker = create_chunker(ChunkingConfig(type=ChunkerType.Sentence))
chunks = chunker.chunk(input)
assert len(chunks) == 3
assert chunks[0].text == "This is a test."
assert chunks[0].index == 0
assert chunks[0].start_char == 0
assert chunks[0].end_char == 14
assert chunks[1].text == "Another sentence."
assert chunks[1].index == 1
assert chunks[1].start_char == 16
assert chunks[1].end_char == 32
assert chunks[2].text == "And a third one!"
assert chunks[2].index == 2
assert chunks[2].start_char == 34
assert chunks[2].end_char == 49
def test_mixed_whitespace_handling(self):
"""Test input with irregular whitespace"""
input = " Sentence with spaces. Another one! "
chunker = create_chunker(ChunkingConfig(type=ChunkerType.Sentence))
chunks = chunker.chunk(input)
assert len(chunks) == 2
assert chunks[0].text == "Sentence with spaces."
assert chunks[0].index == 0
assert chunks[0].start_char == 3
assert chunks[0].end_char == 23
assert chunks[1].text == "Another one!"
assert chunks[1].index == 1
assert chunks[1].start_char == 25
assert chunks[1].end_char == 36
class TestRunTokens:
@patch("tiktoken.get_encoding")
def test_basic_functionality(self, mock_get_encoding):
mock_encoder = Mock()
mock_encoder.encode.side_effect = lambda x: list(x.encode())
mock_encoder.decode.side_effect = lambda x: bytes(x).decode()
mock_get_encoding.return_value = mock_encoder
input = "Marley was dead: to begin with. There is no doubt whatever about that. The register of his burial was signed by the clergyman, the clerk, the undertaker, and the chief mourner. Scrooge signed it. And Scrooge's name was good upon 'Change, for anything he chose to put his hand to."
config = ChunkingConfig(
size=5,
overlap=1,
encoding_model="fake-encoding",
type=ChunkerType.Tokens,
)
chunker = create_chunker(config, mock_encoder.encode, mock_encoder.decode)
chunks = chunker.chunk(input)
assert len(chunks) > 0
def test_split_text_str_empty():
tokenizer = get_tokenizer()
result = split_text_on_tokens(
"",
chunk_size=5,
chunk_overlap=2,
encode=tokenizer.encode,
decode=tokenizer.decode,
)
assert result == []
def test_split_text_on_tokens():
text = "This is a test text, meaning to be taken seriously by this test only."
mocked_tokenizer = MockTokenizer()
expected_splits = [
"This is a ",
"is a test ",
"test text,",
"text, mean",
" meaning t",
"ing to be ",
"o be taken",
"taken seri", # cspell:disable-line
" seriously",
"ously by t", # cspell:disable-line
" by this t",
"his test o",
"est only.",
]
result = split_text_on_tokens(
text=text,
chunk_overlap=5,
chunk_size=10,
decode=mocked_tokenizer.decode,
encode=lambda text: mocked_tokenizer.encode(text),
)
assert result == expected_splits
def test_split_text_on_tokens_one_overlap():
text = "This is a test text, meaning to be taken seriously by this test only."
tokenizer = get_tokenizer(encoding_model="o200k_base")
expected_splits = [
"This is",
" is a",
" a test",
" test text",
" text,",
", meaning",
" meaning to",
" to be",
" be taken",
" taken seriously",
" seriously by",
" by this",
" this test",
" test only",
" only.",
]
result = split_text_on_tokens(
text=text,
chunk_size=2,
chunk_overlap=1,
decode=tokenizer.decode,
encode=tokenizer.encode,
)
assert result == expected_splits
def test_split_text_on_tokens_no_overlap():
text = "This is a test text, meaning to be taken seriously by this test only."
tokenizer = get_tokenizer(encoding_model="o200k_base")
expected_splits = [
"This is a",
" test text,",
" meaning to be",
" taken seriously by",
" this test only",
".",
]
result = split_text_on_tokens(
text=text,
chunk_size=3,
chunk_overlap=0,
decode=tokenizer.decode,
encode=tokenizer.encode,
)
assert result == expected_splits

View File

@ -0,0 +1,43 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
from graphrag_chunking.add_metadata import add_metadata
def test_add_metadata_one_row():
"""Test prepending metadata to chunks"""
chunks = ["This is a test.", "Another sentence."]
metadata = {"message": "hello"}
results = [add_metadata(chunk, metadata) for chunk in chunks]
assert results[0] == "message: hello\nThis is a test."
assert results[1] == "message: hello\nAnother sentence."
def test_add_metadata_one_row_append():
"""Test prepending metadata to chunks"""
chunks = ["This is a test.", "Another sentence."]
metadata = {"message": "hello"}
results = [add_metadata(chunk, metadata, append=True) for chunk in chunks]
assert results[0] == "This is a test.message: hello\n"
assert results[1] == "Another sentence.message: hello\n"
def test_add_metadata_multiple_rows():
"""Test prepending metadata to chunks"""
chunks = ["This is a test.", "Another sentence."]
metadata = {"message": "hello", "tag": "first"}
results = [add_metadata(chunk, metadata) for chunk in chunks]
assert results[0] == "message: hello\ntag: first\nThis is a test."
assert results[1] == "message: hello\ntag: first\nAnother sentence."
def test_add_metadata_custom_delimiters():
"""Test prepending metadata to chunks"""
chunks = ["This is a test.", "Another sentence."]
metadata = {"message": "hello", "tag": "first"}
results = [
add_metadata(chunk, metadata, delimiter="-", line_delimiter="_")
for chunk in chunks
]
assert results[0] == "message-hello_tag-first_This is a test."
assert results[1] == "message-hello_tag-first_Another sentence."

View File

@ -5,7 +5,6 @@ from dataclasses import asdict
import graphrag.config.defaults as defs
from graphrag.config.models.basic_search_config import BasicSearchConfig
from graphrag.config.models.chunking_config import ChunkingConfig
from graphrag.config.models.cluster_graph_config import ClusterGraphConfig
from graphrag.config.models.community_reports_config import CommunityReportsConfig
from graphrag.config.models.drift_search_config import DRIFTSearchConfig
@ -29,6 +28,7 @@ from graphrag.config.models.summarize_descriptions_config import (
from graphrag.config.models.vector_store_config import VectorStoreConfig
from graphrag.index.input.input_config import InputConfig
from graphrag_cache import CacheConfig
from graphrag_chunking.chunking_config import ChunkingConfig
from graphrag_storage import StorageConfig
from pydantic import BaseModel
@ -164,10 +164,9 @@ def assert_text_embedding_configs(
def assert_chunking_configs(actual: ChunkingConfig, expected: ChunkingConfig) -> None:
assert actual.size == expected.size
assert actual.overlap == expected.overlap
assert actual.strategy == expected.strategy
assert actual.type == expected.type
assert actual.encoding_model == expected.encoding_model
assert actual.prepend_metadata == expected.prepend_metadata
assert actual.chunk_size_includes_metadata == expected.chunk_size_includes_metadata
def assert_snapshots_configs(
@ -345,7 +344,7 @@ def assert_graphrag_configs(actual: GraphRagConfig, expected: GraphRagConfig) ->
assert_cache_configs(actual.cache, expected.cache)
assert_input_configs(actual.input, expected.input)
assert_text_embedding_configs(actual.embed_text, expected.embed_text)
assert_chunking_configs(actual.chunks, expected.chunks)
assert_chunking_configs(actual.chunking, expected.chunking)
assert_snapshots_configs(actual.snapshots, expected.snapshots)
assert_extract_graph_configs(actual.extract_graph, expected.extract_graph)
assert_extract_graph_nlp_configs(

View File

@ -1,180 +0,0 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
from unittest import mock
from unittest.mock import ANY, Mock
import pandas as pd
import pytest
from graphrag.config.enums import ChunkStrategyType
from graphrag.index.operations.chunk_text.chunk_text import (
_get_num_total,
_load_strategy,
_run_strategy,
chunk_text,
)
from graphrag.index.operations.chunk_text.typing import (
TextChunk,
)
def test_get_num_total_default():
output = pd.DataFrame({"column": ["a", "b", "c"]})
total = _get_num_total(output, "column")
assert total == 3
def test_get_num_total_array():
output = pd.DataFrame({"column": [["a", "b", "c"], ["x", "y"]]})
total = _get_num_total(output, "column")
assert total == 5
def test_load_strategy_tokens():
strategy_type = ChunkStrategyType.tokens
strategy_loaded = _load_strategy(strategy_type)
assert strategy_loaded.__name__ == "run_tokens"
def test_load_strategy_sentence():
strategy_type = ChunkStrategyType.sentence
strategy_loaded = _load_strategy(strategy_type)
assert strategy_loaded.__name__ == "run_sentences"
def test_load_strategy_none():
strategy_type = ChunkStrategyType
with pytest.raises(
ValueError, match="Unknown strategy: <enum 'ChunkStrategyType'>"
):
_load_strategy(strategy_type) # type: ignore
def test_run_strategy_str():
input = "text test for run strategy"
config = Mock()
tick = Mock()
strategy_mocked = Mock()
strategy_mocked.return_value = [
TextChunk(
text_chunk="text test for run strategy",
source_doc_indices=[0],
)
]
runned = _run_strategy(strategy_mocked, input, config, tick)
assert runned == ["text test for run strategy"]
def test_run_strategy_arr_str():
input = ["text test for run strategy", "use for strategy"]
config = Mock()
tick = Mock()
strategy_mocked = Mock()
strategy_mocked.return_value = [
TextChunk(
text_chunk="text test for run strategy", source_doc_indices=[0], n_tokens=5
),
TextChunk(text_chunk="use for strategy", source_doc_indices=[1], n_tokens=3),
]
expected = [
"text test for run strategy",
"use for strategy",
]
runned = _run_strategy(strategy_mocked, input, config, tick)
assert runned == expected
def test_run_strategy_arr_tuple():
input = [("text test for run strategy", "3"), ("use for strategy", "5")]
config = Mock()
tick = Mock()
strategy_mocked = Mock()
strategy_mocked.return_value = [
TextChunk(
text_chunk="text test for run strategy", source_doc_indices=[0], n_tokens=5
),
TextChunk(text_chunk="use for strategy", source_doc_indices=[1], n_tokens=3),
]
expected = [
(
["text test for run strategy"],
"text test for run strategy",
5,
),
(
["use for strategy"],
"use for strategy",
3,
),
]
runned = _run_strategy(strategy_mocked, input, config, tick)
assert runned == expected
def test_run_strategy_arr_tuple_same_doc():
input = [("text test for run strategy", "3"), ("use for strategy", "5")]
config = Mock()
tick = Mock()
strategy_mocked = Mock()
strategy_mocked.return_value = [
TextChunk(
text_chunk="text test for run strategy", source_doc_indices=[0], n_tokens=5
),
TextChunk(text_chunk="use for strategy", source_doc_indices=[0], n_tokens=3),
]
expected = [
(
["text test for run strategy"],
"text test for run strategy",
5,
),
(
["text test for run strategy"],
"use for strategy",
3,
),
]
runned = _run_strategy(strategy_mocked, input, config, tick)
assert runned == expected
@mock.patch("graphrag.index.operations.chunk_text.chunk_text._load_strategy")
@mock.patch("graphrag.index.operations.chunk_text.chunk_text._run_strategy")
@mock.patch("graphrag.index.operations.chunk_text.chunk_text.progress_ticker")
def test_chunk_text(mock_progress_ticker, mock_run_strategy, mock_load_strategy):
input_data = pd.DataFrame({"name": ["The Shining"]})
column = "name"
size = 10
overlap = 2
encoding_model = "model"
strategy = ChunkStrategyType.sentence
callbacks = Mock()
callbacks.progress = Mock()
mock_load_strategy.return_value = Mock()
mock_progress_ticker.return_value = Mock()
chunk_text(input_data, column, size, overlap, encoding_model, strategy, callbacks)
mock_run_strategy.assert_called_with(
mock_load_strategy(), "The Shining", ANY, mock_progress_ticker.return_value
)

View File

@ -1,127 +0,0 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
from unittest.mock import Mock, patch
from graphrag.config.models.chunking_config import ChunkingConfig
from graphrag.index.operations.chunk_text.bootstrap import bootstrap
from graphrag.index.operations.chunk_text.strategies import (
run_sentences,
run_tokens,
)
from graphrag.index.operations.chunk_text.typing import TextChunk
from graphrag.tokenizer.get_tokenizer import get_tokenizer
class TestRunSentences:
def setup_method(self, method):
bootstrap()
def test_basic_functionality(self):
"""Test basic sentence splitting without metadata"""
input = ["This is a test. Another sentence."]
tick = Mock()
chunks = list(run_sentences(input, ChunkingConfig(), tick))
assert len(chunks) == 2
assert chunks[0].text_chunk == "This is a test."
assert chunks[1].text_chunk == "Another sentence."
assert all(c.source_doc_indices == [0] for c in chunks)
tick.assert_called_once_with(1)
def test_multiple_documents(self):
"""Test processing multiple input documents"""
input = ["First. Document.", "Second. Doc."]
tick = Mock()
chunks = list(run_sentences(input, ChunkingConfig(), tick))
assert len(chunks) == 4
assert chunks[0].source_doc_indices == [0]
assert chunks[2].source_doc_indices == [1]
assert tick.call_count == 2
def test_mixed_whitespace_handling(self):
"""Test input with irregular whitespace"""
input = [" Sentence with spaces. Another one! "]
chunks = list(run_sentences(input, ChunkingConfig(), Mock()))
assert chunks[0].text_chunk == " Sentence with spaces."
assert chunks[1].text_chunk == "Another one!"
class TestRunTokens:
@patch("tiktoken.get_encoding")
def test_basic_functionality(self, mock_get_encoding):
mock_encoder = Mock()
mock_encoder.encode.side_effect = lambda x: list(x.encode())
mock_encoder.decode.side_effect = lambda x: bytes(x).decode()
mock_get_encoding.return_value = mock_encoder
# Input and config
input = [
"Marley was dead: to begin with. There is no doubt whatever about that. The register of his burial was signed by the clergyman, the clerk, the undertaker, and the chief mourner. Scrooge signed it. And Scrooge's name was good upon 'Change, for anything he chose to put his hand to."
]
config = ChunkingConfig(size=5, overlap=1, encoding_model="fake-encoding")
tick = Mock()
# Run the function
chunks = list(run_tokens(input, config, tick))
# Verify output
assert len(chunks) > 0
assert all(isinstance(chunk, TextChunk) for chunk in chunks)
tick.assert_called_once_with(1)
@patch("tiktoken.get_encoding")
def test_non_string_input(self, mock_get_encoding):
"""Test handling of non-string input (e.g., numbers)."""
mock_encoder = Mock()
mock_encoder.encode.side_effect = lambda x: list(str(x).encode())
mock_encoder.decode.side_effect = lambda x: bytes(x).decode()
mock_get_encoding.return_value = mock_encoder
input = [123] # Non-string input
config = ChunkingConfig(size=5, overlap=1, encoding_model="fake-encoding")
tick = Mock()
chunks = list(run_tokens(input, config, tick)) # type: ignore
# Verify non-string input is handled
assert len(chunks) > 0
assert "123" in chunks[0].text_chunk
@patch("tiktoken.get_encoding")
def test_get_encoding_fn_encode(mock_get_encoding):
# Create a mock encoding object with encode and decode methods
mock_encoding = Mock()
mock_encoding.encode = Mock(return_value=[1, 2, 3])
mock_encoding.decode = Mock(return_value="decoded text")
# Configure the mock_get_encoding to return the mock encoding object
mock_get_encoding.return_value = mock_encoding
# Call the function to get encode and decode functions
tokenizer = get_tokenizer(encoding_model="mock_encoding")
# Test the encode function
encoded_text = tokenizer.encode("test text")
assert encoded_text == [1, 2, 3]
mock_encoding.encode.assert_called_once_with("test text")
@patch("tiktoken.get_encoding")
def test_get_encoding_fn_decode(mock_get_encoding):
# Create a mock encoding object with encode and decode methods
mock_encoding = Mock()
mock_encoding.encode = Mock(return_value=[1, 2, 3])
mock_encoding.decode = Mock(return_value="decoded text")
# Configure the mock_get_encoding to return the mock encoding object
mock_get_encoding.return_value = mock_encoding
# Call the function to get encode and decode functions
tokenizer = get_tokenizer(encoding_model="mock_encoding")
decoded_text = tokenizer.decode([1, 2, 3])
assert decoded_text == "decoded text"
mock_encoding.decode.assert_called_once_with([1, 2, 3])

View File

@ -1,2 +0,0 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License

View File

@ -1,169 +0,0 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
from unittest import mock
from unittest.mock import MagicMock
import pytest
import tiktoken
from graphrag.index.text_splitting.text_splitting import (
NoopTextSplitter,
TokenChunkerOptions,
TokenTextSplitter,
split_multiple_texts_on_tokens,
split_single_text_on_tokens,
)
def test_noop_text_splitter() -> None:
splitter = NoopTextSplitter()
assert list(splitter.split_text("some text")) == ["some text"]
assert list(splitter.split_text(["some", "text"])) == ["some", "text"]
class MockTokenizer:
def encode(self, text):
return [ord(char) for char in text]
def decode(self, token_ids):
return "".join(chr(id) for id in token_ids)
def test_split_text_str_empty():
splitter = TokenTextSplitter(chunk_size=5, chunk_overlap=2)
result = splitter.split_text("")
assert result == []
def test_split_text_str_bool():
splitter = TokenTextSplitter(chunk_size=5, chunk_overlap=2)
result = splitter.split_text(None) # type: ignore
assert result == []
def test_split_text_str_int():
splitter = TokenTextSplitter(chunk_size=5, chunk_overlap=2)
with pytest.raises(TypeError):
splitter.split_text(123) # type: ignore
@mock.patch("graphrag.index.text_splitting.text_splitting.split_single_text_on_tokens")
def test_split_text_large_input(mock_split):
large_text = "a" * 10_000
mock_split.return_value = ["chunk"] * 2_000
splitter = TokenTextSplitter(chunk_size=5, chunk_overlap=2)
result = splitter.split_text(large_text)
assert len(result) == 2_000, "Large input was not split correctly"
mock_split.assert_called_once()
@mock.patch("graphrag.index.text_splitting.text_splitting.split_single_text_on_tokens")
@mock.patch("graphrag.index.text_splitting.text_splitting.TokenChunkerOptions")
def test_token_text_splitter(mock_tokenizer, mock_split_text):
text = "chunk1 chunk2 chunk3"
expected_chunks = ["chunk1", "chunk2", "chunk3"]
mocked_tokenizer = MagicMock()
mock_tokenizer.return_value = mocked_tokenizer
mock_split_text.return_value = expected_chunks
splitter = TokenTextSplitter()
splitter.split_text(["chunk1", "chunk2", "chunk3"])
mock_split_text.assert_called_once_with(text=text, tokenizer=mocked_tokenizer)
def test_split_single_text_on_tokens():
text = "This is a test text, meaning to be taken seriously by this test only."
mocked_tokenizer = MockTokenizer()
tokenizer = TokenChunkerOptions(
chunk_overlap=5,
tokens_per_chunk=10,
decode=mocked_tokenizer.decode,
encode=lambda text: mocked_tokenizer.encode(text),
)
expected_splits = [
"This is a ",
"is a test ",
"test text,",
"text, mean",
" meaning t",
"ing to be ",
"o be taken",
"taken seri", # cspell:disable-line
" seriously",
"ously by t", # cspell:disable-line
" by this t",
"his test o",
"est only.",
]
result = split_single_text_on_tokens(text=text, tokenizer=tokenizer)
assert result == expected_splits
def test_split_multiple_texts_on_tokens():
texts = [
"This is a test text, meaning to be taken seriously by this test only.",
"This is th second text, meaning to be taken seriously by this test only.",
]
mocked_tokenizer = MockTokenizer()
mock_tick = MagicMock()
tokenizer = TokenChunkerOptions(
chunk_overlap=5,
tokens_per_chunk=10,
decode=mocked_tokenizer.decode,
encode=lambda text: mocked_tokenizer.encode(text),
)
split_multiple_texts_on_tokens(texts, tokenizer, tick=mock_tick)
mock_tick.assert_called()
def test_split_single_text_on_tokens_no_overlap():
text = "This is a test text, meaning to be taken seriously by this test only."
enc = tiktoken.get_encoding("cl100k_base")
def encode(text: str) -> list[int]:
if not isinstance(text, str):
text = f"{text}"
return enc.encode(text)
def decode(tokens: list[int]) -> str:
return enc.decode(tokens)
tokenizer = TokenChunkerOptions(
chunk_overlap=1,
tokens_per_chunk=2,
decode=decode,
encode=lambda text: encode(text),
)
expected_splits = [
"This is",
" is a",
" a test",
" test text",
" text,",
", meaning",
" meaning to",
" to be",
" be taken", # cspell:disable-line
" taken seriously", # cspell:disable-line
" seriously by",
" by this", # cspell:disable-line
" this test",
" test only",
" only.",
]
result = split_single_text_on_tokens(text=text, tokenizer=tokenizer)
assert result == expected_splits

View File

@ -35,7 +35,7 @@ async def test_create_base_text_units_metadata():
config = GraphRagConfig(models=DEFAULT_MODEL_CONFIG) # type: ignore
config.input.metadata = ["title"]
config.chunks.prepend_metadata = True
config.chunking.prepend_metadata = True
await update_document_metadata(config.input.metadata, context)
@ -43,22 +43,3 @@ async def test_create_base_text_units_metadata():
actual = await load_table_from_storage("text_units", context.output_storage)
compare_outputs(actual, expected, ["text", "document_id", "n_tokens"])
async def test_create_base_text_units_metadata_included_in_chunk():
expected = load_test_table("text_units_metadata_included_chunk")
context = await create_test_context()
config = GraphRagConfig(models=DEFAULT_MODEL_CONFIG) # type: ignore
config.input.metadata = ["title"]
config.chunks.prepend_metadata = True
config.chunks.chunk_size_includes_metadata = True
await update_document_metadata(config.input.metadata, context)
await run_workflow(config, context)
actual = await load_table_from_storage("text_units", context.output_storage)
# only check the columns from the base workflow - our expected table is the final and will have more
compare_outputs(actual, expected, columns=["text", "document_id", "n_tokens"])

2482
uv.lock generated

File diff suppressed because it is too large Load Diff