Move Tokenizer base class to common package

This commit is contained in:
Nathan Evans 2025-12-18 13:56:32 -08:00
parent e5c1aa7d52
commit b7c06730d7
41 changed files with 158 additions and 39 deletions

View File

@ -0,0 +1,8 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""The GraphRAG types module: contains commonly-reused types across the monorepo packages."""
from graphrag_common.types.tokenizer import Tokenizer
__all__ = ["Tokenizer"]

View File

@ -6,11 +6,11 @@
from collections.abc import Callable
from graphrag_common.factory.factory import Factory, ServiceScope
from graphrag_common.types.tokenizer import Tokenizer
from graphrag.chunking.chunker import Chunker
from graphrag.config.enums import ChunkStrategyType
from graphrag.config.models.chunking_config import ChunkingConfig
from graphrag.tokenizer.tokenizer import Tokenizer
class ChunkerFactory(Factory[Chunker]):

View File

@ -6,8 +6,9 @@
from collections.abc import Callable
from typing import Any
from graphrag_common.types.tokenizer import Tokenizer
from graphrag.chunking.chunker import Chunker
from graphrag.tokenizer.tokenizer import Tokenizer
EncodedText = list[int]
DecodeFn = Callable[[EncodedText], str]

View File

@ -7,11 +7,11 @@ import logging
import numpy as np
import pandas as pd
from graphrag_common.types.tokenizer import Tokenizer
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
from graphrag.index.operations.embed_text.run_embed_text import run_embed_text
from graphrag.language_model.protocol.base import EmbeddingModel
from graphrag.tokenizer.tokenizer import Tokenizer
from graphrag.vector_stores.base import BaseVectorStore, VectorStoreDocument
logger = logging.getLogger(__name__)

View File

@ -8,13 +8,13 @@ import logging
from dataclasses import dataclass
import numpy as np
from graphrag_common.types.tokenizer import Tokenizer
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
from graphrag.chunking.token_chunker import split_text_on_tokens
from graphrag.index.utils.is_null import is_null
from graphrag.language_model.protocol.base import EmbeddingModel
from graphrag.logger.progress import ProgressTicker, progress_ticker
from graphrag.tokenizer.tokenizer import Tokenizer
logger = logging.getLogger(__name__)

View File

@ -4,12 +4,12 @@
"""A module containing build_mixed_context method definition."""
import pandas as pd
from graphrag_common.types.tokenizer import Tokenizer
import graphrag.data_model.schemas as schemas
from graphrag.index.operations.summarize_communities.graph_context.sort_context import (
sort_context,
)
from graphrag.tokenizer.tokenizer import Tokenizer
def build_mixed_context(

View File

@ -7,6 +7,7 @@ import logging
from typing import cast
import pandas as pd
from graphrag_common.types.tokenizer import Tokenizer
import graphrag.data_model.schemas as schemas
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
@ -30,7 +31,6 @@ from graphrag.index.utils.dataframes import (
where_column_equals,
)
from graphrag.logger.progress import progress_iterable
from graphrag.tokenizer.tokenizer import Tokenizer
logger = logging.getLogger(__name__)

View File

@ -3,9 +3,9 @@
"""Sort context by degree in descending order."""
import pandas as pd
from graphrag_common.types.tokenizer import Tokenizer
import graphrag.data_model.schemas as schemas
from graphrag.tokenizer.tokenizer import Tokenizer
def sort_context(

View File

@ -7,6 +7,7 @@ import logging
from collections.abc import Callable
import pandas as pd
from graphrag_common.types.tokenizer import Tokenizer
import graphrag.data_model.schemas as schemas
from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks
@ -26,7 +27,6 @@ from graphrag.index.operations.summarize_communities.utils import (
from graphrag.index.utils.derive_from_rows import derive_from_rows
from graphrag.language_model.protocol.base import ChatModel
from graphrag.logger.progress import progress_ticker
from graphrag.tokenizer.tokenizer import Tokenizer
logger = logging.getLogger(__name__)

View File

@ -7,6 +7,7 @@ import logging
from typing import cast
import pandas as pd
from graphrag_common.types.tokenizer import Tokenizer
import graphrag.data_model.schemas as schemas
from graphrag.index.operations.summarize_communities.build_mixed_context import (
@ -18,7 +19,6 @@ from graphrag.index.operations.summarize_communities.text_unit_context.prep_text
from graphrag.index.operations.summarize_communities.text_unit_context.sort_context import (
sort_context,
)
from graphrag.tokenizer.tokenizer import Tokenizer
logger = logging.getLogger(__name__)

View File

@ -6,9 +6,9 @@
import logging
import pandas as pd
from graphrag_common.types.tokenizer import Tokenizer
import graphrag.data_model.schemas as schemas
from graphrag.tokenizer.tokenizer import Tokenizer
logger = logging.getLogger(__name__)

View File

@ -0,0 +1,102 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""A module containing 'TokenTextSplitter' class and 'split_single_text_on_tokens' function."""
import logging
from abc import ABC
from collections.abc import Callable
from typing import cast
import pandas as pd
from graphrag_common.types.tokenizer import Tokenizer
from graphrag.tokenizer.get_tokenizer import get_tokenizer
EncodedText = list[int]
DecodeFn = Callable[[EncodedText], str]
EncodeFn = Callable[[str], EncodedText]
LengthFn = Callable[[str], int]
logger = logging.getLogger(__name__)
class TokenTextSplitter(ABC):
"""Token text splitter class definition."""
_chunk_size: int
_chunk_overlap: int
_length_function: LengthFn
_keep_separator: bool
_add_start_index: bool
_strip_whitespace: bool
def __init__(
self,
# based on OpenAI embedding chunk size limits
# https://devblogs.microsoft.com/azure-sql/embedding-models-and-dimensions-optimizing-the-performance-resource-usage-ratio/
chunk_size: int = 8191,
chunk_overlap: int = 100,
length_function: LengthFn = len,
keep_separator: bool = False,
add_start_index: bool = False,
strip_whitespace: bool = True,
tokenizer: Tokenizer | None = None,
):
"""Init method definition."""
self._chunk_size = chunk_size
self._chunk_overlap = chunk_overlap
self._length_function = length_function
self._keep_separator = keep_separator
self._add_start_index = add_start_index
self._strip_whitespace = strip_whitespace
self._tokenizer = tokenizer or get_tokenizer()
def num_tokens(self, text: str) -> int:
"""Return the number of tokens in a string."""
return self._tokenizer.num_tokens(text)
def split_text(self, text: str | list[str]) -> list[str]:
"""Split text method."""
if isinstance(text, list):
text = " ".join(text)
elif cast("bool", pd.isna(text)) or text == "":
return []
if not isinstance(text, str):
msg = f"Attempting to split a non-string value, actual is {type(text)}"
raise TypeError(msg)
return split_single_text_on_tokens(
text,
chunk_overlap=self._chunk_overlap,
tokens_per_chunk=self._chunk_size,
decode=self._tokenizer.decode,
encode=self._tokenizer.encode,
)
def split_single_text_on_tokens(
text: str,
tokens_per_chunk: int,
chunk_overlap: int,
encode: EncodeFn,
decode: DecodeFn,
) -> list[str]:
"""Split a single text and return chunks using the tokenizer."""
result = []
input_ids = encode(text)
start_idx = 0
cur_idx = min(start_idx + tokens_per_chunk, len(input_ids))
chunk_ids = input_ids[start_idx:cur_idx]
while start_idx < len(input_ids):
chunk_text = decode(list(chunk_ids))
result.append(chunk_text) # Append chunked text as string
if cur_idx == len(input_ids):
break
start_idx += tokens_per_chunk - chunk_overlap
cur_idx = min(start_idx + tokens_per_chunk, len(input_ids))
chunk_ids = input_ids[start_idx:cur_idx]
return result

