Restore create_base_text_units parameterization

This commit is contained in:
Nathan Evans 2025-12-18 13:46:22 -08:00
parent a20dbdb795
commit e5c1aa7d52
5 changed files with 53 additions and 30 deletions

View File

@ -10,6 +10,7 @@ from graphrag_common.factory.factory import Factory, ServiceScope
from graphrag.chunking.chunker import Chunker
from graphrag.config.enums import ChunkStrategyType
from graphrag.config.models.chunking_config import ChunkingConfig
from graphrag.tokenizer.tokenizer import Tokenizer
class ChunkerFactory(Factory[Chunker]):
@ -36,7 +37,9 @@ def register_chunker(
chunker_factory.register(chunker_type, chunker_initializer, scope)
def create_chunker(config: ChunkingConfig) -> Chunker:
def create_chunker(
config: ChunkingConfig, tokenizer: Tokenizer | None = None
) -> Chunker:
"""Create a chunker implementation based on the given configuration.
Args
@ -50,6 +53,8 @@ def create_chunker(config: ChunkingConfig) -> Chunker:
The created chunker implementation.
"""
config_model = config.model_dump()
if tokenizer is not None:
config_model["tokenizer"] = tokenizer
chunker_strategy = config.strategy
if chunker_strategy not in chunker_factory:
@ -57,11 +62,11 @@ def create_chunker(config: ChunkingConfig) -> Chunker:
case ChunkStrategyType.tokens:
from graphrag.chunking.token_chunker import TokenChunker
chunker_factory.register(ChunkStrategyType.tokens, TokenChunker)
register_chunker(ChunkStrategyType.tokens, TokenChunker)
case ChunkStrategyType.sentence:
from graphrag.chunking.sentence_chunker import SentenceChunker
chunker_factory.register(ChunkStrategyType.sentence, SentenceChunker)
register_chunker(ChunkStrategyType.sentence, SentenceChunker)
case _:
msg = f"ChunkingConfig.strategy '{chunker_strategy}' is not registered in the ChunkerFactory. Registered types: {', '.join(chunker_factory.keys())}."
raise ValueError(msg)

View File

@ -7,7 +7,7 @@ from collections.abc import Callable
from typing import Any
from graphrag.chunking.chunker import Chunker
from graphrag.tokenizer.get_tokenizer import get_tokenizer
from graphrag.tokenizer.tokenizer import Tokenizer
EncodedText = list[int]
DecodeFn = Callable[[EncodedText], str]
@ -22,13 +22,14 @@ class TokenChunker(Chunker):
size: int,
overlap: int,
encoding_model: str,
tokenizer: Tokenizer,
**kwargs: Any,
) -> None:
"""Create a token chunker instance."""
self._size = size
self._overlap = overlap
self._encoding_model = encoding_model
self._tokenizer = get_tokenizer(encoding_model=encoding_model)
self._tokenizer = tokenizer
def chunk(self, text: str) -> list[str]:
"""Chunk the text into token-based chunks."""

View File

@ -19,6 +19,7 @@ from graphrag.index.typing.workflow import WorkflowFunctionOutput
from graphrag.index.utils.hashing import gen_sha512_hash
from graphrag.logger.progress import progress_ticker
from graphrag.tokenizer.get_tokenizer import get_tokenizer
from graphrag.tokenizer.tokenizer import Tokenizer
from graphrag.utils.storage import load_table_from_storage, write_table_to_storage
logger = logging.getLogger(__name__)
@ -32,7 +33,18 @@ async def run_workflow(
logger.info("Workflow started: create_base_text_units")
documents = await load_table_from_storage("documents", context.output_storage)
output = create_base_text_units(documents, context.callbacks, config.chunks)
chunking_config = config.chunks
tokenizer = get_tokenizer(encoding_model=chunking_config.encoding_model)
output = create_base_text_units(
documents,
context.callbacks,
tokenizer=tokenizer,
chunk_size=chunking_config.size,
chunk_overlap=chunking_config.overlap,
prepend_metadata=chunking_config.prepend_metadata,
chunk_size_includes_metadata=chunking_config.chunk_size_includes_metadata,
)
await write_table_to_storage(output, "text_units", context.output_storage)
@ -43,17 +55,15 @@ async def run_workflow(
def create_base_text_units(
documents: pd.DataFrame,
callbacks: WorkflowCallbacks,
chunks_config: ChunkingConfig,
tokenizer: Tokenizer,
chunk_size: int,
chunk_overlap: int,
prepend_metadata: bool,
chunk_size_includes_metadata: bool,
) -> pd.DataFrame:
"""All the steps to transform base text_units."""
documents.sort_values(by=["id"], ascending=[True], inplace=True)
size = chunks_config.size
encoding_model = chunks_config.encoding_model
prepend_metadata = chunks_config.prepend_metadata
chunk_size_includes_metadata = chunks_config.chunk_size_includes_metadata
tokenizer = get_tokenizer(encoding_model=encoding_model)
num_total = _get_num_total(documents, "text")
tick = progress_ticker(callbacks.progress, num_total)
@ -74,12 +84,17 @@ def create_base_text_units(
if chunk_size_includes_metadata:
metadata_tokens = len(tokenizer.encode(metadata_str))
if metadata_tokens >= size:
if metadata_tokens >= chunk_size:
message = "Metadata tokens exceeds the maximum tokens per chunk. Please increase the tokens per chunk."
raise ValueError(message)
chunks_config.size = size - metadata_tokens
chunker = create_chunker(chunks_config)
chunker = create_chunker(
ChunkingConfig(
size=chunk_size - metadata_tokens,
overlap=chunk_overlap,
),
tokenizer=tokenizer,
)
chunked = _chunk_text(
pd.DataFrame([row]).reset_index(drop=True),

View File

@ -70,7 +70,11 @@ async def load_docs_in_chunks(
chunks_df = create_base_text_units(
documents=dataset,
callbacks=NoopWorkflowCallbacks(),
chunks_config=config.chunks,
tokenizer=tokenizer,
chunk_size=config.chunks.size,
chunk_overlap=config.chunks.overlap,
prepend_metadata=config.chunks.prepend_metadata,
chunk_size_includes_metadata=config.chunks.chunk_size_includes_metadata,
)
# Depending on the select method, build the dataset

View File

@ -13,6 +13,16 @@ from graphrag.config.models.chunking_config import ChunkingConfig
from graphrag.tokenizer.get_tokenizer import get_tokenizer
class MockTokenizer:
def encode(self, text):
return [ord(char) for char in text]
def decode(self, token_ids):
return "".join(chr(id) for id in token_ids)
tokenizer = get_tokenizer()
class TestRunSentences:
def setup_method(self, method):
bootstrap()
@ -59,23 +69,11 @@ class TestRunTokens:
strategy=ChunkStrategyType.tokens,
)
chunker = create_chunker(config)
chunker = create_chunker(config, tokenizer=tokenizer)
chunks = chunker.chunk(input)
assert len(chunks) > 0
class MockTokenizer:
def encode(self, text):
return [ord(char) for char in text]
def decode(self, token_ids):
return "".join(chr(id) for id in token_ids)
tokenizer = get_tokenizer()
def test_split_text_str_empty():
result = split_text_on_tokens(
"",