Fix dangling token_encoder refs

This commit is contained in:
Nathan Evans 2025-10-06 15:30:53 -07:00
parent 3f93a0095d
commit 4c4ab16070
5 changed files with 22 additions and 26 deletions

View File

@ -24,7 +24,7 @@ Below are the key parameters of the [DRIFTSearch class](https://github.com/micro
- `llm`: OpenAI model object to be used for response generation
- `context_builder`: [context builder](https://github.com/microsoft/graphrag/blob/main/graphrag/query/structured_search/drift_search/drift_context.py) object to be used for preparing context data from community reports and query information
- `config`: model to define the DRIFT Search hyperparameters. [DRIFT Config model](https://github.com/microsoft/graphrag/blob/main/graphrag/config/models/drift_search_config.py)
- `token_encoder`: token encoder for tracking the budget for the algorithm.
- `tokenizer`: token encoder for tracking the budget for the algorithm.
- `query_state`: a state object as defined in [Query State](https://github.com/microsoft/graphrag/blob/main/graphrag/query/structured_search/drift_search/state.py) that allows to track execution of a DRIFT Search instance, alongside follow ups and [DRIFT actions](https://github.com/microsoft/graphrag/blob/main/graphrag/query/structured_search/drift_search/action.py).
## How to Use

View File

@ -11,7 +11,6 @@ def get_embedding_settings(
vector_store_params: dict | None = None,
) -> dict:
"""Transform GraphRAG config into settings for workflows."""
# TEMP
embeddings_llm_settings = settings.get_language_model_config(
settings.embed_text.model_id
)

View File

@ -9,19 +9,20 @@ import re
from collections.abc import Iterator
from itertools import islice
import tiktoken
from json_repair import repair_json
import graphrag.config.defaults as defs
from graphrag.tokenizer.get_tokenizer import get_tokenizer
from graphrag.tokenizer.tokenizer import Tokenizer
logger = logging.getLogger(__name__)
def num_tokens(text: str, token_encoder: tiktoken.Encoding | None = None) -> int:
def num_tokens(text: str, tokenizer: Tokenizer | None = None) -> int:
"""Return the number of tokens in the given text."""
if token_encoder is None:
token_encoder = tiktoken.get_encoding(defs.ENCODING_MODEL)
return len(token_encoder.encode(text)) # type: ignore
if tokenizer is None:
tokenizer = get_tokenizer(encoding_model=defs.ENCODING_MODEL)
return len(tokenizer.encode(text)) # type: ignore
def batched(iterable: Iterator, n: int):
@ -39,15 +40,13 @@ def batched(iterable: Iterator, n: int):
yield batch
def chunk_text(
text: str, max_tokens: int, token_encoder: tiktoken.Encoding | None = None
):
def chunk_text(text: str, max_tokens: int, tokenizer: Tokenizer | None = None):
"""Chunk text by token length."""
if token_encoder is None:
token_encoder = tiktoken.get_encoding(defs.ENCODING_MODEL)
tokens = token_encoder.encode(text) # type: ignore
if tokenizer is None:
tokenizer = get_tokenizer(encoding_model=defs.ENCODING_MODEL)
tokens = tokenizer.encode(text) # type: ignore
chunk_iterator = batched(iter(tokens), max_tokens)
yield from (token_encoder.decode(list(chunk)) for chunk in chunk_iterator)
yield from (tokenizer.decode(list(chunk)) for chunk in chunk_iterator)
def try_parse_json_object(input: str, verbose: bool = True) -> tuple[str, dict]:

View File

@ -7,13 +7,12 @@ from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any
import tiktoken
from graphrag.language_model.protocol.base import ChatModel
from graphrag.query.context_builder.builders import (
GlobalContextBuilder,
LocalContextBuilder,
)
from graphrag.tokenizer.tokenizer import Tokenizer
@dataclass
@ -34,13 +33,13 @@ class BaseQuestionGen(ABC):
self,
model: ChatModel,
context_builder: GlobalContextBuilder | LocalContextBuilder,
token_encoder: tiktoken.Encoding | None = None,
tokenizer: Tokenizer | None = None,
model_params: dict[str, Any] | None = None,
context_builder_params: dict[str, Any] | None = None,
):
self.model = model
self.context_builder = context_builder
self.token_encoder = token_encoder
self.tokenizer = tokenizer
self.model_params = model_params or {}
self.context_builder_params = context_builder_params or {}

View File

@ -7,8 +7,6 @@ import logging
import time
from typing import Any, cast
import tiktoken
from graphrag.callbacks.llm_callbacks import BaseLLMCallback
from graphrag.language_model.protocol.base import ChatModel
from graphrag.prompts.query.question_gen_system_prompt import QUESTION_SYSTEM_PROMPT
@ -21,6 +19,7 @@ from graphrag.query.context_builder.conversation_history import (
)
from graphrag.query.llm.text_utils import num_tokens
from graphrag.query.question_gen.base import BaseQuestionGen, QuestionResult
from graphrag.tokenizer.tokenizer import Tokenizer
logger = logging.getLogger(__name__)
@ -32,7 +31,7 @@ class LocalQuestionGen(BaseQuestionGen):
self,
model: ChatModel,
context_builder: LocalContextBuilder,
token_encoder: tiktoken.Encoding | None = None,
tokenizer: Tokenizer | None = None,
system_prompt: str = QUESTION_SYSTEM_PROMPT,
callbacks: list[BaseLLMCallback] | None = None,
model_params: dict[str, Any] | None = None,
@ -41,7 +40,7 @@ class LocalQuestionGen(BaseQuestionGen):
super().__init__(
model=model,
context_builder=context_builder,
token_encoder=token_encoder,
tokenizer=tokenizer,
model_params=model_params,
context_builder_params=context_builder_params,
)
@ -118,7 +117,7 @@ class LocalQuestionGen(BaseQuestionGen):
},
completion_time=time.time() - start_time,
llm_calls=1,
prompt_tokens=num_tokens(system_prompt, self.token_encoder),
prompt_tokens=num_tokens(system_prompt, self.tokenizer),
)
except Exception:
@ -128,7 +127,7 @@ class LocalQuestionGen(BaseQuestionGen):
context_data=context_records,
completion_time=time.time() - start_time,
llm_calls=1,
prompt_tokens=num_tokens(system_prompt, self.token_encoder),
prompt_tokens=num_tokens(system_prompt, self.tokenizer),
)
async def generate(
@ -201,7 +200,7 @@ class LocalQuestionGen(BaseQuestionGen):
},
completion_time=time.time() - start_time,
llm_calls=1,
prompt_tokens=num_tokens(system_prompt, self.token_encoder),
prompt_tokens=num_tokens(system_prompt, self.tokenizer),
)
except Exception:
@ -211,5 +210,5 @@ class LocalQuestionGen(BaseQuestionGen):
context_data=context_records,
completion_time=time.time() - start_time,
llm_calls=1,
prompt_tokens=num_tokens(system_prompt, self.token_encoder),
prompt_tokens=num_tokens(system_prompt, self.tokenizer),
)