Streamline chunking config

This commit is contained in:
Nathan Evans 2025-12-23 10:28:14 -08:00
parent a741bfb8d7
commit 7748493fdf
13 changed files with 47 additions and 51 deletions

View File

@ -6,8 +6,8 @@
from enum import StrEnum
class ChunkStrategyType(StrEnum):
"""ChunkStrategy class definition."""
class ChunkerType(StrEnum):
"""ChunkerType class definition."""
Tokens = "tokens"
Sentence = "sentence"

View File

@ -7,7 +7,7 @@ from collections.abc import Callable
from graphrag_common.factory.factory import Factory, ServiceScope
from graphrag_chunking.chunk_strategy_type import ChunkStrategyType
from graphrag_chunking.chunk_strategy_type import ChunkerType
from graphrag_chunking.chunker import Chunker
from graphrag_chunking.chunking_config import ChunkingConfig
@ -58,18 +58,18 @@ def create_chunker(
config_model["encode"] = encode
if decode is not None:
config_model["decode"] = decode
chunker_strategy = config.strategy
chunker_strategy = config.type
if chunker_strategy not in chunker_factory:
match chunker_strategy:
case ChunkStrategyType.Tokens:
case ChunkerType.Tokens:
from graphrag_chunking.token_chunker import TokenChunker
register_chunker(ChunkStrategyType.Tokens, TokenChunker)
case ChunkStrategyType.Sentence:
register_chunker(ChunkerType.Tokens, TokenChunker)
case ChunkerType.Sentence:
from graphrag_chunking.sentence_chunker import SentenceChunker
register_chunker(ChunkStrategyType.Sentence, SentenceChunker)
register_chunker(ChunkerType.Sentence, SentenceChunker)
case _:
msg = f"ChunkingConfig.strategy '{chunker_strategy}' is not registered in the ChunkerFactory. Registered types: {', '.join(chunker_factory.keys())}."
raise ValueError(msg)

View File

@ -5,7 +5,7 @@
from pydantic import BaseModel, ConfigDict, Field
from graphrag_chunking.chunk_strategy_type import ChunkStrategyType
from graphrag_chunking.chunk_strategy_type import ChunkerType
class ChunkingConfig(BaseModel):
@ -14,9 +14,9 @@ class ChunkingConfig(BaseModel):
model_config = ConfigDict(extra="allow")
"""Allow extra fields to support custom cache implementations."""
strategy: str = Field(
description="The chunking strategy to use.",
default=ChunkStrategyType.Tokens,
type: str = Field(
description="The chunking type to use.",
default=ChunkerType.Tokens,
)
encoding_model: str | None = Field(
description="The encoding model to use.",

View File

@ -306,14 +306,14 @@ def _prompt_tune_cli(
help="The minimum number of examples to generate/include in the entity extraction prompt.",
),
chunk_size: int = typer.Option(
graphrag_config_defaults.chunks.size,
graphrag_config_defaults.chunking.size,
"--chunk-size",
help="The size of each example text chunk. Overrides chunks.size in the configuration file.",
help="The size of each example text chunk. Overrides chunking.size in the configuration file.",
),
overlap: int = typer.Option(
graphrag_config_defaults.chunks.overlap,
graphrag_config_defaults.chunking.overlap,
"--overlap",
help="The overlap size for chunking documents. Overrides chunks.overlap in the configuration file.",
help="The overlap size for chunking documents. Overrides chunking.overlap in the configuration file.",
),
language: str | None = typer.Option(
None,

View File

@ -61,11 +61,11 @@ async def prompt_tune(
)
# override chunking config in the configuration
if chunk_size != graph_config.chunks.size:
graph_config.chunks.size = chunk_size
if chunk_size != graph_config.chunking.size:
graph_config.chunking.size = chunk_size
if overlap != graph_config.chunks.overlap:
graph_config.chunks.overlap = overlap
if overlap != graph_config.chunking.overlap:
graph_config.chunking.overlap = overlap
# configure the root logger with the specified log level
from graphrag.logger.standard_logging import init_loggers

View File

@ -8,7 +8,7 @@ from pathlib import Path
from typing import ClassVar
from graphrag_cache import CacheType
from graphrag_chunking.chunk_strategy_type import ChunkStrategyType
from graphrag_chunking.chunk_strategy_type import ChunkerType
from graphrag_storage import StorageType
from graphrag.config.embeddings import default_embeddings
@ -57,10 +57,10 @@ class BasicSearchDefaults:
@dataclass
class ChunksDefaults:
"""Default values for chunks."""
class ChunkingDefaults:
"""Default values for chunking."""
strategy: str = ChunkStrategyType.Tokens
type: str = ChunkerType.Tokens
size: int = 1200
overlap: int = 100
encoding_model: str = ENCODING_MODEL
@ -126,7 +126,6 @@ class EmbedTextDefaults:
batch_size: int = 16
batch_max_tokens: int = 8191
names: list[str] = field(default_factory=lambda: default_embeddings)
strategy: None = None
@dataclass
@ -139,7 +138,6 @@ class ExtractClaimsDefaults:
"Any claims or facts that could be relevant to information discovery."
)
max_gleanings: int = 1
strategy: None = None
model_id: str = DEFAULT_CHAT_MODEL_ID
model_instance_name: str = "extract_claims"
@ -153,7 +151,6 @@ class ExtractGraphDefaults:
default_factory=lambda: ["organization", "person", "geo", "event"]
)
max_gleanings: int = 1
strategy: None = None
model_id: str = DEFAULT_CHAT_MODEL_ID
model_instance_name: str = "extract_graph"
@ -360,7 +357,6 @@ class SummarizeDescriptionsDefaults:
prompt: None = None
max_length: int = 500
max_input_tokens: int = 4_000
strategy: None = None
model_id: str = DEFAULT_CHAT_MODEL_ID
model_instance_name: str = "summarize_descriptions"
@ -401,7 +397,7 @@ class GraphRagConfigDefaults:
cache: CacheDefaults = field(default_factory=CacheDefaults)
input: InputDefaults = field(default_factory=InputDefaults)
embed_text: EmbedTextDefaults = field(default_factory=EmbedTextDefaults)
chunks: ChunksDefaults = field(default_factory=ChunksDefaults)
chunking: ChunkingDefaults = field(default_factory=ChunkingDefaults)
snapshots: SnapshotsDefaults = field(default_factory=SnapshotsDefaults)
extract_graph: ExtractGraphDefaults = field(default_factory=ExtractGraphDefaults)
extract_graph_nlp: ExtractGraphNLPDefaults = field(

View File

@ -54,11 +54,11 @@ input:
base_dir: "{graphrag_config_defaults.input.storage.base_dir}"
file_type: {graphrag_config_defaults.input.file_type.value} # [csv, text, json]
chunks:
strategy: {graphrag_config_defaults.chunks.strategy}
size: {graphrag_config_defaults.chunks.size}
overlap: {graphrag_config_defaults.chunks.overlap}
encoding_model: {graphrag_config_defaults.chunks.encoding_model}
chunking:
type: {graphrag_config_defaults.chunking.type}
size: {graphrag_config_defaults.chunking.size}
overlap: {graphrag_config_defaults.chunking.overlap}
encoding_model: {graphrag_config_defaults.chunking.encoding_model}
### Output/storage settings ###
## If blob storage is specified in the following four sections,

View File

@ -125,14 +125,14 @@ class GraphRagConfig(BaseModel):
Path(self.input.storage.base_dir).resolve()
)
chunks: ChunkingConfig = Field(
chunking: ChunkingConfig = Field(
description="The chunking configuration to use.",
default=ChunkingConfig(
strategy=graphrag_config_defaults.chunks.strategy,
size=graphrag_config_defaults.chunks.size,
overlap=graphrag_config_defaults.chunks.overlap,
encoding_model=graphrag_config_defaults.chunks.encoding_model,
prepend_metadata=graphrag_config_defaults.chunks.prepend_metadata,
type=graphrag_config_defaults.chunking.type,
size=graphrag_config_defaults.chunking.size,
overlap=graphrag_config_defaults.chunking.overlap,
encoding_model=graphrag_config_defaults.chunking.encoding_model,
prepend_metadata=graphrag_config_defaults.chunking.prepend_metadata,
),
)
"""The chunking configuration to use."""

View File

@ -33,14 +33,14 @@ async def run_workflow(
logger.info("Workflow started: create_base_text_units")
documents = await load_table_from_storage("documents", context.output_storage)
tokenizer = get_tokenizer(encoding_model=config.chunks.encoding_model)
chunker = create_chunker(config.chunks, tokenizer.encode, tokenizer.decode)
tokenizer = get_tokenizer(encoding_model=config.chunking.encoding_model)
chunker = create_chunker(config.chunking, tokenizer.encode, tokenizer.decode)
output = create_base_text_units(
documents,
context.callbacks,
tokenizer=tokenizer,
chunker=chunker,
prepend_metadata=config.chunks.prepend_metadata,
prepend_metadata=config.chunking.prepend_metadata,
)
await write_table_to_storage(output, "text_units", context.output_storage)

View File

@ -62,7 +62,7 @@ async def load_docs_in_chunks(
cache=NoopCache(),
)
tokenizer = get_tokenizer(embeddings_llm_settings)
chunker = create_chunker(config.chunks, tokenizer.encode, tokenizer.decode)
chunker = create_chunker(config.chunking, tokenizer.encode, tokenizer.decode)
input_storage = create_storage(config.input.storage)
input_reader = InputReaderFactory().create(
config.input.file_type,

View File

@ -6,7 +6,7 @@ from unittest.mock import Mock, patch
from graphrag.tokenizer.get_tokenizer import get_tokenizer
from graphrag.tokenizer.tokenizer import Tokenizer
from graphrag_chunking.bootstrap_nltk import bootstrap
from graphrag_chunking.chunk_strategy_type import ChunkStrategyType
from graphrag_chunking.chunk_strategy_type import ChunkerType
from graphrag_chunking.chunker_factory import create_chunker
from graphrag_chunking.chunking_config import ChunkingConfig
from graphrag_chunking.token_chunker import (
@ -29,7 +29,7 @@ class TestRunSentences:
def test_basic_functionality(self):
"""Test basic sentence splitting without metadata"""
input = "This is a test. Another sentence. And a third one!"
chunker = create_chunker(ChunkingConfig(strategy=ChunkStrategyType.Sentence))
chunker = create_chunker(ChunkingConfig(type=ChunkerType.Sentence))
chunks = chunker.chunk(input)
assert len(chunks) == 3
@ -51,7 +51,7 @@ class TestRunSentences:
def test_mixed_whitespace_handling(self):
"""Test input with irregular whitespace"""
input = " Sentence with spaces. Another one! "
chunker = create_chunker(ChunkingConfig(strategy=ChunkStrategyType.Sentence))
chunker = create_chunker(ChunkingConfig(type=ChunkerType.Sentence))
chunks = chunker.chunk(input)
assert len(chunks) == 2
@ -80,7 +80,7 @@ class TestRunTokens:
size=5,
overlap=1,
encoding_model="fake-encoding",
strategy=ChunkStrategyType.Tokens,
type=ChunkerType.Tokens,
)
chunker = create_chunker(config, mock_encoder.encode, mock_encoder.decode)

View File

@ -164,7 +164,7 @@ def assert_text_embedding_configs(
def assert_chunking_configs(actual: ChunkingConfig, expected: ChunkingConfig) -> None:
assert actual.size == expected.size
assert actual.overlap == expected.overlap
assert actual.strategy == expected.strategy
assert actual.type == expected.type
assert actual.encoding_model == expected.encoding_model
assert actual.prepend_metadata == expected.prepend_metadata
@ -344,7 +344,7 @@ def assert_graphrag_configs(actual: GraphRagConfig, expected: GraphRagConfig) ->
assert_cache_configs(actual.cache, expected.cache)
assert_input_configs(actual.input, expected.input)
assert_text_embedding_configs(actual.embed_text, expected.embed_text)
assert_chunking_configs(actual.chunks, expected.chunks)
assert_chunking_configs(actual.chunking, expected.chunking)
assert_snapshots_configs(actual.snapshots, expected.snapshots)
assert_extract_graph_configs(actual.extract_graph, expected.extract_graph)
assert_extract_graph_nlp_configs(

View File

@ -35,7 +35,7 @@ async def test_create_base_text_units_metadata():
config = GraphRagConfig(models=DEFAULT_MODEL_CONFIG) # type: ignore
config.input.metadata = ["title"]
config.chunks.prepend_metadata = True
config.chunking.prepend_metadata = True
await update_document_metadata(config.input.metadata, context)