mirror of
https://github.com/microsoft/graphrag.git
synced 2026-01-14 09:07:20 +08:00
Fix dangling token_encoder refs
This commit is contained in:
parent
3f93a0095d
commit
4c4ab16070
@ -24,7 +24,7 @@ Below are the key parameters of the [DRIFTSearch class](https://github.com/micro
|
||||
- `llm`: OpenAI model object to be used for response generation
|
||||
- `context_builder`: [context builder](https://github.com/microsoft/graphrag/blob/main/graphrag/query/structured_search/drift_search/drift_context.py) object to be used for preparing context data from community reports and query information
|
||||
- `config`: model to define the DRIFT Search hyperparameters. [DRIFT Config model](https://github.com/microsoft/graphrag/blob/main/graphrag/config/models/drift_search_config.py)
|
||||
- `token_encoder`: token encoder for tracking the budget for the algorithm.
|
||||
- `tokenizer`: token encoder for tracking the budget for the algorithm.
|
||||
- `query_state`: a state object as defined in [Query State](https://github.com/microsoft/graphrag/blob/main/graphrag/query/structured_search/drift_search/state.py) that allows to track execution of a DRIFT Search instance, alongside follow ups and [DRIFT actions](https://github.com/microsoft/graphrag/blob/main/graphrag/query/structured_search/drift_search/action.py).
|
||||
|
||||
## How to Use
|
||||
|
||||
@ -11,7 +11,6 @@ def get_embedding_settings(
|
||||
vector_store_params: dict | None = None,
|
||||
) -> dict:
|
||||
"""Transform GraphRAG config into settings for workflows."""
|
||||
# TEMP
|
||||
embeddings_llm_settings = settings.get_language_model_config(
|
||||
settings.embed_text.model_id
|
||||
)
|
||||
|
||||
@ -9,19 +9,20 @@ import re
|
||||
from collections.abc import Iterator
|
||||
from itertools import islice
|
||||
|
||||
import tiktoken
|
||||
from json_repair import repair_json
|
||||
|
||||
import graphrag.config.defaults as defs
|
||||
from graphrag.tokenizer.get_tokenizer import get_tokenizer
|
||||
from graphrag.tokenizer.tokenizer import Tokenizer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def num_tokens(text: str, token_encoder: tiktoken.Encoding | None = None) -> int:
|
||||
def num_tokens(text: str, tokenizer: Tokenizer | None = None) -> int:
|
||||
"""Return the number of tokens in the given text."""
|
||||
if token_encoder is None:
|
||||
token_encoder = tiktoken.get_encoding(defs.ENCODING_MODEL)
|
||||
return len(token_encoder.encode(text)) # type: ignore
|
||||
if tokenizer is None:
|
||||
tokenizer = get_tokenizer(encoding_model=defs.ENCODING_MODEL)
|
||||
return len(tokenizer.encode(text)) # type: ignore
|
||||
|
||||
|
||||
def batched(iterable: Iterator, n: int):
|
||||
@ -39,15 +40,13 @@ def batched(iterable: Iterator, n: int):
|
||||
yield batch
|
||||
|
||||
|
||||
def chunk_text(
|
||||
text: str, max_tokens: int, token_encoder: tiktoken.Encoding | None = None
|
||||
):
|
||||
def chunk_text(text: str, max_tokens: int, tokenizer: Tokenizer | None = None):
|
||||
"""Chunk text by token length."""
|
||||
if token_encoder is None:
|
||||
token_encoder = tiktoken.get_encoding(defs.ENCODING_MODEL)
|
||||
tokens = token_encoder.encode(text) # type: ignore
|
||||
if tokenizer is None:
|
||||
tokenizer = get_tokenizer(encoding_model=defs.ENCODING_MODEL)
|
||||
tokens = tokenizer.encode(text) # type: ignore
|
||||
chunk_iterator = batched(iter(tokens), max_tokens)
|
||||
yield from (token_encoder.decode(list(chunk)) for chunk in chunk_iterator)
|
||||
yield from (tokenizer.decode(list(chunk)) for chunk in chunk_iterator)
|
||||
|
||||
|
||||
def try_parse_json_object(input: str, verbose: bool = True) -> tuple[str, dict]:
|
||||
|
||||
@ -7,13 +7,12 @@ from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import tiktoken
|
||||
|
||||
from graphrag.language_model.protocol.base import ChatModel
|
||||
from graphrag.query.context_builder.builders import (
|
||||
GlobalContextBuilder,
|
||||
LocalContextBuilder,
|
||||
)
|
||||
from graphrag.tokenizer.tokenizer import Tokenizer
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -34,13 +33,13 @@ class BaseQuestionGen(ABC):
|
||||
self,
|
||||
model: ChatModel,
|
||||
context_builder: GlobalContextBuilder | LocalContextBuilder,
|
||||
token_encoder: tiktoken.Encoding | None = None,
|
||||
tokenizer: Tokenizer | None = None,
|
||||
model_params: dict[str, Any] | None = None,
|
||||
context_builder_params: dict[str, Any] | None = None,
|
||||
):
|
||||
self.model = model
|
||||
self.context_builder = context_builder
|
||||
self.token_encoder = token_encoder
|
||||
self.tokenizer = tokenizer
|
||||
self.model_params = model_params or {}
|
||||
self.context_builder_params = context_builder_params or {}
|
||||
|
||||
|
||||
@ -7,8 +7,6 @@ import logging
|
||||
import time
|
||||
from typing import Any, cast
|
||||
|
||||
import tiktoken
|
||||
|
||||
from graphrag.callbacks.llm_callbacks import BaseLLMCallback
|
||||
from graphrag.language_model.protocol.base import ChatModel
|
||||
from graphrag.prompts.query.question_gen_system_prompt import QUESTION_SYSTEM_PROMPT
|
||||
@ -21,6 +19,7 @@ from graphrag.query.context_builder.conversation_history import (
|
||||
)
|
||||
from graphrag.query.llm.text_utils import num_tokens
|
||||
from graphrag.query.question_gen.base import BaseQuestionGen, QuestionResult
|
||||
from graphrag.tokenizer.tokenizer import Tokenizer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -32,7 +31,7 @@ class LocalQuestionGen(BaseQuestionGen):
|
||||
self,
|
||||
model: ChatModel,
|
||||
context_builder: LocalContextBuilder,
|
||||
token_encoder: tiktoken.Encoding | None = None,
|
||||
tokenizer: Tokenizer | None = None,
|
||||
system_prompt: str = QUESTION_SYSTEM_PROMPT,
|
||||
callbacks: list[BaseLLMCallback] | None = None,
|
||||
model_params: dict[str, Any] | None = None,
|
||||
@ -41,7 +40,7 @@ class LocalQuestionGen(BaseQuestionGen):
|
||||
super().__init__(
|
||||
model=model,
|
||||
context_builder=context_builder,
|
||||
token_encoder=token_encoder,
|
||||
tokenizer=tokenizer,
|
||||
model_params=model_params,
|
||||
context_builder_params=context_builder_params,
|
||||
)
|
||||
@ -118,7 +117,7 @@ class LocalQuestionGen(BaseQuestionGen):
|
||||
},
|
||||
completion_time=time.time() - start_time,
|
||||
llm_calls=1,
|
||||
prompt_tokens=num_tokens(system_prompt, self.token_encoder),
|
||||
prompt_tokens=num_tokens(system_prompt, self.tokenizer),
|
||||
)
|
||||
|
||||
except Exception:
|
||||
@ -128,7 +127,7 @@ class LocalQuestionGen(BaseQuestionGen):
|
||||
context_data=context_records,
|
||||
completion_time=time.time() - start_time,
|
||||
llm_calls=1,
|
||||
prompt_tokens=num_tokens(system_prompt, self.token_encoder),
|
||||
prompt_tokens=num_tokens(system_prompt, self.tokenizer),
|
||||
)
|
||||
|
||||
async def generate(
|
||||
@ -201,7 +200,7 @@ class LocalQuestionGen(BaseQuestionGen):
|
||||
},
|
||||
completion_time=time.time() - start_time,
|
||||
llm_calls=1,
|
||||
prompt_tokens=num_tokens(system_prompt, self.token_encoder),
|
||||
prompt_tokens=num_tokens(system_prompt, self.tokenizer),
|
||||
)
|
||||
|
||||
except Exception:
|
||||
@ -211,5 +210,5 @@ class LocalQuestionGen(BaseQuestionGen):
|
||||
context_data=context_records,
|
||||
completion_time=time.time() - start_time,
|
||||
llm_calls=1,
|
||||
prompt_tokens=num_tokens(system_prompt, self.token_encoder),
|
||||
prompt_tokens=num_tokens(system_prompt, self.tokenizer),
|
||||
)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user