mirror of
https://github.com/microsoft/graphrag.git
synced 2026-01-14 00:57:23 +08:00
Collapse token splitting functionality into one class/function
This commit is contained in:
parent
b63f747d44
commit
a20dbdb795
@ -3,14 +3,16 @@
|
||||
|
||||
"""A module containing 'TokenChunker' class."""
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
from graphrag.chunking.chunker import Chunker
|
||||
from graphrag.chunking.token_text_splitter import (
|
||||
TokenTextSplitter,
|
||||
)
|
||||
from graphrag.tokenizer.get_tokenizer import get_tokenizer
|
||||
|
||||
EncodedText = list[int]
|
||||
DecodeFn = Callable[[EncodedText], str]
|
||||
EncodeFn = Callable[[str], EncodedText]
|
||||
|
||||
|
||||
class TokenChunker(Chunker):
|
||||
"""A chunker that splits text into token-based chunks."""
|
||||
@ -26,12 +28,41 @@ class TokenChunker(Chunker):
|
||||
self._size = size
|
||||
self._overlap = overlap
|
||||
self._encoding_model = encoding_model
|
||||
self._text_splitter = TokenTextSplitter(
|
||||
chunk_size=size,
|
||||
chunk_overlap=overlap,
|
||||
tokenizer=get_tokenizer(encoding_model=encoding_model),
|
||||
)
|
||||
self._tokenizer = get_tokenizer(encoding_model=encoding_model)
|
||||
|
||||
def chunk(self, text: str) -> list[str]:
|
||||
"""Chunk the text into token-based chunks."""
|
||||
return self._text_splitter.split_text(text)
|
||||
return split_text_on_tokens(
|
||||
text,
|
||||
chunk_size=self._size,
|
||||
chunk_overlap=self._overlap,
|
||||
encode=self._tokenizer.encode,
|
||||
decode=self._tokenizer.decode,
|
||||
)
|
||||
|
||||
|
||||
def split_text_on_tokens(
|
||||
text: str,
|
||||
chunk_size: int,
|
||||
chunk_overlap: int,
|
||||
encode: EncodeFn,
|
||||
decode: DecodeFn,
|
||||
) -> list[str]:
|
||||
"""Split a single text and return chunks using the tokenizer."""
|
||||
result = []
|
||||
input_tokens = encode(text)
|
||||
|
||||
start_idx = 0
|
||||
cur_idx = min(start_idx + chunk_size, len(input_tokens))
|
||||
chunk_tokens = input_tokens[start_idx:cur_idx]
|
||||
|
||||
while start_idx < len(input_tokens):
|
||||
chunk_text = decode(list(chunk_tokens))
|
||||
result.append(chunk_text) # Append chunked text as string
|
||||
if cur_idx == len(input_tokens):
|
||||
break
|
||||
start_idx += chunk_size - chunk_overlap
|
||||
cur_idx = min(start_idx + chunk_size, len(input_tokens))
|
||||
chunk_tokens = input_tokens[start_idx:cur_idx]
|
||||
|
||||
return result
|
||||
|
||||
@ -1,87 +0,0 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""A module containing 'TokenTextSplitter' class and 'split_text_on_tokens' function."""
|
||||
|
||||
import logging
|
||||
from abc import ABC
|
||||
from collections.abc import Callable
|
||||
from typing import cast
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from graphrag.tokenizer.tokenizer import Tokenizer
|
||||
|
||||
EncodedText = list[int]
|
||||
DecodeFn = Callable[[EncodedText], str]
|
||||
EncodeFn = Callable[[str], EncodedText]
|
||||
LengthFn = Callable[[str], int]
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TokenTextSplitter(ABC):
|
||||
"""Token text splitter class definition."""
|
||||
|
||||
_chunk_size: int
|
||||
_chunk_overlap: int
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer: Tokenizer,
|
||||
# based on OpenAI embedding chunk size limits
|
||||
# https://devblogs.microsoft.com/azure-sql/embedding-models-and-dimensions-optimizing-the-performance-resource-usage-ratio/
|
||||
chunk_size: int = 8191,
|
||||
chunk_overlap: int = 100,
|
||||
):
|
||||
"""Init method definition."""
|
||||
self._chunk_size = chunk_size
|
||||
self._chunk_overlap = chunk_overlap
|
||||
self._tokenizer = tokenizer
|
||||
|
||||
def num_tokens(self, text: str) -> int:
|
||||
"""Return the number of tokens in a string."""
|
||||
return self._tokenizer.num_tokens(text)
|
||||
|
||||
def split_text(self, text: str) -> list[str]:
|
||||
"""Split text method."""
|
||||
if cast("bool", pd.isna(text)) or text == "":
|
||||
return []
|
||||
if not isinstance(text, str):
|
||||
msg = f"Attempting to split a non-string value, actual is {type(text)}"
|
||||
raise TypeError(msg)
|
||||
|
||||
return split_text_on_tokens(
|
||||
text,
|
||||
chunk_overlap=self._chunk_overlap,
|
||||
tokens_per_chunk=self._chunk_size,
|
||||
decode=self._tokenizer.decode,
|
||||
encode=self._tokenizer.encode,
|
||||
)
|
||||
|
||||
|
||||
def split_text_on_tokens(
|
||||
text: str,
|
||||
tokens_per_chunk: int,
|
||||
chunk_overlap: int,
|
||||
encode: EncodeFn,
|
||||
decode: DecodeFn,
|
||||
) -> list[str]:
|
||||
"""Split a single text and return chunks using the tokenizer."""
|
||||
result = []
|
||||
input_tokens = encode(text)
|
||||
|
||||
start_idx = 0
|
||||
cur_idx = min(start_idx + tokens_per_chunk, len(input_tokens))
|
||||
chunk_tokens = input_tokens[start_idx:cur_idx]
|
||||
|
||||
while start_idx < len(input_tokens):
|
||||
chunk_text = decode(list(chunk_tokens))
|
||||
result.append(chunk_text) # Append chunked text as string
|
||||
if cur_idx == len(input_tokens):
|
||||
break
|
||||
start_idx += tokens_per_chunk - chunk_overlap
|
||||
cur_idx = min(start_idx + tokens_per_chunk, len(input_tokens))
|
||||
chunk_tokens = input_tokens[start_idx:cur_idx]
|
||||
|
||||
return result
|
||||
@ -10,7 +10,7 @@ from dataclasses import dataclass
|
||||
import numpy as np
|
||||
|
||||
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
|
||||
from graphrag.chunking.token_text_splitter import TokenTextSplitter
|
||||
from graphrag.chunking.token_chunker import split_text_on_tokens
|
||||
from graphrag.index.utils.is_null import is_null
|
||||
from graphrag.language_model.protocol.base import EmbeddingModel
|
||||
from graphrag.logger.progress import ProgressTicker, progress_ticker
|
||||
@ -39,17 +39,15 @@ async def run_embed_text(
|
||||
if is_null(input):
|
||||
return TextEmbeddingResult(embeddings=None)
|
||||
|
||||
splitter = _get_splitter(tokenizer, batch_max_tokens)
|
||||
|
||||
semaphore: asyncio.Semaphore = asyncio.Semaphore(num_threads)
|
||||
|
||||
# Break up the input texts. The sizes here indicate how many snippets are in each input text
|
||||
texts, input_sizes = _prepare_embed_texts(input, splitter)
|
||||
texts, input_sizes = _prepare_embed_texts(input, tokenizer, batch_max_tokens)
|
||||
text_batches = _create_text_batches(
|
||||
texts,
|
||||
tokenizer,
|
||||
batch_size,
|
||||
batch_max_tokens,
|
||||
splitter,
|
||||
)
|
||||
logger.info(
|
||||
"embedding %d inputs via %d snippets using %d batches. max_batch_size=%d, batch_max_tokens=%d",
|
||||
@ -72,13 +70,6 @@ async def run_embed_text(
|
||||
return TextEmbeddingResult(embeddings=embeddings)
|
||||
|
||||
|
||||
def _get_splitter(tokenizer: Tokenizer, batch_max_tokens: int) -> TokenTextSplitter:
|
||||
return TokenTextSplitter(
|
||||
tokenizer=tokenizer,
|
||||
chunk_size=batch_max_tokens,
|
||||
)
|
||||
|
||||
|
||||
async def _execute(
|
||||
model: EmbeddingModel,
|
||||
chunks: list[list[str]],
|
||||
@ -100,9 +91,9 @@ async def _execute(
|
||||
|
||||
def _create_text_batches(
|
||||
texts: list[str],
|
||||
tokenizer: Tokenizer,
|
||||
max_batch_size: int,
|
||||
max_batch_tokens: int,
|
||||
splitter: TokenTextSplitter,
|
||||
) -> list[list[str]]:
|
||||
"""Create batches of texts to embed."""
|
||||
# https://learn.microsoft.com/en-us/azure/ai-services/openai/reference
|
||||
@ -112,7 +103,7 @@ def _create_text_batches(
|
||||
current_batch_tokens = 0
|
||||
|
||||
for text in texts:
|
||||
token_count = splitter.num_tokens(text)
|
||||
token_count = tokenizer.num_tokens(text)
|
||||
if (
|
||||
len(current_batch) >= max_batch_size
|
||||
or current_batch_tokens + token_count > max_batch_tokens
|
||||
@ -131,18 +122,23 @@ def _create_text_batches(
|
||||
|
||||
|
||||
def _prepare_embed_texts(
|
||||
input: list[str], splitter: TokenTextSplitter
|
||||
input: list[str],
|
||||
tokenizer: Tokenizer,
|
||||
batch_max_tokens: int = 8191,
|
||||
chunk_overlap: int = 100,
|
||||
) -> tuple[list[str], list[int]]:
|
||||
sizes: list[int] = []
|
||||
snippets: list[str] = []
|
||||
|
||||
for text in input:
|
||||
# Split the input text and filter out any empty content
|
||||
split_texts = splitter.split_text(text)
|
||||
if split_texts is None:
|
||||
continue
|
||||
split_texts = split_text_on_tokens(
|
||||
text,
|
||||
chunk_size=batch_max_tokens,
|
||||
chunk_overlap=chunk_overlap,
|
||||
encode=tokenizer.encode,
|
||||
decode=tokenizer.decode,
|
||||
)
|
||||
split_texts = [text for text in split_texts if len(text) > 0]
|
||||
|
||||
sizes.append(len(split_texts))
|
||||
snippets.extend(split_texts)
|
||||
|
||||
|
||||
@ -5,6 +5,9 @@ from unittest.mock import Mock, patch
|
||||
|
||||
from graphrag.chunking.bootstrap import bootstrap
|
||||
from graphrag.chunking.chunker_factory import create_chunker
|
||||
from graphrag.chunking.token_chunker import (
|
||||
split_text_on_tokens,
|
||||
)
|
||||
from graphrag.config.enums import ChunkStrategyType
|
||||
from graphrag.config.models.chunking_config import ChunkingConfig
|
||||
from graphrag.tokenizer.get_tokenizer import get_tokenizer
|
||||
@ -62,6 +65,89 @@ class TestRunTokens:
|
||||
assert len(chunks) > 0
|
||||
|
||||
|
||||
class MockTokenizer:
|
||||
def encode(self, text):
|
||||
return [ord(char) for char in text]
|
||||
|
||||
def decode(self, token_ids):
|
||||
return "".join(chr(id) for id in token_ids)
|
||||
|
||||
|
||||
tokenizer = get_tokenizer()
|
||||
|
||||
|
||||
def test_split_text_str_empty():
|
||||
result = split_text_on_tokens(
|
||||
"",
|
||||
chunk_size=5,
|
||||
chunk_overlap=2,
|
||||
encode=tokenizer.encode,
|
||||
decode=tokenizer.decode,
|
||||
)
|
||||
assert result == []
|
||||
|
||||
|
||||
def test_split_text_on_tokens():
|
||||
text = "This is a test text, meaning to be taken seriously by this test only."
|
||||
mocked_tokenizer = MockTokenizer()
|
||||
expected_splits = [
|
||||
"This is a ",
|
||||
"is a test ",
|
||||
"test text,",
|
||||
"text, mean",
|
||||
" meaning t",
|
||||
"ing to be ",
|
||||
"o be taken",
|
||||
"taken seri", # cspell:disable-line
|
||||
" seriously",
|
||||
"ously by t", # cspell:disable-line
|
||||
" by this t",
|
||||
"his test o",
|
||||
"est only.",
|
||||
]
|
||||
|
||||
result = split_text_on_tokens(
|
||||
text=text,
|
||||
chunk_overlap=5,
|
||||
chunk_size=10,
|
||||
decode=mocked_tokenizer.decode,
|
||||
encode=lambda text: mocked_tokenizer.encode(text),
|
||||
)
|
||||
assert result == expected_splits
|
||||
|
||||
|
||||
def test_split_text_on_tokens_no_overlap():
|
||||
text = "This is a test text, meaning to be taken seriously by this test only."
|
||||
tok = get_tokenizer(encoding_model="cl100k_base")
|
||||
|
||||
expected_splits = [
|
||||
"This is",
|
||||
" is a",
|
||||
" a test",
|
||||
" test text",
|
||||
" text,",
|
||||
", meaning",
|
||||
" meaning to",
|
||||
" to be",
|
||||
" be taken", # cspell:disable-line
|
||||
" taken seriously", # cspell:disable-line
|
||||
" seriously by",
|
||||
" by this", # cspell:disable-line
|
||||
" this test",
|
||||
" test only",
|
||||
" only.",
|
||||
]
|
||||
|
||||
result = split_text_on_tokens(
|
||||
text=text,
|
||||
chunk_size=2,
|
||||
chunk_overlap=1,
|
||||
decode=tok.decode,
|
||||
encode=tok.encode,
|
||||
)
|
||||
assert result == expected_splits
|
||||
|
||||
|
||||
@patch("tiktoken.get_encoding")
|
||||
def test_get_encoding_fn_encode(mock_get_encoding):
|
||||
# Create a mock encoding object with encode and decode methods
|
||||
|
||||
@ -1,124 +0,0 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
import tiktoken
|
||||
from graphrag.chunking.token_text_splitter import (
|
||||
TokenTextSplitter,
|
||||
split_text_on_tokens,
|
||||
)
|
||||
from graphrag.tokenizer.get_tokenizer import get_tokenizer
|
||||
|
||||
|
||||
class MockTokenizer:
|
||||
def encode(self, text):
|
||||
return [ord(char) for char in text]
|
||||
|
||||
def decode(self, token_ids):
|
||||
return "".join(chr(id) for id in token_ids)
|
||||
|
||||
|
||||
tokenizer = get_tokenizer()
|
||||
|
||||
|
||||
def test_split_text_str_empty():
|
||||
splitter = TokenTextSplitter(tokenizer=tokenizer, chunk_size=5, chunk_overlap=2)
|
||||
result = splitter.split_text("")
|
||||
|
||||
assert result == []
|
||||
|
||||
|
||||
def test_split_text_str_bool():
|
||||
splitter = TokenTextSplitter(tokenizer=tokenizer, chunk_size=5, chunk_overlap=2)
|
||||
result = splitter.split_text(None) # type: ignore
|
||||
|
||||
assert result == []
|
||||
|
||||
|
||||
def test_split_text_str_int():
|
||||
splitter = TokenTextSplitter(tokenizer=tokenizer, chunk_size=5, chunk_overlap=2)
|
||||
with pytest.raises(TypeError):
|
||||
splitter.split_text(123) # type: ignore
|
||||
|
||||
|
||||
@mock.patch("graphrag.chunking.token_text_splitter.split_text_on_tokens")
|
||||
def test_split_text_large_input(mock_split):
|
||||
large_text = "a" * 10_000
|
||||
mock_split.return_value = ["chunk"] * 2_000
|
||||
splitter = TokenTextSplitter(tokenizer=tokenizer, chunk_size=5, chunk_overlap=2)
|
||||
|
||||
result = splitter.split_text(large_text)
|
||||
|
||||
assert len(result) == 2_000, "Large input was not split correctly"
|
||||
mock_split.assert_called_once()
|
||||
|
||||
|
||||
def test_split_text_on_tokens():
|
||||
text = "This is a test text, meaning to be taken seriously by this test only."
|
||||
mocked_tokenizer = MockTokenizer()
|
||||
expected_splits = [
|
||||
"This is a ",
|
||||
"is a test ",
|
||||
"test text,",
|
||||
"text, mean",
|
||||
" meaning t",
|
||||
"ing to be ",
|
||||
"o be taken",
|
||||
"taken seri", # cspell:disable-line
|
||||
" seriously",
|
||||
"ously by t", # cspell:disable-line
|
||||
" by this t",
|
||||
"his test o",
|
||||
"est only.",
|
||||
]
|
||||
|
||||
result = split_text_on_tokens(
|
||||
text=text,
|
||||
chunk_overlap=5,
|
||||
tokens_per_chunk=10,
|
||||
decode=mocked_tokenizer.decode,
|
||||
encode=lambda text: mocked_tokenizer.encode(text),
|
||||
)
|
||||
assert result == expected_splits
|
||||
|
||||
|
||||
def test_split_text_on_tokens_no_overlap():
|
||||
text = "This is a test text, meaning to be taken seriously by this test only."
|
||||
enc = tiktoken.get_encoding("cl100k_base")
|
||||
|
||||
def encode(text: str) -> list[int]:
|
||||
if not isinstance(text, str):
|
||||
text = f"{text}"
|
||||
return enc.encode(text)
|
||||
|
||||
def decode(tokens: list[int]) -> str:
|
||||
return enc.decode(tokens)
|
||||
|
||||
expected_splits = [
|
||||
"This is",
|
||||
" is a",
|
||||
" a test",
|
||||
" test text",
|
||||
" text,",
|
||||
", meaning",
|
||||
" meaning to",
|
||||
" to be",
|
||||
" be taken", # cspell:disable-line
|
||||
" taken seriously", # cspell:disable-line
|
||||
" seriously by",
|
||||
" by this", # cspell:disable-line
|
||||
" this test",
|
||||
" test only",
|
||||
" only.",
|
||||
]
|
||||
|
||||
result = split_text_on_tokens(
|
||||
text=text,
|
||||
chunk_overlap=1,
|
||||
tokens_per_chunk=2,
|
||||
decode=decode,
|
||||
encode=lambda text: encode(text),
|
||||
)
|
||||
assert result == expected_splits
|
||||
Loading…
Reference in New Issue
Block a user