Move Tokenizer back to GR core

This commit is contained in:
Nathan Evans 2025-12-22 11:18:42 -08:00
parent 247547f5bc
commit 90479c0b1c
41 changed files with 58 additions and 71 deletions

View File

@ -1,8 +0,0 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""The GraphRAG types module: contains commonly-reused types across the monorepo packages."""
from graphrag_common.types.tokenizer import Tokenizer
__all__ = ["Tokenizer"]

View File

@ -6,7 +6,6 @@
from collections.abc import Callable
from graphrag_common.factory.factory import Factory, ServiceScope
from graphrag_common.types.tokenizer import Tokenizer
from graphrag.chunking.chunk_strategy_type import ChunkStrategyType
from graphrag.chunking.chunker import Chunker
@ -38,7 +37,9 @@ def register_chunker(
def create_chunker(
config: ChunkingConfig, tokenizer: Tokenizer | None = None
config: ChunkingConfig,
encode: Callable[[str], list[int]] | None,
decode: Callable[[list[int]], str] | None,
) -> Chunker:
"""Create a chunker implementation based on the given configuration.
@ -53,8 +54,10 @@ def create_chunker(
The created chunker implementation.
"""
config_model = config.model_dump()
if tokenizer is not None:
config_model["tokenizer"] = tokenizer
if encode is not None:
config_model["encode"] = encode
if decode is not None:
config_model["decode"] = decode
chunker_strategy = config.strategy
if chunker_strategy not in chunker_factory:

View File

@ -6,14 +6,8 @@
from collections.abc import Callable
from typing import Any
from graphrag_common.types.tokenizer import Tokenizer
from graphrag.chunking.chunker import Chunker
EncodedText = list[int]
DecodeFn = Callable[[EncodedText], str]
EncodeFn = Callable[[str], EncodedText]
class TokenChunker(Chunker):
"""A chunker that splits text into token-based chunks."""
@ -22,13 +16,15 @@ class TokenChunker(Chunker):
self,
size: int,
overlap: int,
tokenizer: Tokenizer,
encode: Callable[[str], list[int]],
decode: Callable[[list[int]], str],
**kwargs: Any,
) -> None:
"""Create a token chunker instance."""
self._size = size
self._overlap = overlap
self._tokenizer = tokenizer
self._encode = encode
self._decode = decode
def chunk(self, text: str) -> list[str]:
"""Chunk the text into token-based chunks."""
@ -36,8 +32,8 @@ class TokenChunker(Chunker):
text,
chunk_size=self._size,
chunk_overlap=self._overlap,
encode=self._tokenizer.encode,
decode=self._tokenizer.decode,
encode=self._encode,
decode=self._decode,
)
@ -45,8 +41,8 @@ def split_text_on_tokens(
text: str,
chunk_size: int,
chunk_overlap: int,
encode: EncodeFn,
decode: DecodeFn,
encode: Callable[[str], list[int]],
decode: Callable[[list[int]], str],
) -> list[str]:
"""Split a single text and return chunks using the tokenizer."""
result = []

View File

@ -7,11 +7,11 @@ import logging
import numpy as np
import pandas as pd
from graphrag_common.types.tokenizer import Tokenizer
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
from graphrag.index.operations.embed_text.run_embed_text import run_embed_text
from graphrag.language_model.protocol.base import EmbeddingModel
from graphrag.tokenizer.tokenizer import Tokenizer
from graphrag.vector_stores.base import BaseVectorStore, VectorStoreDocument
logger = logging.getLogger(__name__)

View File

@ -8,13 +8,13 @@ import logging
from dataclasses import dataclass
import numpy as np
from graphrag_common.types.tokenizer import Tokenizer
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
from graphrag.chunking.token_chunker import split_text_on_tokens
from graphrag.index.utils.is_null import is_null
from graphrag.language_model.protocol.base import EmbeddingModel
from graphrag.logger.progress import ProgressTicker, progress_ticker
from graphrag.tokenizer.tokenizer import Tokenizer
logger = logging.getLogger(__name__)

View File

@ -4,12 +4,12 @@
"""A module containing build_mixed_context method definition."""
import pandas as pd
from graphrag_common.types.tokenizer import Tokenizer
import graphrag.data_model.schemas as schemas
from graphrag.index.operations.summarize_communities.graph_context.sort_context import (
sort_context,
)
from graphrag.tokenizer.tokenizer import Tokenizer
def build_mixed_context(

View File

@ -7,7 +7,6 @@ import logging
from typing import cast
import pandas as pd
from graphrag_common.types.tokenizer import Tokenizer
import graphrag.data_model.schemas as schemas
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
@ -31,6 +30,7 @@ from graphrag.index.utils.dataframes import (
where_column_equals,
)
from graphrag.logger.progress import progress_iterable
from graphrag.tokenizer.tokenizer import Tokenizer
logger = logging.getLogger(__name__)

View File

@ -3,9 +3,9 @@
"""Sort context by degree in descending order."""
import pandas as pd
from graphrag_common.types.tokenizer import Tokenizer
import graphrag.data_model.schemas as schemas
from graphrag.tokenizer.tokenizer import Tokenizer
def sort_context(

View File

@ -7,7 +7,6 @@ import logging
from collections.abc import Callable
import pandas as pd
from graphrag_common.types.tokenizer import Tokenizer
import graphrag.data_model.schemas as schemas
from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks
@ -27,6 +26,7 @@ from graphrag.index.operations.summarize_communities.utils import (
from graphrag.index.utils.derive_from_rows import derive_from_rows
from graphrag.language_model.protocol.base import ChatModel
from graphrag.logger.progress import progress_ticker
from graphrag.tokenizer.tokenizer import Tokenizer
logger = logging.getLogger(__name__)

View File

@ -7,7 +7,6 @@ import logging
from typing import cast
import pandas as pd
from graphrag_common.types.tokenizer import Tokenizer
import graphrag.data_model.schemas as schemas
from graphrag.index.operations.summarize_communities.build_mixed_context import (
@ -19,6 +18,7 @@ from graphrag.index.operations.summarize_communities.text_unit_context.prep_text
from graphrag.index.operations.summarize_communities.text_unit_context.sort_context import (
sort_context,
)
from graphrag.tokenizer.tokenizer import Tokenizer
logger = logging.getLogger(__name__)

View File

@ -6,9 +6,9 @@
import logging
import pandas as pd
from graphrag_common.types.tokenizer import Tokenizer
import graphrag.data_model.schemas as schemas
from graphrag.tokenizer.tokenizer import Tokenizer
logger = logging.getLogger(__name__)

View File

@ -9,9 +9,9 @@ from collections.abc import Callable
from typing import cast
import pandas as pd
from graphrag_common.types.tokenizer import Tokenizer
from graphrag.tokenizer.get_tokenizer import get_tokenizer
from graphrag.tokenizer.tokenizer import Tokenizer
EncodedText = list[int]
DecodeFn = Callable[[EncodedText], str]

View File

@ -8,7 +8,6 @@ import logging
from typing import Any, cast
import pandas as pd
from graphrag_common.types.tokenizer import Tokenizer
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
from graphrag.chunking.chunker import Chunker
@ -20,6 +19,7 @@ from graphrag.index.typing.workflow import WorkflowFunctionOutput
from graphrag.index.utils.hashing import gen_sha512_hash
from graphrag.logger.progress import progress_ticker
from graphrag.tokenizer.get_tokenizer import get_tokenizer
from graphrag.tokenizer.tokenizer import Tokenizer
from graphrag.utils.storage import load_table_from_storage, write_table_to_storage
logger = logging.getLogger(__name__)
@ -34,7 +34,7 @@ async def run_workflow(
documents = await load_table_from_storage("documents", context.output_storage)
tokenizer = get_tokenizer(encoding_model=config.chunks.encoding_model)
chunker = create_chunker(config.chunks, tokenizer)
chunker = create_chunker(config.chunks, tokenizer.encode, tokenizer.decode)
output = create_base_text_units(
documents,
context.callbacks,
@ -71,7 +71,9 @@ def create_base_text_units(
metadata = row.get("metadata", None)
if prepend_metadata and metadata is not None:
metadata = json.loads(metadata) if isinstance(metadata, str) else metadata
row["chunks"] = [prepend_metadata_fn(chunk, metadata) for chunk in row["chunks"]]
row["chunks"] = [
prepend_metadata_fn(chunk, metadata) for chunk in row["chunks"]
]
tick()
logger.info("chunker progress: %d/%d", row_index + 1, total_rows)
return row

View File

@ -6,7 +6,6 @@
import logging
import pandas as pd
from graphrag_common.types.tokenizer import Tokenizer
import graphrag.data_model.schemas as schemas
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
@ -30,6 +29,7 @@ from graphrag.index.typing.workflow import WorkflowFunctionOutput
from graphrag.language_model.manager import ModelManager
from graphrag.language_model.protocol.base import ChatModel
from graphrag.tokenizer.get_tokenizer import get_tokenizer
from graphrag.tokenizer.tokenizer import Tokenizer
from graphrag.utils.storage import (
load_table_from_storage,
storage_has_table,

View File

@ -6,7 +6,6 @@
import logging
import pandas as pd
from graphrag_common.types.tokenizer import Tokenizer
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
from graphrag.config.enums import AsyncType
@ -29,6 +28,7 @@ from graphrag.index.typing.workflow import WorkflowFunctionOutput
from graphrag.language_model.manager import ModelManager
from graphrag.language_model.protocol.base import ChatModel
from graphrag.tokenizer.get_tokenizer import get_tokenizer
from graphrag.tokenizer.tokenizer import Tokenizer
from graphrag.utils.storage import load_table_from_storage, write_table_to_storage
logger = logging.getLogger(__name__)

View File

@ -6,7 +6,6 @@
import logging
import pandas as pd
from graphrag_common.types.tokenizer import Tokenizer
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
from graphrag.config.embeddings import (
@ -24,6 +23,7 @@ from graphrag.index.typing.workflow import WorkflowFunctionOutput
from graphrag.language_model.manager import ModelManager
from graphrag.language_model.protocol.base import EmbeddingModel
from graphrag.tokenizer.get_tokenizer import get_tokenizer
from graphrag.tokenizer.tokenizer import Tokenizer
from graphrag.utils.storage import (
load_table_from_storage,
write_table_to_storage,

View File

@ -5,8 +5,6 @@
from pathlib import Path
from graphrag_common.types.tokenizer import Tokenizer
from graphrag.prompt_tune.template.extract_graph import (
EXAMPLE_EXTRACTION_TEMPLATE,
GRAPH_EXTRACTION_JSON_PROMPT,
@ -15,6 +13,7 @@ from graphrag.prompt_tune.template.extract_graph import (
UNTYPED_GRAPH_EXTRACTION_PROMPT,
)
from graphrag.tokenizer.get_tokenizer import get_tokenizer
from graphrag.tokenizer.tokenizer import Tokenizer
EXTRACT_GRAPH_FILENAME = "extract_graph.txt"

View File

@ -8,11 +8,11 @@ import random
from typing import Any, cast
import pandas as pd
from graphrag_common.types.tokenizer import Tokenizer
from graphrag.data_model.community_report import CommunityReport
from graphrag.data_model.entity import Entity
from graphrag.tokenizer.get_tokenizer import get_tokenizer
from graphrag.tokenizer.tokenizer import Tokenizer
logger = logging.getLogger(__name__)

View File

@ -7,9 +7,9 @@ from dataclasses import dataclass
from enum import Enum
import pandas as pd
from graphrag_common.types.tokenizer import Tokenizer
from graphrag.tokenizer.get_tokenizer import get_tokenizer
from graphrag.tokenizer.tokenizer import Tokenizer
"""
Enum for conversation roles

View File

@ -10,13 +10,12 @@ from copy import deepcopy
from time import time
from typing import Any
from graphrag_common.types.tokenizer import Tokenizer
from graphrag.data_model.community import Community
from graphrag.data_model.community_report import CommunityReport
from graphrag.language_model.protocol.base import ChatModel
from graphrag.query.context_builder.rate_prompt import RATE_QUERY
from graphrag.query.context_builder.rate_relevancy import rate_relevancy
from graphrag.tokenizer.tokenizer import Tokenizer
logger = logging.getLogger(__name__)

View File

@ -7,7 +7,6 @@ from collections import defaultdict
from typing import Any, cast
import pandas as pd
from graphrag_common.types.tokenizer import Tokenizer
from graphrag.data_model.covariate import Covariate
from graphrag.data_model.entity import Entity
@ -25,6 +24,7 @@ from graphrag.query.input.retrieval.relationships import (
to_relationship_dataframe,
)
from graphrag.tokenizer.get_tokenizer import get_tokenizer
from graphrag.tokenizer.tokenizer import Tokenizer
def build_entity_context(

View File

@ -9,11 +9,11 @@ from contextlib import nullcontext
from typing import Any
import numpy as np
from graphrag_common.types.tokenizer import Tokenizer
from graphrag.language_model.protocol.base import ChatModel
from graphrag.query.context_builder.rate_prompt import RATE_QUERY
from graphrag.query.llm.text_utils import try_parse_json_object
from graphrag.tokenizer.tokenizer import Tokenizer
logger = logging.getLogger(__name__)

View File

@ -7,11 +7,11 @@ import random
from typing import Any, cast
import pandas as pd
from graphrag_common.types.tokenizer import Tokenizer
from graphrag.data_model.relationship import Relationship
from graphrag.data_model.text_unit import TextUnit
from graphrag.tokenizer.get_tokenizer import get_tokenizer
from graphrag.tokenizer.tokenizer import Tokenizer
"""
Contain util functions to build text unit context for the search's system prompt

View File

@ -9,10 +9,10 @@ import re
from collections.abc import Iterator
from itertools import islice
from graphrag_common.types.tokenizer import Tokenizer
from json_repair import repair_json
from graphrag.tokenizer.get_tokenizer import get_tokenizer
from graphrag.tokenizer.tokenizer import Tokenizer
logger = logging.getLogger(__name__)

View File

@ -7,14 +7,13 @@ from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any
from graphrag_common.types.tokenizer import Tokenizer
from graphrag.language_model.protocol.base import ChatModel
from graphrag.query.context_builder.builders import (
GlobalContextBuilder,
LocalContextBuilder,
)
from graphrag.tokenizer.get_tokenizer import get_tokenizer
from graphrag.tokenizer.tokenizer import Tokenizer
@dataclass

View File

@ -7,8 +7,6 @@ import logging
import time
from typing import Any, cast
from graphrag_common.types.tokenizer import Tokenizer
from graphrag.callbacks.llm_callbacks import BaseLLMCallback
from graphrag.language_model.protocol.base import ChatModel
from graphrag.prompts.query.question_gen_system_prompt import QUESTION_SYSTEM_PROMPT
@ -20,6 +18,7 @@ from graphrag.query.context_builder.conversation_history import (
ConversationHistory,
)
from graphrag.query.question_gen.base import BaseQuestionGen, QuestionResult
from graphrag.tokenizer.tokenizer import Tokenizer
logger = logging.getLogger(__name__)

View File

@ -9,7 +9,6 @@ from dataclasses import dataclass
from typing import Any, Generic, TypeVar
import pandas as pd
from graphrag_common.types.tokenizer import Tokenizer
from graphrag.language_model.protocol.base import ChatModel
from graphrag.query.context_builder.builders import (
@ -22,6 +21,7 @@ from graphrag.query.context_builder.conversation_history import (
ConversationHistory,
)
from graphrag.tokenizer.get_tokenizer import get_tokenizer
from graphrag.tokenizer.tokenizer import Tokenizer
@dataclass

View File

@ -7,7 +7,6 @@ import logging
from typing import cast
import pandas as pd
from graphrag_common.types.tokenizer import Tokenizer
from graphrag.data_model.text_unit import TextUnit
from graphrag.language_model.protocol.base import EmbeddingModel
@ -17,6 +16,7 @@ from graphrag.query.context_builder.builders import (
)
from graphrag.query.context_builder.conversation_history import ConversationHistory
from graphrag.tokenizer.get_tokenizer import get_tokenizer
from graphrag.tokenizer.tokenizer import Tokenizer
from graphrag.vector_stores.base import BaseVectorStore
logger = logging.getLogger(__name__)

View File

@ -8,8 +8,6 @@ import time
from collections.abc import AsyncGenerator
from typing import Any
from graphrag_common.types.tokenizer import Tokenizer
from graphrag.callbacks.query_callbacks import QueryCallbacks
from graphrag.language_model.protocol.base import ChatModel
from graphrag.prompts.query.basic_search_system_prompt import (
@ -18,6 +16,7 @@ from graphrag.prompts.query.basic_search_system_prompt import (
from graphrag.query.context_builder.builders import BasicContextBuilder
from graphrag.query.context_builder.conversation_history import ConversationHistory
from graphrag.query.structured_search.base import BaseSearch, SearchResult
from graphrag.tokenizer.tokenizer import Tokenizer
logger = logging.getLogger(__name__)
"""

View File

@ -9,7 +9,6 @@ from typing import Any
import numpy as np
import pandas as pd
from graphrag_common.types.tokenizer import Tokenizer
from graphrag.config.models.drift_search_config import DRIFTSearchConfig
from graphrag.data_model.community_report import CommunityReport
@ -29,6 +28,7 @@ from graphrag.query.structured_search.local_search.mixed_context import (
LocalSearchMixedContext,
)
from graphrag.tokenizer.get_tokenizer import get_tokenizer
from graphrag.tokenizer.tokenizer import Tokenizer
from graphrag.vector_stores.base import BaseVectorStore
logger = logging.getLogger(__name__)

View File

@ -10,7 +10,6 @@ import time
import numpy as np
import pandas as pd
from graphrag_common.types.tokenizer import Tokenizer
from tqdm.asyncio import tqdm_asyncio
from graphrag.config.models.drift_search_config import DRIFTSearchConfig
@ -21,6 +20,7 @@ from graphrag.prompts.query.drift_search_system_prompt import (
)
from graphrag.query.structured_search.base import SearchResult
from graphrag.tokenizer.get_tokenizer import get_tokenizer
from graphrag.tokenizer.tokenizer import Tokenizer
logger = logging.getLogger(__name__)

View File

@ -8,7 +8,6 @@ import time
from collections.abc import AsyncGenerator
from typing import Any
from graphrag_common.types.tokenizer import Tokenizer
from tqdm.asyncio import tqdm_asyncio
from graphrag.callbacks.query_callbacks import QueryCallbacks
@ -27,6 +26,7 @@ from graphrag.query.structured_search.drift_search.primer import DRIFTPrimer
from graphrag.query.structured_search.drift_search.state import QueryState
from graphrag.query.structured_search.local_search.search import LocalSearch
from graphrag.tokenizer.get_tokenizer import get_tokenizer
from graphrag.tokenizer.tokenizer import Tokenizer
logger = logging.getLogger(__name__)

View File

@ -5,8 +5,6 @@
from typing import Any
from graphrag_common.types.tokenizer import Tokenizer
from graphrag.data_model.community import Community
from graphrag.data_model.community_report import CommunityReport
from graphrag.data_model.entity import Entity
@ -22,6 +20,7 @@ from graphrag.query.context_builder.dynamic_community_selection import (
)
from graphrag.query.structured_search.base import GlobalContextBuilder
from graphrag.tokenizer.get_tokenizer import get_tokenizer
from graphrag.tokenizer.tokenizer import Tokenizer
class GlobalCommunityContext(GlobalContextBuilder):

View File

@ -12,7 +12,6 @@ from dataclasses import dataclass
from typing import Any
import pandas as pd
from graphrag_common.types.tokenizer import Tokenizer
from graphrag.callbacks.query_callbacks import QueryCallbacks
from graphrag.language_model.protocol.base import ChatModel
@ -32,6 +31,7 @@ from graphrag.query.context_builder.conversation_history import (
)
from graphrag.query.llm.text_utils import try_parse_json_object
from graphrag.query.structured_search.base import BaseSearch, SearchResult
from graphrag.tokenizer.tokenizer import Tokenizer
logger = logging.getLogger(__name__)

View File

@ -7,7 +7,6 @@ from copy import deepcopy
from typing import Any
import pandas as pd
from graphrag_common.types.tokenizer import Tokenizer
from graphrag.data_model.community_report import CommunityReport
from graphrag.data_model.covariate import Covariate
@ -42,6 +41,7 @@ from graphrag.query.input.retrieval.community_reports import (
from graphrag.query.input.retrieval.text_units import get_candidate_text_units
from graphrag.query.structured_search.base import LocalContextBuilder
from graphrag.tokenizer.get_tokenizer import get_tokenizer
from graphrag.tokenizer.tokenizer import Tokenizer
from graphrag.vector_stores.base import BaseVectorStore
logger = logging.getLogger(__name__)

View File

@ -8,8 +8,6 @@ import time
from collections.abc import AsyncGenerator
from typing import Any
from graphrag_common.types.tokenizer import Tokenizer
from graphrag.callbacks.query_callbacks import QueryCallbacks
from graphrag.language_model.protocol.base import ChatModel
from graphrag.prompts.query.local_search_system_prompt import (
@ -20,6 +18,7 @@ from graphrag.query.context_builder.conversation_history import (
ConversationHistory,
)
from graphrag.query.structured_search.base import BaseSearch, SearchResult
from graphrag.tokenizer.tokenizer import Tokenizer
logger = logging.getLogger(__name__)

View File

@ -3,12 +3,11 @@
"""Get Tokenizer."""
from graphrag_common.types.tokenizer import Tokenizer
from graphrag.config.defaults import ENCODING_MODEL
from graphrag.config.models.language_model_config import LanguageModelConfig
from graphrag.tokenizer.litellm_tokenizer import LitellmTokenizer
from graphrag.tokenizer.tiktoken_tokenizer import TiktokenTokenizer
from graphrag.tokenizer.tokenizer import Tokenizer
def get_tokenizer(

View File

@ -3,9 +3,10 @@
"""LiteLLM Tokenizer."""
from graphrag_common.types.tokenizer import Tokenizer
from litellm import decode, encode # type: ignore
from graphrag.tokenizer.tokenizer import Tokenizer
class LitellmTokenizer(Tokenizer):
"""LiteLLM Tokenizer."""

View File

@ -4,7 +4,8 @@
"""Tiktoken Tokenizer."""
import tiktoken
from graphrag_common.types.tokenizer import Tokenizer
from graphrag.tokenizer.tokenizer import Tokenizer
class TiktokenTokenizer(Tokenizer):

View File

@ -11,7 +11,7 @@ from graphrag.chunking.token_chunker import (
split_text_on_tokens,
)
from graphrag.tokenizer.get_tokenizer import get_tokenizer
from graphrag_common.types.tokenizer import Tokenizer
from graphrag.tokenizer.tokenizer import Tokenizer
class MockTokenizer(Tokenizer):