graphrag/graphrag/index/text_splitting/text_splitting.py
Nathan Evans e2a448170a
Fix/minor query fixes (#1893)
* fixed token count for drift search

* basic search fixes

* updated basic search prompt

* fixed text splitting logic

* Lint/format

* Semver

* Fix text splitting tests

---------

Co-authored-by: ha2trinh <trinhha@microsoft.com>
2025-04-25 14:12:18 -07:00

198 lines
6.7 KiB
Python

# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""A module containing the 'Tokenizer', 'TextSplitter', 'NoopTextSplitter' and 'TokenTextSplitter' models."""
import logging
from abc import ABC, abstractmethod
from collections.abc import Callable, Collection, Iterable
from dataclasses import dataclass
from typing import Any, Literal, cast
import pandas as pd
import tiktoken
import graphrag.config.defaults as defs
from graphrag.index.operations.chunk_text.typing import TextChunk
from graphrag.logger.progress import ProgressTicker
EncodedText = list[int]
DecodeFn = Callable[[EncodedText], str]
EncodeFn = Callable[[str], EncodedText]
LengthFn = Callable[[str], int]
log = logging.getLogger(__name__)
@dataclass(frozen=True)
class Tokenizer:
"""Tokenizer data class."""
chunk_overlap: int
"""Overlap in tokens between chunks"""
tokens_per_chunk: int
"""Maximum number of tokens per chunk"""
decode: DecodeFn
""" Function to decode a list of token ids to a string"""
encode: EncodeFn
""" Function to encode a string to a list of token ids"""
class TextSplitter(ABC):
"""Text splitter class definition."""
_chunk_size: int
_chunk_overlap: int
_length_function: LengthFn
_keep_separator: bool
_add_start_index: bool
_strip_whitespace: bool
def __init__(
self,
# based on text-ada-002-embedding max input buffer length
# https://platform.openai.com/docs/guides/embeddings/second-generation-models
chunk_size: int = 8191,
chunk_overlap: int = 100,
length_function: LengthFn = len,
keep_separator: bool = False,
add_start_index: bool = False,
strip_whitespace: bool = True,
):
"""Init method definition."""
self._chunk_size = chunk_size
self._chunk_overlap = chunk_overlap
self._length_function = length_function
self._keep_separator = keep_separator
self._add_start_index = add_start_index
self._strip_whitespace = strip_whitespace
@abstractmethod
def split_text(self, text: str | list[str]) -> Iterable[str]:
"""Split text method definition."""
class NoopTextSplitter(TextSplitter):
"""Noop text splitter class definition."""
def split_text(self, text: str | list[str]) -> Iterable[str]:
"""Split text method definition."""
return [text] if isinstance(text, str) else text
class TokenTextSplitter(TextSplitter):
"""Token text splitter class definition."""
_allowed_special: Literal["all"] | set[str]
_disallowed_special: Literal["all"] | Collection[str]
def __init__(
self,
encoding_name: str = defs.ENCODING_MODEL,
model_name: str | None = None,
allowed_special: Literal["all"] | set[str] | None = None,
disallowed_special: Literal["all"] | Collection[str] = "all",
**kwargs: Any,
):
"""Init method definition."""
super().__init__(**kwargs)
if model_name is not None:
try:
enc = tiktoken.encoding_for_model(model_name)
except KeyError:
log.exception("Model %s not found, using %s", model_name, encoding_name)
enc = tiktoken.get_encoding(encoding_name)
else:
enc = tiktoken.get_encoding(encoding_name)
self._tokenizer = enc
self._allowed_special = allowed_special or set()
self._disallowed_special = disallowed_special
def encode(self, text: str) -> list[int]:
"""Encode the given text into an int-vector."""
return self._tokenizer.encode(
text,
allowed_special=self._allowed_special,
disallowed_special=self._disallowed_special,
)
def num_tokens(self, text: str) -> int:
"""Return the number of tokens in a string."""
return len(self.encode(text))
def split_text(self, text: str | list[str]) -> list[str]:
"""Split text method."""
if isinstance(text, list):
text = " ".join(text)
elif cast("bool", pd.isna(text)) or text == "":
return []
if not isinstance(text, str):
msg = f"Attempting to split a non-string value, actual is {type(text)}"
raise TypeError(msg)
tokenizer = Tokenizer(
chunk_overlap=self._chunk_overlap,
tokens_per_chunk=self._chunk_size,
decode=self._tokenizer.decode,
encode=lambda text: self.encode(text),
)
return split_single_text_on_tokens(text=text, tokenizer=tokenizer)
def split_single_text_on_tokens(text: str, tokenizer: Tokenizer) -> list[str]:
"""Split a single text and return chunks using the tokenizer."""
result = []
input_ids = tokenizer.encode(text)
start_idx = 0
cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids))
chunk_ids = input_ids[start_idx:cur_idx]
while start_idx < len(input_ids):
chunk_text = tokenizer.decode(list(chunk_ids))
result.append(chunk_text) # Append chunked text as string
if cur_idx == len(input_ids):
break
start_idx += tokenizer.tokens_per_chunk - tokenizer.chunk_overlap
cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids))
chunk_ids = input_ids[start_idx:cur_idx]
return result
# Adapted from - https://github.com/langchain-ai/langchain/blob/77b359edf5df0d37ef0d539f678cf64f5557cb54/libs/langchain/langchain/text_splitter.py#L471
# So we could have better control over the chunking process
def split_multiple_texts_on_tokens(
texts: list[str], tokenizer: Tokenizer, tick: ProgressTicker
) -> list[TextChunk]:
"""Split multiple texts and return chunks with metadata using the tokenizer."""
result = []
mapped_ids = []
for source_doc_idx, text in enumerate(texts):
encoded = tokenizer.encode(text)
if tick:
tick(1) # Track progress if tick callback is provided
mapped_ids.append((source_doc_idx, encoded))
input_ids = [
(source_doc_idx, id) for source_doc_idx, ids in mapped_ids for id in ids
]
start_idx = 0
cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids))
chunk_ids = input_ids[start_idx:cur_idx]
while start_idx < len(input_ids):
chunk_text = tokenizer.decode([id for _, id in chunk_ids])
doc_indices = list({doc_idx for doc_idx, _ in chunk_ids})
result.append(TextChunk(chunk_text, doc_indices, len(chunk_ids)))
if cur_idx == len(input_ids):
break
start_idx += tokenizer.tokens_per_chunk - tokenizer.chunk_overlap
cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids))
chunk_ids = input_ids[start_idx:cur_idx]
return result