View File

@ -8,6 +8,7 @@ import logging
from typing import Any, cast
import pandas as pd
from graphrag_common.types.tokenizer import Tokenizer
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
from graphrag.chunking.chunker import Chunker
@ -19,7 +20,6 @@ from graphrag.index.typing.workflow import WorkflowFunctionOutput
from graphrag.index.utils.hashing import gen_sha512_hash
from graphrag.logger.progress import progress_ticker
from graphrag.tokenizer.get_tokenizer import get_tokenizer
from graphrag.tokenizer.tokenizer import Tokenizer
from graphrag.utils.storage import load_table_from_storage, write_table_to_storage
logger = logging.getLogger(__name__)

View File

@ -6,6 +6,7 @@
import logging
import pandas as pd
from graphrag_common.types.tokenizer import Tokenizer
import graphrag.data_model.schemas as schemas
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
@ -29,7 +30,6 @@ from graphrag.index.typing.workflow import WorkflowFunctionOutput
from graphrag.language_model.manager import ModelManager
from graphrag.language_model.protocol.base import ChatModel
from graphrag.tokenizer.get_tokenizer import get_tokenizer
from graphrag.tokenizer.tokenizer import Tokenizer
from graphrag.utils.storage import (
load_table_from_storage,
storage_has_table,

View File

@ -6,6 +6,7 @@
import logging
import pandas as pd
from graphrag_common.types.tokenizer import Tokenizer
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
from graphrag.config.enums import AsyncType
@ -28,7 +29,6 @@ from graphrag.index.typing.workflow import WorkflowFunctionOutput
from graphrag.language_model.manager import ModelManager
from graphrag.language_model.protocol.base import ChatModel
from graphrag.tokenizer.get_tokenizer import get_tokenizer
from graphrag.tokenizer.tokenizer import Tokenizer
from graphrag.utils.storage import load_table_from_storage, write_table_to_storage
logger = logging.getLogger(__name__)

View File

@ -6,6 +6,7 @@
import logging
import pandas as pd
from graphrag_common.types.tokenizer import Tokenizer
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
from graphrag.config.embeddings import (
@ -23,7 +24,6 @@ from graphrag.index.typing.workflow import WorkflowFunctionOutput
from graphrag.language_model.manager import ModelManager
from graphrag.language_model.protocol.base import EmbeddingModel
from graphrag.tokenizer.get_tokenizer import get_tokenizer
from graphrag.tokenizer.tokenizer import Tokenizer
from graphrag.utils.storage import (
load_table_from_storage,
write_table_to_storage,

View File

@ -5,6 +5,8 @@
from pathlib import Path
from graphrag_common.types.tokenizer import Tokenizer
from graphrag.prompt_tune.template.extract_graph import (
EXAMPLE_EXTRACTION_TEMPLATE,
GRAPH_EXTRACTION_JSON_PROMPT,
@ -13,7 +15,6 @@ from graphrag.prompt_tune.template.extract_graph import (
UNTYPED_GRAPH_EXTRACTION_PROMPT,
)
from graphrag.tokenizer.get_tokenizer import get_tokenizer
from graphrag.tokenizer.tokenizer import Tokenizer
EXTRACT_GRAPH_FILENAME = "extract_graph.txt"

View File

@ -8,11 +8,11 @@ import random
from typing import Any, cast
import pandas as pd
from graphrag_common.types.tokenizer import Tokenizer
from graphrag.data_model.community_report import CommunityReport
from graphrag.data_model.entity import Entity
from graphrag.tokenizer.get_tokenizer import get_tokenizer
from graphrag.tokenizer.tokenizer import Tokenizer
logger = logging.getLogger(__name__)

View File

@ -7,9 +7,9 @@ from dataclasses import dataclass
from enum import Enum
import pandas as pd
from graphrag_common.types.tokenizer import Tokenizer
from graphrag.tokenizer.get_tokenizer import get_tokenizer
from graphrag.tokenizer.tokenizer import Tokenizer
"""
Enum for conversation roles

View File

@ -10,12 +10,13 @@ from copy import deepcopy
from time import time
from typing import Any
from graphrag_common.types.tokenizer import Tokenizer
from graphrag.data_model.community import Community
from graphrag.data_model.community_report import CommunityReport
from graphrag.language_model.protocol.base import ChatModel
from graphrag.query.context_builder.rate_prompt import RATE_QUERY
from graphrag.query.context_builder.rate_relevancy import rate_relevancy
from graphrag.tokenizer.tokenizer import Tokenizer
logger = logging.getLogger(__name__)

View File

@ -7,6 +7,7 @@ from collections import defaultdict
from typing import Any, cast
import pandas as pd
from graphrag_common.types.tokenizer import Tokenizer
from graphrag.data_model.covariate import Covariate
from graphrag.data_model.entity import Entity
@ -24,7 +25,6 @@ from graphrag.query.input.retrieval.relationships import (
to_relationship_dataframe,
)
from graphrag.tokenizer.get_tokenizer import get_tokenizer
from graphrag.tokenizer.tokenizer import Tokenizer
def build_entity_context(

View File

@ -9,11 +9,11 @@ from contextlib import nullcontext
from typing import Any
import numpy as np
from graphrag_common.types.tokenizer import Tokenizer
from graphrag.language_model.protocol.base import ChatModel
from graphrag.query.context_builder.rate_prompt import RATE_QUERY
from graphrag.query.llm.text_utils import try_parse_json_object
from graphrag.tokenizer.tokenizer import Tokenizer
logger = logging.getLogger(__name__)

View File

@ -7,11 +7,11 @@ import random
from typing import Any, cast
import pandas as pd
from graphrag_common.types.tokenizer import Tokenizer
from graphrag.data_model.relationship import Relationship
from graphrag.data_model.text_unit import TextUnit
from graphrag.tokenizer.get_tokenizer import get_tokenizer
from graphrag.tokenizer.tokenizer import Tokenizer
"""
Contain util functions to build text unit context for the search's system prompt

View File

@ -9,10 +9,10 @@ import re
from collections.abc import Iterator
from itertools import islice
from graphrag_common.types.tokenizer import Tokenizer
from json_repair import repair_json
from graphrag.tokenizer.get_tokenizer import get_tokenizer
from graphrag.tokenizer.tokenizer import Tokenizer
logger = logging.getLogger(__name__)

View File

@ -7,13 +7,14 @@ from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any
from graphrag_common.types.tokenizer import Tokenizer
from graphrag.language_model.protocol.base import ChatModel
from graphrag.query.context_builder.builders import (
GlobalContextBuilder,
LocalContextBuilder,
)
from graphrag.tokenizer.get_tokenizer import get_tokenizer
from graphrag.tokenizer.tokenizer import Tokenizer
@dataclass

View File

@ -7,6 +7,8 @@ import logging
import time
from typing import Any, cast
from graphrag_common.types.tokenizer import Tokenizer
from graphrag.callbacks.llm_callbacks import BaseLLMCallback
from graphrag.language_model.protocol.base import ChatModel
from graphrag.prompts.query.question_gen_system_prompt import QUESTION_SYSTEM_PROMPT
@ -18,7 +20,6 @@ from graphrag.query.context_builder.conversation_history import (
ConversationHistory,
)
from graphrag.query.question_gen.base import BaseQuestionGen, QuestionResult
from graphrag.tokenizer.tokenizer import Tokenizer
logger = logging.getLogger(__name__)

View File

@ -9,6 +9,7 @@ from dataclasses import dataclass
from typing import Any, Generic, TypeVar
import pandas as pd
from graphrag_common.types.tokenizer import Tokenizer
from graphrag.language_model.protocol.base import ChatModel
from graphrag.query.context_builder.builders import (
@ -21,7 +22,6 @@ from graphrag.query.context_builder.conversation_history import (
ConversationHistory,
)
from graphrag.tokenizer.get_tokenizer import get_tokenizer
from graphrag.tokenizer.tokenizer import Tokenizer
@dataclass

View File

@ -7,6 +7,7 @@ import logging
from typing import cast
import pandas as pd
from graphrag_common.types.tokenizer import Tokenizer
from graphrag.data_model.text_unit import TextUnit
from graphrag.language_model.protocol.base import EmbeddingModel
@ -16,7 +17,6 @@ from graphrag.query.context_builder.builders import (
)
from graphrag.query.context_builder.conversation_history import ConversationHistory
from graphrag.tokenizer.get_tokenizer import get_tokenizer
from graphrag.tokenizer.tokenizer import Tokenizer
from graphrag.vector_stores.base import BaseVectorStore
logger = logging.getLogger(__name__)

View File

@ -8,6 +8,8 @@ import time
from collections.abc import AsyncGenerator
from typing import Any
from graphrag_common.types.tokenizer import Tokenizer
from graphrag.callbacks.query_callbacks import QueryCallbacks
from graphrag.language_model.protocol.base import ChatModel
from graphrag.prompts.query.basic_search_system_prompt import (
@ -16,7 +18,6 @@ from graphrag.prompts.query.basic_search_system_prompt import (
from graphrag.query.context_builder.builders import BasicContextBuilder
from graphrag.query.context_builder.conversation_history import ConversationHistory
from graphrag.query.structured_search.base import BaseSearch, SearchResult
from graphrag.tokenizer.tokenizer import Tokenizer
logger = logging.getLogger(__name__)
"""

View File

@ -9,6 +9,7 @@ from typing import Any
import numpy as np
import pandas as pd
from graphrag_common.types.tokenizer import Tokenizer
from graphrag.config.models.drift_search_config import DRIFTSearchConfig
from graphrag.data_model.community_report import CommunityReport
@ -28,7 +29,6 @@ from graphrag.query.structured_search.local_search.mixed_context import (
LocalSearchMixedContext,
)
from graphrag.tokenizer.get_tokenizer import get_tokenizer
from graphrag.tokenizer.tokenizer import Tokenizer
from graphrag.vector_stores.base import BaseVectorStore
logger = logging.getLogger(__name__)

View File

@ -10,6 +10,7 @@ import time
import numpy as np
import pandas as pd
from graphrag_common.types.tokenizer import Tokenizer
from tqdm.asyncio import tqdm_asyncio
from graphrag.config.models.drift_search_config import DRIFTSearchConfig
@ -20,7 +21,6 @@ from graphrag.prompts.query.drift_search_system_prompt import (
)
from graphrag.query.structured_search.base import SearchResult
from graphrag.tokenizer.get_tokenizer import get_tokenizer
from graphrag.tokenizer.tokenizer import Tokenizer
logger = logging.getLogger(__name__)

View File

@ -8,6 +8,7 @@ import time
from collections.abc import AsyncGenerator
from typing import Any
from graphrag_common.types.tokenizer import Tokenizer
from tqdm.asyncio import tqdm_asyncio
from graphrag.callbacks.query_callbacks import QueryCallbacks
@ -26,7 +27,6 @@ from graphrag.query.structured_search.drift_search.primer import DRIFTPrimer
from graphrag.query.structured_search.drift_search.state import QueryState
from graphrag.query.structured_search.local_search.search import LocalSearch
from graphrag.tokenizer.get_tokenizer import get_tokenizer
from graphrag.tokenizer.tokenizer import Tokenizer
logger = logging.getLogger(__name__)

View File

@ -5,6 +5,8 @@
from typing import Any
from graphrag_common.types.tokenizer import Tokenizer
from graphrag.data_model.community import Community
from graphrag.data_model.community_report import CommunityReport
from graphrag.data_model.entity import Entity
@ -20,7 +22,6 @@ from graphrag.query.context_builder.dynamic_community_selection import (
)
from graphrag.query.structured_search.base import GlobalContextBuilder
from graphrag.tokenizer.get_tokenizer import get_tokenizer
from graphrag.tokenizer.tokenizer import Tokenizer
class GlobalCommunityContext(GlobalContextBuilder):

View File

@ -12,6 +12,7 @@ from dataclasses import dataclass
from typing import Any
import pandas as pd
from graphrag_common.types.tokenizer import Tokenizer
from graphrag.callbacks.query_callbacks import QueryCallbacks
from graphrag.language_model.protocol.base import ChatModel
@ -31,7 +32,6 @@ from graphrag.query.context_builder.conversation_history import (
)
from graphrag.query.llm.text_utils import try_parse_json_object
from graphrag.query.structured_search.base import BaseSearch, SearchResult
from graphrag.tokenizer.tokenizer import Tokenizer
logger = logging.getLogger(__name__)

View File

@ -7,6 +7,7 @@ from copy import deepcopy
from typing import Any
import pandas as pd
from graphrag_common.types.tokenizer import Tokenizer
from graphrag.data_model.community_report import CommunityReport
from graphrag.data_model.covariate import Covariate
@ -41,7 +42,6 @@ from graphrag.query.input.retrieval.community_reports import (
from graphrag.query.input.retrieval.text_units import get_candidate_text_units
from graphrag.query.structured_search.base import LocalContextBuilder
from graphrag.tokenizer.get_tokenizer import get_tokenizer
from graphrag.tokenizer.tokenizer import Tokenizer
from graphrag.vector_stores.base import BaseVectorStore
logger = logging.getLogger(__name__)

View File

@ -8,6 +8,8 @@ import time
from collections.abc import AsyncGenerator
from typing import Any
from graphrag_common.types.tokenizer import Tokenizer
from graphrag.callbacks.query_callbacks import QueryCallbacks
from graphrag.language_model.protocol.base import ChatModel
from graphrag.prompts.query.local_search_system_prompt import (
@ -18,7 +20,6 @@ from graphrag.query.context_builder.conversation_history import (
ConversationHistory,
)
from graphrag.query.structured_search.base import BaseSearch, SearchResult
from graphrag.tokenizer.tokenizer import Tokenizer
logger = logging.getLogger(__name__)

View File

@ -3,11 +3,12 @@
"""Get Tokenizer."""
from graphrag_common.types.tokenizer import Tokenizer
from graphrag.config.defaults import ENCODING_MODEL
from graphrag.config.models.language_model_config import LanguageModelConfig
from graphrag.tokenizer.litellm_tokenizer import LitellmTokenizer
from graphrag.tokenizer.tiktoken_tokenizer import TiktokenTokenizer
from graphrag.tokenizer.tokenizer import Tokenizer
def get_tokenizer(

View File

@ -3,10 +3,9 @@
"""LiteLLM Tokenizer."""
from graphrag_common.types.tokenizer import Tokenizer
from litellm import decode, encode # type: ignore
from graphrag.tokenizer.tokenizer import Tokenizer
class LitellmTokenizer(Tokenizer):
"""LiteLLM Tokenizer."""

View File

@ -4,8 +4,7 @@
"""Tiktoken Tokenizer."""
import tiktoken
from graphrag.tokenizer.tokenizer import Tokenizer
from graphrag_common.types.tokenizer import Tokenizer
class TiktokenTokenizer(Tokenizer):

View File

@ -23,6 +23,7 @@ class MockTokenizer:
tokenizer = get_tokenizer()
class TestRunSentences:
def setup_method(self, method):
bootstrap()
@ -74,6 +75,7 @@ class TestRunTokens:
assert len(chunks) > 0
def test_split_text_str_empty():
result = split_text_on_tokens(
"",