mirror of
https://github.com/microsoft/graphrag.git
synced 2026-01-14 00:57:23 +08:00
Streamline chunking config
This commit is contained in:
parent
a741bfb8d7
commit
7748493fdf
@ -6,8 +6,8 @@
|
||||
from enum import StrEnum
|
||||
|
||||
|
||||
class ChunkStrategyType(StrEnum):
|
||||
"""ChunkStrategy class definition."""
|
||||
class ChunkerType(StrEnum):
|
||||
"""ChunkerType class definition."""
|
||||
|
||||
Tokens = "tokens"
|
||||
Sentence = "sentence"
|
||||
|
||||
@ -7,7 +7,7 @@ from collections.abc import Callable
|
||||
|
||||
from graphrag_common.factory.factory import Factory, ServiceScope
|
||||
|
||||
from graphrag_chunking.chunk_strategy_type import ChunkStrategyType
|
||||
from graphrag_chunking.chunk_strategy_type import ChunkerType
|
||||
from graphrag_chunking.chunker import Chunker
|
||||
from graphrag_chunking.chunking_config import ChunkingConfig
|
||||
|
||||
@ -58,18 +58,18 @@ def create_chunker(
|
||||
config_model["encode"] = encode
|
||||
if decode is not None:
|
||||
config_model["decode"] = decode
|
||||
chunker_strategy = config.strategy
|
||||
chunker_strategy = config.type
|
||||
|
||||
if chunker_strategy not in chunker_factory:
|
||||
match chunker_strategy:
|
||||
case ChunkStrategyType.Tokens:
|
||||
case ChunkerType.Tokens:
|
||||
from graphrag_chunking.token_chunker import TokenChunker
|
||||
|
||||
register_chunker(ChunkStrategyType.Tokens, TokenChunker)
|
||||
case ChunkStrategyType.Sentence:
|
||||
register_chunker(ChunkerType.Tokens, TokenChunker)
|
||||
case ChunkerType.Sentence:
|
||||
from graphrag_chunking.sentence_chunker import SentenceChunker
|
||||
|
||||
register_chunker(ChunkStrategyType.Sentence, SentenceChunker)
|
||||
register_chunker(ChunkerType.Sentence, SentenceChunker)
|
||||
case _:
|
||||
msg = f"ChunkingConfig.strategy '{chunker_strategy}' is not registered in the ChunkerFactory. Registered types: {', '.join(chunker_factory.keys())}."
|
||||
raise ValueError(msg)
|
||||
|
||||
@ -5,7 +5,7 @@
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from graphrag_chunking.chunk_strategy_type import ChunkStrategyType
|
||||
from graphrag_chunking.chunk_strategy_type import ChunkerType
|
||||
|
||||
|
||||
class ChunkingConfig(BaseModel):
|
||||
@ -14,9 +14,9 @@ class ChunkingConfig(BaseModel):
|
||||
model_config = ConfigDict(extra="allow")
|
||||
"""Allow extra fields to support custom cache implementations."""
|
||||
|
||||
strategy: str = Field(
|
||||
description="The chunking strategy to use.",
|
||||
default=ChunkStrategyType.Tokens,
|
||||
type: str = Field(
|
||||
description="The chunking type to use.",
|
||||
default=ChunkerType.Tokens,
|
||||
)
|
||||
encoding_model: str | None = Field(
|
||||
description="The encoding model to use.",
|
||||
|
||||
@ -306,14 +306,14 @@ def _prompt_tune_cli(
|
||||
help="The minimum number of examples to generate/include in the entity extraction prompt.",
|
||||
),
|
||||
chunk_size: int = typer.Option(
|
||||
graphrag_config_defaults.chunks.size,
|
||||
graphrag_config_defaults.chunking.size,
|
||||
"--chunk-size",
|
||||
help="The size of each example text chunk. Overrides chunks.size in the configuration file.",
|
||||
help="The size of each example text chunk. Overrides chunking.size in the configuration file.",
|
||||
),
|
||||
overlap: int = typer.Option(
|
||||
graphrag_config_defaults.chunks.overlap,
|
||||
graphrag_config_defaults.chunking.overlap,
|
||||
"--overlap",
|
||||
help="The overlap size for chunking documents. Overrides chunks.overlap in the configuration file.",
|
||||
help="The overlap size for chunking documents. Overrides chunking.overlap in the configuration file.",
|
||||
),
|
||||
language: str | None = typer.Option(
|
||||
None,
|
||||
|
||||
@ -61,11 +61,11 @@ async def prompt_tune(
|
||||
)
|
||||
|
||||
# override chunking config in the configuration
|
||||
if chunk_size != graph_config.chunks.size:
|
||||
graph_config.chunks.size = chunk_size
|
||||
if chunk_size != graph_config.chunking.size:
|
||||
graph_config.chunking.size = chunk_size
|
||||
|
||||
if overlap != graph_config.chunks.overlap:
|
||||
graph_config.chunks.overlap = overlap
|
||||
if overlap != graph_config.chunking.overlap:
|
||||
graph_config.chunking.overlap = overlap
|
||||
|
||||
# configure the root logger with the specified log level
|
||||
from graphrag.logger.standard_logging import init_loggers
|
||||
|
||||
@ -8,7 +8,7 @@ from pathlib import Path
|
||||
from typing import ClassVar
|
||||
|
||||
from graphrag_cache import CacheType
|
||||
from graphrag_chunking.chunk_strategy_type import ChunkStrategyType
|
||||
from graphrag_chunking.chunk_strategy_type import ChunkerType
|
||||
from graphrag_storage import StorageType
|
||||
|
||||
from graphrag.config.embeddings import default_embeddings
|
||||
@ -57,10 +57,10 @@ class BasicSearchDefaults:
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChunksDefaults:
|
||||
"""Default values for chunks."""
|
||||
class ChunkingDefaults:
|
||||
"""Default values for chunking."""
|
||||
|
||||
strategy: str = ChunkStrategyType.Tokens
|
||||
type: str = ChunkerType.Tokens
|
||||
size: int = 1200
|
||||
overlap: int = 100
|
||||
encoding_model: str = ENCODING_MODEL
|
||||
@ -126,7 +126,6 @@ class EmbedTextDefaults:
|
||||
batch_size: int = 16
|
||||
batch_max_tokens: int = 8191
|
||||
names: list[str] = field(default_factory=lambda: default_embeddings)
|
||||
strategy: None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -139,7 +138,6 @@ class ExtractClaimsDefaults:
|
||||
"Any claims or facts that could be relevant to information discovery."
|
||||
)
|
||||
max_gleanings: int = 1
|
||||
strategy: None = None
|
||||
model_id: str = DEFAULT_CHAT_MODEL_ID
|
||||
model_instance_name: str = "extract_claims"
|
||||
|
||||
@ -153,7 +151,6 @@ class ExtractGraphDefaults:
|
||||
default_factory=lambda: ["organization", "person", "geo", "event"]
|
||||
)
|
||||
max_gleanings: int = 1
|
||||
strategy: None = None
|
||||
model_id: str = DEFAULT_CHAT_MODEL_ID
|
||||
model_instance_name: str = "extract_graph"
|
||||
|
||||
@ -360,7 +357,6 @@ class SummarizeDescriptionsDefaults:
|
||||
prompt: None = None
|
||||
max_length: int = 500
|
||||
max_input_tokens: int = 4_000
|
||||
strategy: None = None
|
||||
model_id: str = DEFAULT_CHAT_MODEL_ID
|
||||
model_instance_name: str = "summarize_descriptions"
|
||||
|
||||
@ -401,7 +397,7 @@ class GraphRagConfigDefaults:
|
||||
cache: CacheDefaults = field(default_factory=CacheDefaults)
|
||||
input: InputDefaults = field(default_factory=InputDefaults)
|
||||
embed_text: EmbedTextDefaults = field(default_factory=EmbedTextDefaults)
|
||||
chunks: ChunksDefaults = field(default_factory=ChunksDefaults)
|
||||
chunking: ChunkingDefaults = field(default_factory=ChunkingDefaults)
|
||||
snapshots: SnapshotsDefaults = field(default_factory=SnapshotsDefaults)
|
||||
extract_graph: ExtractGraphDefaults = field(default_factory=ExtractGraphDefaults)
|
||||
extract_graph_nlp: ExtractGraphNLPDefaults = field(
|
||||
|
||||
@ -54,11 +54,11 @@ input:
|
||||
base_dir: "{graphrag_config_defaults.input.storage.base_dir}"
|
||||
file_type: {graphrag_config_defaults.input.file_type.value} # [csv, text, json]
|
||||
|
||||
chunks:
|
||||
strategy: {graphrag_config_defaults.chunks.strategy}
|
||||
size: {graphrag_config_defaults.chunks.size}
|
||||
overlap: {graphrag_config_defaults.chunks.overlap}
|
||||
encoding_model: {graphrag_config_defaults.chunks.encoding_model}
|
||||
chunking:
|
||||
type: {graphrag_config_defaults.chunking.type}
|
||||
size: {graphrag_config_defaults.chunking.size}
|
||||
overlap: {graphrag_config_defaults.chunking.overlap}
|
||||
encoding_model: {graphrag_config_defaults.chunking.encoding_model}
|
||||
|
||||
### Output/storage settings ###
|
||||
## If blob storage is specified in the following four sections,
|
||||
|
||||
@ -125,14 +125,14 @@ class GraphRagConfig(BaseModel):
|
||||
Path(self.input.storage.base_dir).resolve()
|
||||
)
|
||||
|
||||
chunks: ChunkingConfig = Field(
|
||||
chunking: ChunkingConfig = Field(
|
||||
description="The chunking configuration to use.",
|
||||
default=ChunkingConfig(
|
||||
strategy=graphrag_config_defaults.chunks.strategy,
|
||||
size=graphrag_config_defaults.chunks.size,
|
||||
overlap=graphrag_config_defaults.chunks.overlap,
|
||||
encoding_model=graphrag_config_defaults.chunks.encoding_model,
|
||||
prepend_metadata=graphrag_config_defaults.chunks.prepend_metadata,
|
||||
type=graphrag_config_defaults.chunking.type,
|
||||
size=graphrag_config_defaults.chunking.size,
|
||||
overlap=graphrag_config_defaults.chunking.overlap,
|
||||
encoding_model=graphrag_config_defaults.chunking.encoding_model,
|
||||
prepend_metadata=graphrag_config_defaults.chunking.prepend_metadata,
|
||||
),
|
||||
)
|
||||
"""The chunking configuration to use."""
|
||||
|
||||
@ -33,14 +33,14 @@ async def run_workflow(
|
||||
logger.info("Workflow started: create_base_text_units")
|
||||
documents = await load_table_from_storage("documents", context.output_storage)
|
||||
|
||||
tokenizer = get_tokenizer(encoding_model=config.chunks.encoding_model)
|
||||
chunker = create_chunker(config.chunks, tokenizer.encode, tokenizer.decode)
|
||||
tokenizer = get_tokenizer(encoding_model=config.chunking.encoding_model)
|
||||
chunker = create_chunker(config.chunking, tokenizer.encode, tokenizer.decode)
|
||||
output = create_base_text_units(
|
||||
documents,
|
||||
context.callbacks,
|
||||
tokenizer=tokenizer,
|
||||
chunker=chunker,
|
||||
prepend_metadata=config.chunks.prepend_metadata,
|
||||
prepend_metadata=config.chunking.prepend_metadata,
|
||||
)
|
||||
|
||||
await write_table_to_storage(output, "text_units", context.output_storage)
|
||||
|
||||
@ -62,7 +62,7 @@ async def load_docs_in_chunks(
|
||||
cache=NoopCache(),
|
||||
)
|
||||
tokenizer = get_tokenizer(embeddings_llm_settings)
|
||||
chunker = create_chunker(config.chunks, tokenizer.encode, tokenizer.decode)
|
||||
chunker = create_chunker(config.chunking, tokenizer.encode, tokenizer.decode)
|
||||
input_storage = create_storage(config.input.storage)
|
||||
input_reader = InputReaderFactory().create(
|
||||
config.input.file_type,
|
||||
|
||||
@ -6,7 +6,7 @@ from unittest.mock import Mock, patch
|
||||
from graphrag.tokenizer.get_tokenizer import get_tokenizer
|
||||
from graphrag.tokenizer.tokenizer import Tokenizer
|
||||
from graphrag_chunking.bootstrap_nltk import bootstrap
|
||||
from graphrag_chunking.chunk_strategy_type import ChunkStrategyType
|
||||
from graphrag_chunking.chunk_strategy_type import ChunkerType
|
||||
from graphrag_chunking.chunker_factory import create_chunker
|
||||
from graphrag_chunking.chunking_config import ChunkingConfig
|
||||
from graphrag_chunking.token_chunker import (
|
||||
@ -29,7 +29,7 @@ class TestRunSentences:
|
||||
def test_basic_functionality(self):
|
||||
"""Test basic sentence splitting without metadata"""
|
||||
input = "This is a test. Another sentence. And a third one!"
|
||||
chunker = create_chunker(ChunkingConfig(strategy=ChunkStrategyType.Sentence))
|
||||
chunker = create_chunker(ChunkingConfig(type=ChunkerType.Sentence))
|
||||
chunks = chunker.chunk(input)
|
||||
|
||||
assert len(chunks) == 3
|
||||
@ -51,7 +51,7 @@ class TestRunSentences:
|
||||
def test_mixed_whitespace_handling(self):
|
||||
"""Test input with irregular whitespace"""
|
||||
input = " Sentence with spaces. Another one! "
|
||||
chunker = create_chunker(ChunkingConfig(strategy=ChunkStrategyType.Sentence))
|
||||
chunker = create_chunker(ChunkingConfig(type=ChunkerType.Sentence))
|
||||
chunks = chunker.chunk(input)
|
||||
|
||||
assert len(chunks) == 2
|
||||
@ -80,7 +80,7 @@ class TestRunTokens:
|
||||
size=5,
|
||||
overlap=1,
|
||||
encoding_model="fake-encoding",
|
||||
strategy=ChunkStrategyType.Tokens,
|
||||
type=ChunkerType.Tokens,
|
||||
)
|
||||
|
||||
chunker = create_chunker(config, mock_encoder.encode, mock_encoder.decode)
|
||||
|
||||
@ -164,7 +164,7 @@ def assert_text_embedding_configs(
|
||||
def assert_chunking_configs(actual: ChunkingConfig, expected: ChunkingConfig) -> None:
|
||||
assert actual.size == expected.size
|
||||
assert actual.overlap == expected.overlap
|
||||
assert actual.strategy == expected.strategy
|
||||
assert actual.type == expected.type
|
||||
assert actual.encoding_model == expected.encoding_model
|
||||
assert actual.prepend_metadata == expected.prepend_metadata
|
||||
|
||||
@ -344,7 +344,7 @@ def assert_graphrag_configs(actual: GraphRagConfig, expected: GraphRagConfig) ->
|
||||
assert_cache_configs(actual.cache, expected.cache)
|
||||
assert_input_configs(actual.input, expected.input)
|
||||
assert_text_embedding_configs(actual.embed_text, expected.embed_text)
|
||||
assert_chunking_configs(actual.chunks, expected.chunks)
|
||||
assert_chunking_configs(actual.chunking, expected.chunking)
|
||||
assert_snapshots_configs(actual.snapshots, expected.snapshots)
|
||||
assert_extract_graph_configs(actual.extract_graph, expected.extract_graph)
|
||||
assert_extract_graph_nlp_configs(
|
||||
|
||||
@ -35,7 +35,7 @@ async def test_create_base_text_units_metadata():
|
||||
|
||||
config = GraphRagConfig(models=DEFAULT_MODEL_CONFIG) # type: ignore
|
||||
config.input.metadata = ["title"]
|
||||
config.chunks.prepend_metadata = True
|
||||
config.chunking.prepend_metadata = True
|
||||
|
||||
await update_document_metadata(config.input.metadata, context)
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user