mirror of
https://github.com/microsoft/graphrag.git
synced 2026-01-14 00:57:23 +08:00
Co-locate chunking/splitting
This commit is contained in:
parent
461291706f
commit
b63f747d44
@ -6,8 +6,8 @@
|
||||
from typing import Any
|
||||
|
||||
from graphrag.chunking.chunker import Chunker
|
||||
from graphrag.index.text_splitting.text_splitting import (
|
||||
split_single_text_on_tokens,
|
||||
from graphrag.chunking.token_text_splitter import (
|
||||
TokenTextSplitter,
|
||||
)
|
||||
from graphrag.tokenizer.get_tokenizer import get_tokenizer
|
||||
|
||||
@ -26,14 +26,12 @@ class TokenChunker(Chunker):
|
||||
self._size = size
|
||||
self._overlap = overlap
|
||||
self._encoding_model = encoding_model
|
||||
self._text_splitter = TokenTextSplitter(
|
||||
chunk_size=size,
|
||||
chunk_overlap=overlap,
|
||||
tokenizer=get_tokenizer(encoding_model=encoding_model),
|
||||
)
|
||||
|
||||
def chunk(self, text: str) -> list[str]:
|
||||
"""Chunk the text into token-based chunks."""
|
||||
tokenizer = get_tokenizer(encoding_model=self._encoding_model)
|
||||
return split_single_text_on_tokens(
|
||||
text,
|
||||
chunk_overlap=self._overlap,
|
||||
tokens_per_chunk=self._size,
|
||||
encode=tokenizer.encode,
|
||||
decode=tokenizer.decode,
|
||||
)
|
||||
return self._text_splitter.split_text(text)
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""A module containing 'TokenTextSplitter' class and 'split_single_text_on_tokens' function."""
|
||||
"""A module containing 'TokenTextSplitter' class and 'split_text_on_tokens' function."""
|
||||
|
||||
import logging
|
||||
from abc import ABC
|
||||
@ -10,7 +10,6 @@ from typing import cast
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from graphrag.tokenizer.get_tokenizer import get_tokenizer
|
||||
from graphrag.tokenizer.tokenizer import Tokenizer
|
||||
|
||||
EncodedText = list[int]
|
||||
@ -26,47 +25,33 @@ class TokenTextSplitter(ABC):
|
||||
|
||||
_chunk_size: int
|
||||
_chunk_overlap: int
|
||||
_length_function: LengthFn
|
||||
_keep_separator: bool
|
||||
_add_start_index: bool
|
||||
_strip_whitespace: bool
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer: Tokenizer,
|
||||
# based on OpenAI embedding chunk size limits
|
||||
# https://devblogs.microsoft.com/azure-sql/embedding-models-and-dimensions-optimizing-the-performance-resource-usage-ratio/
|
||||
chunk_size: int = 8191,
|
||||
chunk_overlap: int = 100,
|
||||
length_function: LengthFn = len,
|
||||
keep_separator: bool = False,
|
||||
add_start_index: bool = False,
|
||||
strip_whitespace: bool = True,
|
||||
tokenizer: Tokenizer | None = None,
|
||||
):
|
||||
"""Init method definition."""
|
||||
self._chunk_size = chunk_size
|
||||
self._chunk_overlap = chunk_overlap
|
||||
self._length_function = length_function
|
||||
self._keep_separator = keep_separator
|
||||
self._add_start_index = add_start_index
|
||||
self._strip_whitespace = strip_whitespace
|
||||
self._tokenizer = tokenizer or get_tokenizer()
|
||||
self._tokenizer = tokenizer
|
||||
|
||||
def num_tokens(self, text: str) -> int:
|
||||
"""Return the number of tokens in a string."""
|
||||
return self._tokenizer.num_tokens(text)
|
||||
|
||||
def split_text(self, text: str | list[str]) -> list[str]:
|
||||
def split_text(self, text: str) -> list[str]:
|
||||
"""Split text method."""
|
||||
if isinstance(text, list):
|
||||
text = " ".join(text)
|
||||
elif cast("bool", pd.isna(text)) or text == "":
|
||||
if cast("bool", pd.isna(text)) or text == "":
|
||||
return []
|
||||
if not isinstance(text, str):
|
||||
msg = f"Attempting to split a non-string value, actual is {type(text)}"
|
||||
raise TypeError(msg)
|
||||
|
||||
return split_single_text_on_tokens(
|
||||
return split_text_on_tokens(
|
||||
text,
|
||||
chunk_overlap=self._chunk_overlap,
|
||||
tokens_per_chunk=self._chunk_size,
|
||||
@ -75,7 +60,7 @@ class TokenTextSplitter(ABC):
|
||||
)
|
||||
|
||||
|
||||
def split_single_text_on_tokens(
|
||||
def split_text_on_tokens(
|
||||
text: str,
|
||||
tokens_per_chunk: int,
|
||||
chunk_overlap: int,
|
||||
@ -84,19 +69,19 @@ def split_single_text_on_tokens(
|
||||
) -> list[str]:
|
||||
"""Split a single text and return chunks using the tokenizer."""
|
||||
result = []
|
||||
input_ids = encode(text)
|
||||
input_tokens = encode(text)
|
||||
|
||||
start_idx = 0
|
||||
cur_idx = min(start_idx + tokens_per_chunk, len(input_ids))
|
||||
chunk_ids = input_ids[start_idx:cur_idx]
|
||||
cur_idx = min(start_idx + tokens_per_chunk, len(input_tokens))
|
||||
chunk_tokens = input_tokens[start_idx:cur_idx]
|
||||
|
||||
while start_idx < len(input_ids):
|
||||
chunk_text = decode(list(chunk_ids))
|
||||
while start_idx < len(input_tokens):
|
||||
chunk_text = decode(list(chunk_tokens))
|
||||
result.append(chunk_text) # Append chunked text as string
|
||||
if cur_idx == len(input_ids):
|
||||
if cur_idx == len(input_tokens):
|
||||
break
|
||||
start_idx += tokens_per_chunk - chunk_overlap
|
||||
cur_idx = min(start_idx + tokens_per_chunk, len(input_ids))
|
||||
chunk_ids = input_ids[start_idx:cur_idx]
|
||||
cur_idx = min(start_idx + tokens_per_chunk, len(input_tokens))
|
||||
chunk_tokens = input_tokens[start_idx:cur_idx]
|
||||
|
||||
return result
|
||||
@ -10,7 +10,7 @@ from dataclasses import dataclass
|
||||
import numpy as np
|
||||
|
||||
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
|
||||
from graphrag.index.text_splitting.text_splitting import TokenTextSplitter
|
||||
from graphrag.chunking.token_text_splitter import TokenTextSplitter
|
||||
from graphrag.index.utils.is_null import is_null
|
||||
from graphrag.language_model.protocol.base import EmbeddingModel
|
||||
from graphrag.logger.progress import ProgressTicker, progress_ticker
|
||||
|
||||
@ -5,10 +5,11 @@ from unittest import mock
|
||||
|
||||
import pytest
|
||||
import tiktoken
|
||||
from graphrag.index.text_splitting.text_splitting import (
|
||||
from graphrag.chunking.token_text_splitter import (
|
||||
TokenTextSplitter,
|
||||
split_single_text_on_tokens,
|
||||
split_text_on_tokens,
|
||||
)
|
||||
from graphrag.tokenizer.get_tokenizer import get_tokenizer
|
||||
|
||||
|
||||
class MockTokenizer:
|
||||
@ -19,31 +20,34 @@ class MockTokenizer:
|
||||
return "".join(chr(id) for id in token_ids)
|
||||
|
||||
|
||||
tokenizer = get_tokenizer()
|
||||
|
||||
|
||||
def test_split_text_str_empty():
|
||||
splitter = TokenTextSplitter(chunk_size=5, chunk_overlap=2)
|
||||
splitter = TokenTextSplitter(tokenizer=tokenizer, chunk_size=5, chunk_overlap=2)
|
||||
result = splitter.split_text("")
|
||||
|
||||
assert result == []
|
||||
|
||||
|
||||
def test_split_text_str_bool():
|
||||
splitter = TokenTextSplitter(chunk_size=5, chunk_overlap=2)
|
||||
splitter = TokenTextSplitter(tokenizer=tokenizer, chunk_size=5, chunk_overlap=2)
|
||||
result = splitter.split_text(None) # type: ignore
|
||||
|
||||
assert result == []
|
||||
|
||||
|
||||
def test_split_text_str_int():
|
||||
splitter = TokenTextSplitter(chunk_size=5, chunk_overlap=2)
|
||||
splitter = TokenTextSplitter(tokenizer=tokenizer, chunk_size=5, chunk_overlap=2)
|
||||
with pytest.raises(TypeError):
|
||||
splitter.split_text(123) # type: ignore
|
||||
|
||||
|
||||
@mock.patch("graphrag.index.text_splitting.text_splitting.split_single_text_on_tokens")
|
||||
@mock.patch("graphrag.chunking.token_text_splitter.split_text_on_tokens")
|
||||
def test_split_text_large_input(mock_split):
|
||||
large_text = "a" * 10_000
|
||||
mock_split.return_value = ["chunk"] * 2_000
|
||||
splitter = TokenTextSplitter(chunk_size=5, chunk_overlap=2)
|
||||
splitter = TokenTextSplitter(tokenizer=tokenizer, chunk_size=5, chunk_overlap=2)
|
||||
|
||||
result = splitter.split_text(large_text)
|
||||
|
||||
@ -51,20 +55,7 @@ def test_split_text_large_input(mock_split):
|
||||
mock_split.assert_called_once()
|
||||
|
||||
|
||||
@mock.patch("graphrag.index.text_splitting.text_splitting.split_single_text_on_tokens")
|
||||
def test_token_text_splitter(mock_split_text):
|
||||
expected_chunks = ["chunk1", "chunk2", "chunk3"]
|
||||
|
||||
mock_split_text.return_value = expected_chunks
|
||||
|
||||
splitter = TokenTextSplitter(chunk_size=5, chunk_overlap=2)
|
||||
|
||||
splitter.split_text(["chunk1", "chunk2", "chunk3"])
|
||||
|
||||
mock_split_text.assert_called_once()
|
||||
|
||||
|
||||
def test_split_single_text_on_tokens():
|
||||
def test_split_text_on_tokens():
|
||||
text = "This is a test text, meaning to be taken seriously by this test only."
|
||||
mocked_tokenizer = MockTokenizer()
|
||||
expected_splits = [
|
||||
@ -83,7 +74,7 @@ def test_split_single_text_on_tokens():
|
||||
"est only.",
|
||||
]
|
||||
|
||||
result = split_single_text_on_tokens(
|
||||
result = split_text_on_tokens(
|
||||
text=text,
|
||||
chunk_overlap=5,
|
||||
tokens_per_chunk=10,
|
||||
@ -93,7 +84,7 @@ def test_split_single_text_on_tokens():
|
||||
assert result == expected_splits
|
||||
|
||||
|
||||
def test_split_single_text_on_tokens_no_overlap():
|
||||
def test_split_text_on_tokens_no_overlap():
|
||||
text = "This is a test text, meaning to be taken seriously by this test only."
|
||||
enc = tiktoken.get_encoding("cl100k_base")
|
||||
|
||||
@ -123,7 +114,7 @@ def test_split_single_text_on_tokens_no_overlap():
|
||||
" only.",
|
||||
]
|
||||
|
||||
result = split_single_text_on_tokens(
|
||||
result = split_text_on_tokens(
|
||||
text=text,
|
||||
chunk_overlap=1,
|
||||
tokens_per_chunk=2,
|
||||
@ -1,2 +0,0 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
@ -1,2 +0,0 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
Loading…
Reference in New Issue
Block a user