mirror of
https://github.com/microsoft/graphrag.git
synced 2026-01-14 00:57:23 +08:00
Restore create_base_text_units parameterization
This commit is contained in:
parent
a20dbdb795
commit
e5c1aa7d52
@ -10,6 +10,7 @@ from graphrag_common.factory.factory import Factory, ServiceScope
|
||||
from graphrag.chunking.chunker import Chunker
|
||||
from graphrag.config.enums import ChunkStrategyType
|
||||
from graphrag.config.models.chunking_config import ChunkingConfig
|
||||
from graphrag.tokenizer.tokenizer import Tokenizer
|
||||
|
||||
|
||||
class ChunkerFactory(Factory[Chunker]):
|
||||
@ -36,7 +37,9 @@ def register_chunker(
|
||||
chunker_factory.register(chunker_type, chunker_initializer, scope)
|
||||
|
||||
|
||||
def create_chunker(config: ChunkingConfig) -> Chunker:
|
||||
def create_chunker(
|
||||
config: ChunkingConfig, tokenizer: Tokenizer | None = None
|
||||
) -> Chunker:
|
||||
"""Create a chunker implementation based on the given configuration.
|
||||
|
||||
Args
|
||||
@ -50,6 +53,8 @@ def create_chunker(config: ChunkingConfig) -> Chunker:
|
||||
The created chunker implementation.
|
||||
"""
|
||||
config_model = config.model_dump()
|
||||
if tokenizer is not None:
|
||||
config_model["tokenizer"] = tokenizer
|
||||
chunker_strategy = config.strategy
|
||||
|
||||
if chunker_strategy not in chunker_factory:
|
||||
@ -57,11 +62,11 @@ def create_chunker(config: ChunkingConfig) -> Chunker:
|
||||
case ChunkStrategyType.tokens:
|
||||
from graphrag.chunking.token_chunker import TokenChunker
|
||||
|
||||
chunker_factory.register(ChunkStrategyType.tokens, TokenChunker)
|
||||
register_chunker(ChunkStrategyType.tokens, TokenChunker)
|
||||
case ChunkStrategyType.sentence:
|
||||
from graphrag.chunking.sentence_chunker import SentenceChunker
|
||||
|
||||
chunker_factory.register(ChunkStrategyType.sentence, SentenceChunker)
|
||||
register_chunker(ChunkStrategyType.sentence, SentenceChunker)
|
||||
case _:
|
||||
msg = f"ChunkingConfig.strategy '{chunker_strategy}' is not registered in the ChunkerFactory. Registered types: {', '.join(chunker_factory.keys())}."
|
||||
raise ValueError(msg)
|
||||
|
||||
@ -7,7 +7,7 @@ from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
from graphrag.chunking.chunker import Chunker
|
||||
from graphrag.tokenizer.get_tokenizer import get_tokenizer
|
||||
from graphrag.tokenizer.tokenizer import Tokenizer
|
||||
|
||||
EncodedText = list[int]
|
||||
DecodeFn = Callable[[EncodedText], str]
|
||||
@ -22,13 +22,14 @@ class TokenChunker(Chunker):
|
||||
size: int,
|
||||
overlap: int,
|
||||
encoding_model: str,
|
||||
tokenizer: Tokenizer,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Create a token chunker instance."""
|
||||
self._size = size
|
||||
self._overlap = overlap
|
||||
self._encoding_model = encoding_model
|
||||
self._tokenizer = get_tokenizer(encoding_model=encoding_model)
|
||||
self._tokenizer = tokenizer
|
||||
|
||||
def chunk(self, text: str) -> list[str]:
|
||||
"""Chunk the text into token-based chunks."""
|
||||
|
||||
@ -19,6 +19,7 @@ from graphrag.index.typing.workflow import WorkflowFunctionOutput
|
||||
from graphrag.index.utils.hashing import gen_sha512_hash
|
||||
from graphrag.logger.progress import progress_ticker
|
||||
from graphrag.tokenizer.get_tokenizer import get_tokenizer
|
||||
from graphrag.tokenizer.tokenizer import Tokenizer
|
||||
from graphrag.utils.storage import load_table_from_storage, write_table_to_storage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -32,7 +33,18 @@ async def run_workflow(
|
||||
logger.info("Workflow started: create_base_text_units")
|
||||
documents = await load_table_from_storage("documents", context.output_storage)
|
||||
|
||||
output = create_base_text_units(documents, context.callbacks, config.chunks)
|
||||
chunking_config = config.chunks
|
||||
tokenizer = get_tokenizer(encoding_model=chunking_config.encoding_model)
|
||||
|
||||
output = create_base_text_units(
|
||||
documents,
|
||||
context.callbacks,
|
||||
tokenizer=tokenizer,
|
||||
chunk_size=chunking_config.size,
|
||||
chunk_overlap=chunking_config.overlap,
|
||||
prepend_metadata=chunking_config.prepend_metadata,
|
||||
chunk_size_includes_metadata=chunking_config.chunk_size_includes_metadata,
|
||||
)
|
||||
|
||||
await write_table_to_storage(output, "text_units", context.output_storage)
|
||||
|
||||
@ -43,17 +55,15 @@ async def run_workflow(
|
||||
def create_base_text_units(
|
||||
documents: pd.DataFrame,
|
||||
callbacks: WorkflowCallbacks,
|
||||
chunks_config: ChunkingConfig,
|
||||
tokenizer: Tokenizer,
|
||||
chunk_size: int,
|
||||
chunk_overlap: int,
|
||||
prepend_metadata: bool,
|
||||
chunk_size_includes_metadata: bool,
|
||||
) -> pd.DataFrame:
|
||||
"""All the steps to transform base text_units."""
|
||||
documents.sort_values(by=["id"], ascending=[True], inplace=True)
|
||||
|
||||
size = chunks_config.size
|
||||
encoding_model = chunks_config.encoding_model
|
||||
prepend_metadata = chunks_config.prepend_metadata
|
||||
chunk_size_includes_metadata = chunks_config.chunk_size_includes_metadata
|
||||
|
||||
tokenizer = get_tokenizer(encoding_model=encoding_model)
|
||||
num_total = _get_num_total(documents, "text")
|
||||
tick = progress_ticker(callbacks.progress, num_total)
|
||||
|
||||
@ -74,12 +84,17 @@ def create_base_text_units(
|
||||
|
||||
if chunk_size_includes_metadata:
|
||||
metadata_tokens = len(tokenizer.encode(metadata_str))
|
||||
if metadata_tokens >= size:
|
||||
if metadata_tokens >= chunk_size:
|
||||
message = "Metadata tokens exceeds the maximum tokens per chunk. Please increase the tokens per chunk."
|
||||
raise ValueError(message)
|
||||
|
||||
chunks_config.size = size - metadata_tokens
|
||||
chunker = create_chunker(chunks_config)
|
||||
chunker = create_chunker(
|
||||
ChunkingConfig(
|
||||
size=chunk_size - metadata_tokens,
|
||||
overlap=chunk_overlap,
|
||||
),
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
|
||||
chunked = _chunk_text(
|
||||
pd.DataFrame([row]).reset_index(drop=True),
|
||||
|
||||
@ -70,7 +70,11 @@ async def load_docs_in_chunks(
|
||||
chunks_df = create_base_text_units(
|
||||
documents=dataset,
|
||||
callbacks=NoopWorkflowCallbacks(),
|
||||
chunks_config=config.chunks,
|
||||
tokenizer=tokenizer,
|
||||
chunk_size=config.chunks.size,
|
||||
chunk_overlap=config.chunks.overlap,
|
||||
prepend_metadata=config.chunks.prepend_metadata,
|
||||
chunk_size_includes_metadata=config.chunks.chunk_size_includes_metadata,
|
||||
)
|
||||
|
||||
# Depending on the select method, build the dataset
|
||||
|
||||
@ -13,6 +13,16 @@ from graphrag.config.models.chunking_config import ChunkingConfig
|
||||
from graphrag.tokenizer.get_tokenizer import get_tokenizer
|
||||
|
||||
|
||||
class MockTokenizer:
|
||||
def encode(self, text):
|
||||
return [ord(char) for char in text]
|
||||
|
||||
def decode(self, token_ids):
|
||||
return "".join(chr(id) for id in token_ids)
|
||||
|
||||
|
||||
tokenizer = get_tokenizer()
|
||||
|
||||
class TestRunSentences:
|
||||
def setup_method(self, method):
|
||||
bootstrap()
|
||||
@ -59,23 +69,11 @@ class TestRunTokens:
|
||||
strategy=ChunkStrategyType.tokens,
|
||||
)
|
||||
|
||||
chunker = create_chunker(config)
|
||||
chunker = create_chunker(config, tokenizer=tokenizer)
|
||||
chunks = chunker.chunk(input)
|
||||
|
||||
assert len(chunks) > 0
|
||||
|
||||
|
||||
class MockTokenizer:
|
||||
def encode(self, text):
|
||||
return [ord(char) for char in text]
|
||||
|
||||
def decode(self, token_ids):
|
||||
return "".join(chr(id) for id in token_ids)
|
||||
|
||||
|
||||
tokenizer = get_tokenizer()
|
||||
|
||||
|
||||
def test_split_text_str_empty():
|
||||
result = split_text_on_tokens(
|
||||
"",
|
||||
|
||||
Loading…
Reference in New Issue
Block a user