Housekeeping (#2086)
Some checks failed
gh-pages / build (push) Has been cancelled
Python CI / python-ci (ubuntu-latest, 3.10) (push) Has been cancelled
Python CI / python-ci (ubuntu-latest, 3.11) (push) Has been cancelled
Python CI / python-ci (windows-latest, 3.10) (push) Has been cancelled
Python CI / python-ci (windows-latest, 3.11) (push) Has been cancelled
Python Integration Tests / python-ci (ubuntu-latest, 3.10) (push) Has been cancelled
Python Integration Tests / python-ci (windows-latest, 3.10) (push) Has been cancelled
Python Notebook Tests / python-ci (ubuntu-latest, 3.10) (push) Has been cancelled
Python Notebook Tests / python-ci (windows-latest, 3.10) (push) Has been cancelled
Python Publish (pypi) / Upload release to PyPI (push) Has been cancelled
Python Smoke Tests / python-ci (ubuntu-latest, 3.10) (push) Has been cancelled
Python Smoke Tests / python-ci (windows-latest, 3.10) (push) Has been cancelled
Spellcheck / spellcheck (push) Has been cancelled

* Add deprecation warnings for fnllm and multi-search

* Fix dangling token_encoder refs

* Fix local_search notebook

* Fix global search dynamic notebook

* Fix global search notebook

* Fix drift notebook

* Switch example notebooks to use LiteLLM config

* Properly annotate dev deps as a group

* Semver

* Remove --extra dev

* Remove llm_model variable

* Ignore ruff ASYNC240

* Add note about expected broken notebook in docs

* Fix custom vector store notebook

* Push tokenizer throughout
This commit is contained in:
Nathan Evans 2025-10-07 16:21:24 -07:00 committed by GitHub
parent 6c86b0a7bb
commit ac8a7f5eef
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
42 changed files with 549 additions and 514 deletions

View File

@ -31,7 +31,7 @@ jobs:
- name: Install dependencies
shell: bash
run: uv sync --extra dev
run: uv sync
- name: mkdocs build
shell: bash

View File

@ -67,7 +67,7 @@ jobs:
- name: Install dependencies
shell: bash
run: |
uv sync --extra dev
uv sync
uv pip install gensim
- name: Check

View File

@ -67,7 +67,7 @@ jobs:
- name: Install dependencies
shell: bash
run: |
uv sync --extra dev
uv sync
uv pip install gensim
- name: Build

View File

@ -67,7 +67,7 @@ jobs:
- name: Install dependencies
shell: bash
run: |
uv sync --extra dev
uv sync
uv pip install gensim
- name: Notebook Test

View File

@ -72,7 +72,7 @@ jobs:
- name: Install dependencies
shell: bash
run: |
uv sync --extra dev
uv sync
uv pip install gensim
- name: Build

View File

@ -0,0 +1,4 @@
{
"type": "patch",
"description": "Housekeeping toward 2.7."
}

View File

@ -11,12 +11,8 @@
## Install Dependencies
```shell
# (optional) create virtual environment
uv venv --python 3.10
source .venv/bin/activate
# install python dependencies
uv sync --extra dev
uv sync
```
## Execute the indexing engine

View File

@ -12,12 +12,8 @@
## Install Dependencies
```sh
# (optional) create virtual environment
uv venv --python 3.10
source .venv/bin/activate
# install python dependencies
uv sync --extra dev
uv sync
```
## Execute the Indexing Engine

View File

@ -67,6 +67,8 @@
"metadata": {},
"outputs": [],
"source": [
"# note that we expect this to fail on the deployed docs because the PROJECT_DIRECTORY is not set to a real location.\n",
"# if you run this notebook locally, make sure to point at a location containing your settings.yaml\n",
"graphrag_config = load_config(Path(PROJECT_DIRECTORY))"
]
},

View File

@ -61,6 +61,7 @@
"import numpy as np\n",
"import yaml\n",
"\n",
"from graphrag.config.models.vector_store_schema_config import VectorStoreSchemaConfig\n",
"from graphrag.data_model.types import TextEmbedder\n",
"\n",
"# GraphRAG vector store components\n",
@ -147,14 +148,12 @@
" self.vectors: dict[str, np.ndarray] = {}\n",
" self.connected = False\n",
"\n",
" print(\n",
" f\"🚀 SimpleInMemoryVectorStore initialized for collection: {self.collection_name}\"\n",
" )\n",
" print(f\"🚀 SimpleInMemoryVectorStore initialized for index: {self.index_name}\")\n",
"\n",
" def connect(self, **kwargs: Any) -> None:\n",
" \"\"\"Connect to the vector storage (no-op for in-memory store).\"\"\"\n",
" self.connected = True\n",
" print(f\"✅ Connected to in-memory vector store: {self.collection_name}\")\n",
" print(f\"✅ Connected to in-memory vector store: {self.index_name}\")\n",
"\n",
" def load_documents(\n",
" self, documents: list[VectorStoreDocument], overwrite: bool = True\n",
@ -250,7 +249,7 @@
" def get_stats(self) -> dict[str, Any]:\n",
" \"\"\"Get statistics about the vector store (custom method).\"\"\"\n",
" return {\n",
" \"collection_name\": self.collection_name,\n",
" \"index_name\": self.index_name,\n",
" \"document_count\": len(self.documents),\n",
" \"vector_count\": len(self.vectors),\n",
" \"connected\": self.connected,\n",
@ -353,11 +352,11 @@
"outputs": [],
"source": [
"# Test creating vector store using the factory\n",
"vector_store_config = {\"collection_name\": \"test_collection\"}\n",
"schema = VectorStoreSchemaConfig(index_name=\"test_collection\")\n",
"\n",
"# Create vector store instance using factory\n",
"vector_store = VectorStoreFactory.create_vector_store(\n",
" CUSTOM_VECTOR_STORE_TYPE, vector_store_config\n",
" CUSTOM_VECTOR_STORE_TYPE, vector_store_schema_config=schema\n",
")\n",
"\n",
"print(f\"✅ Created vector store instance: {type(vector_store).__name__}\")\n",
@ -486,9 +485,13 @@
" print(\"🚀 Simulating GraphRAG pipeline with custom vector store...\\n\")\n",
"\n",
" # 1. GraphRAG creates vector store using factory\n",
" config = {\"collection_name\": \"graphrag_entities\", \"similarity_threshold\": 0.3}\n",
" schema = VectorStoreSchemaConfig(index_name=\"graphrag_entities\")\n",
"\n",
" store = VectorStoreFactory.create_vector_store(CUSTOM_VECTOR_STORE_TYPE, config)\n",
" store = VectorStoreFactory.create_vector_store(\n",
" CUSTOM_VECTOR_STORE_TYPE,\n",
" vector_store_schema_config=schema,\n",
" similarity_threshold=0.3,\n",
" )\n",
" store.connect()\n",
"\n",
" print(\"✅ Step 1: Vector store created and connected\")\n",
@ -549,7 +552,8 @@
" # Test 1: Basic functionality\n",
" print(\"Test 1: Basic functionality\")\n",
" store = VectorStoreFactory.create_vector_store(\n",
" CUSTOM_VECTOR_STORE_TYPE, {\"collection_name\": \"test\"}\n",
" CUSTOM_VECTOR_STORE_TYPE,\n",
" vector_store_schema_config=VectorStoreSchemaConfig(index_name=\"test\"),\n",
" )\n",
" store.connect()\n",
"\n",
@ -597,7 +601,8 @@
" # Test 5: Error handling\n",
" print(\"\\nTest 5: Error handling\")\n",
" disconnected_store = VectorStoreFactory.create_vector_store(\n",
" CUSTOM_VECTOR_STORE_TYPE, {\"collection_name\": \"test2\"}\n",
" CUSTOM_VECTOR_STORE_TYPE,\n",
" vector_store_schema_config=VectorStoreSchemaConfig(index_name=\"test2\"),\n",
" )\n",
"\n",
" try:\n",
@ -653,7 +658,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "graphrag-venv (3.10.18)",
"display_name": "graphrag",
"language": "python",
"name": "python3"
},
@ -667,7 +672,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.18"
"version": "3.12.10"
}
},
"nbformat": 4,

View File

@ -20,11 +20,11 @@
"from pathlib import Path\n",
"\n",
"import pandas as pd\n",
"import tiktoken\n",
"\n",
"from graphrag.config.enums import ModelType\n",
"from graphrag.config.models.drift_search_config import DRIFTSearchConfig\n",
"from graphrag.config.models.language_model_config import LanguageModelConfig\n",
"from graphrag.config.models.vector_store_schema_config import VectorStoreSchemaConfig\n",
"from graphrag.language_model.manager import ModelManager\n",
"from graphrag.query.indexer_adapters import (\n",
" read_indexer_entities,\n",
@ -37,6 +37,7 @@
" DRIFTSearchContextBuilder,\n",
")\n",
"from graphrag.query.structured_search.drift_search.search import DRIFTSearch\n",
"from graphrag.tokenizer.get_tokenizer import get_tokenizer\n",
"from graphrag.vector_stores.lancedb import LanceDBVectorStore\n",
"\n",
"INPUT_DIR = \"./inputs/operation dulce\"\n",
@ -62,12 +63,16 @@
"# load description embeddings to an in-memory lancedb vectorstore\n",
"# to connect to a remote db, specify url and port values.\n",
"description_embedding_store = LanceDBVectorStore(\n",
" collection_name=\"default-entity-description\",\n",
" vector_store_schema_config=VectorStoreSchemaConfig(\n",
" index_name=\"default-entity-description\"\n",
" ),\n",
")\n",
"description_embedding_store.connect(db_uri=LANCEDB_URI)\n",
"\n",
"full_content_embedding_store = LanceDBVectorStore(\n",
" collection_name=\"default-community-full_content\",\n",
" vector_store_schema_config=VectorStoreSchemaConfig(\n",
" index_name=\"default-community-full_content\"\n",
" )\n",
")\n",
"full_content_embedding_store.connect(db_uri=LANCEDB_URI)\n",
"\n",
@ -94,33 +99,33 @@
"outputs": [],
"source": [
"api_key = os.environ[\"GRAPHRAG_API_KEY\"]\n",
"llm_model = os.environ[\"GRAPHRAG_LLM_MODEL\"]\n",
"embedding_model = os.environ[\"GRAPHRAG_EMBEDDING_MODEL\"]\n",
"\n",
"chat_config = LanguageModelConfig(\n",
" api_key=api_key,\n",
" type=ModelType.OpenAIChat,\n",
" model=llm_model,\n",
" type=ModelType.Chat,\n",
" model_provider=\"openai\",\n",
" model=\"gpt-4.1\",\n",
" max_retries=20,\n",
")\n",
"chat_model = ModelManager().get_or_create_chat_model(\n",
" name=\"local_search\",\n",
" model_type=ModelType.OpenAIChat,\n",
" model_type=ModelType.Chat,\n",
" config=chat_config,\n",
")\n",
"\n",
"token_encoder = tiktoken.encoding_for_model(llm_model)\n",
"tokenizer = get_tokenizer(chat_config)\n",
"\n",
"embedding_config = LanguageModelConfig(\n",
" api_key=api_key,\n",
" type=ModelType.OpenAIEmbedding,\n",
" model=embedding_model,\n",
" type=ModelType.Embedding,\n",
" model_provider=\"openai\",\n",
" model=\"text-embedding-3-small\",\n",
" max_retries=20,\n",
")\n",
"\n",
"text_embedder = ModelManager().get_or_create_embedding_model(\n",
" name=\"local_search_embedding\",\n",
" model_type=ModelType.OpenAIEmbedding,\n",
" model_type=ModelType.Embedding,\n",
" config=embedding_config,\n",
")"
]
@ -173,12 +178,12 @@
" reports=reports,\n",
" entity_text_embeddings=description_embedding_store,\n",
" text_units=text_units,\n",
" token_encoder=token_encoder,\n",
" tokenizer=tokenizer,\n",
" config=drift_params,\n",
")\n",
"\n",
"search = DRIFTSearch(\n",
" model=chat_model, context_builder=context_builder, token_encoder=token_encoder\n",
" model=chat_model, context_builder=context_builder, tokenizer=tokenizer\n",
")"
]
},
@ -212,7 +217,7 @@
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"display_name": "graphrag",
"language": "python",
"name": "python3"
},
@ -226,7 +231,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.9"
"version": "3.12.10"
}
},
"nbformat": 4,

View File

@ -19,7 +19,6 @@
"import os\n",
"\n",
"import pandas as pd\n",
"import tiktoken\n",
"\n",
"from graphrag.config.enums import ModelType\n",
"from graphrag.config.models.language_model_config import LanguageModelConfig\n",
@ -32,7 +31,8 @@
"from graphrag.query.structured_search.global_search.community_context import (\n",
" GlobalCommunityContext,\n",
")\n",
"from graphrag.query.structured_search.global_search.search import GlobalSearch"
"from graphrag.query.structured_search.global_search.search import GlobalSearch\n",
"from graphrag.tokenizer.get_tokenizer import get_tokenizer"
]
},
{
@ -58,21 +58,21 @@
"outputs": [],
"source": [
"api_key = os.environ[\"GRAPHRAG_API_KEY\"]\n",
"llm_model = os.environ[\"GRAPHRAG_LLM_MODEL\"]\n",
"\n",
"config = LanguageModelConfig(\n",
" api_key=api_key,\n",
" type=ModelType.OpenAIChat,\n",
" model=llm_model,\n",
" type=ModelType.Chat,\n",
" model_provider=\"openai\",\n",
" model=\"gpt-4.1\",\n",
" max_retries=20,\n",
")\n",
"model = ModelManager().get_or_create_chat_model(\n",
" name=\"global_search\",\n",
" model_type=ModelType.OpenAIChat,\n",
" model_type=ModelType.Chat,\n",
" config=config,\n",
")\n",
"\n",
"token_encoder = tiktoken.encoding_for_model(llm_model)"
"tokenizer = get_tokenizer(config)"
]
},
{
@ -142,7 +142,7 @@
" community_reports=reports,\n",
" communities=communities,\n",
" entities=entities, # default to None if you don't want to use community weights for ranking\n",
" token_encoder=token_encoder,\n",
" tokenizer=tokenizer,\n",
")"
]
},
@ -193,7 +193,7 @@
"search_engine = GlobalSearch(\n",
" model=model,\n",
" context_builder=context_builder,\n",
" token_encoder=token_encoder,\n",
" tokenizer=tokenizer,\n",
" max_data_tokens=12_000, # change this based on the token limit you have on your model (if you are using a model with 8k limit, a good setting could be 5000)\n",
" map_llm_params=map_llm_params,\n",
" reduce_llm_params=reduce_llm_params,\n",
@ -241,7 +241,7 @@
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"display_name": "graphrag",
"language": "python",
"name": "python3"
},
@ -255,7 +255,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.9"
"version": "3.12.10"
}
},
"nbformat": 4,

View File

@ -19,7 +19,6 @@
"import os\n",
"\n",
"import pandas as pd\n",
"import tiktoken\n",
"\n",
"from graphrag.config.enums import ModelType\n",
"from graphrag.config.models.language_model_config import LanguageModelConfig\n",
@ -57,22 +56,24 @@
"metadata": {},
"outputs": [],
"source": [
"from graphrag.tokenizer.get_tokenizer import get_tokenizer\n",
"\n",
"api_key = os.environ[\"GRAPHRAG_API_KEY\"]\n",
"llm_model = os.environ[\"GRAPHRAG_LLM_MODEL\"]\n",
"\n",
"config = LanguageModelConfig(\n",
" api_key=api_key,\n",
" type=ModelType.OpenAIChat,\n",
" model=llm_model,\n",
" type=ModelType.Chat,\n",
" model_provider=\"openai\",\n",
" model=\"gpt-4.1\",\n",
" max_retries=20,\n",
")\n",
"model = ModelManager().get_or_create_chat_model(\n",
" name=\"global_search\",\n",
" model_type=ModelType.OpenAIChat,\n",
" model_type=ModelType.Chat,\n",
" config=config,\n",
")\n",
"\n",
"token_encoder = tiktoken.encoding_for_model(llm_model)"
"tokenizer = get_tokenizer(config)"
]
},
{
@ -155,11 +156,11 @@
" community_reports=reports,\n",
" communities=communities,\n",
" entities=entities, # default to None if you don't want to use community weights for ranking\n",
" token_encoder=token_encoder,\n",
" tokenizer=tokenizer,\n",
" dynamic_community_selection=True,\n",
" dynamic_community_selection_kwargs={\n",
" \"model\": model,\n",
" \"token_encoder\": token_encoder,\n",
" \"tokenizer\": tokenizer,\n",
" },\n",
")"
]
@ -211,7 +212,7 @@
"search_engine = GlobalSearch(\n",
" model=model,\n",
" context_builder=context_builder,\n",
" token_encoder=token_encoder,\n",
" tokenizer=tokenizer,\n",
" max_data_tokens=12_000, # change this based on the token limit you have on your model (if you are using a model with 8k limit, a good setting could be 5000)\n",
" map_llm_params=map_llm_params,\n",
" reduce_llm_params=reduce_llm_params,\n",
@ -255,7 +256,7 @@
"prompt_tokens = result.prompt_tokens_categories[\"build_context\"]\n",
"output_tokens = result.output_tokens_categories[\"build_context\"]\n",
"print(\n",
" f\"Build context ({llm_model})\\nLLM calls: {llm_calls}. Prompt tokens: {prompt_tokens}. Output tokens: {output_tokens}.\"\n",
" f\"Build context LLM calls: {llm_calls}. Prompt tokens: {prompt_tokens}. Output tokens: {output_tokens}.\"\n",
")\n",
"# inspect number of LLM calls and tokens in map-reduce\n",
"llm_calls = result.llm_calls_categories[\"map\"] + result.llm_calls_categories[\"reduce\"]\n",
@ -266,14 +267,14 @@
" result.output_tokens_categories[\"map\"] + result.output_tokens_categories[\"reduce\"]\n",
")\n",
"print(\n",
" f\"Map-reduce ({llm_model})\\nLLM calls: {llm_calls}. Prompt tokens: {prompt_tokens}. Output tokens: {output_tokens}.\"\n",
" f\"Map-reduce LLM calls: {llm_calls}. Prompt tokens: {prompt_tokens}. Output tokens: {output_tokens}.\"\n",
")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"display_name": "graphrag",
"language": "python",
"name": "python3"
},
@ -287,7 +288,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.9"
"version": "3.12.10"
}
},
"nbformat": 4,

View File

@ -19,8 +19,8 @@
"import os\n",
"\n",
"import pandas as pd\n",
"import tiktoken\n",
"\n",
"from graphrag.config.models.vector_store_schema_config import VectorStoreSchemaConfig\n",
"from graphrag.query.context_builder.entity_extraction import EntityVectorStoreKey\n",
"from graphrag.query.indexer_adapters import (\n",
" read_indexer_covariates,\n",
@ -102,7 +102,9 @@
"# load description embeddings to an in-memory lancedb vectorstore\n",
"# to connect to a remote db, specify url and port values.\n",
"description_embedding_store = LanceDBVectorStore(\n",
" collection_name=\"default-entity-description\",\n",
" vector_store_schema_config=VectorStoreSchemaConfig(\n",
" index_name=\"default-entity-description\"\n",
" )\n",
")\n",
"description_embedding_store.connect(db_uri=LANCEDB_URI)\n",
"\n",
@ -195,37 +197,38 @@
"from graphrag.config.enums import ModelType\n",
"from graphrag.config.models.language_model_config import LanguageModelConfig\n",
"from graphrag.language_model.manager import ModelManager\n",
"from graphrag.tokenizer.get_tokenizer import get_tokenizer\n",
"\n",
"api_key = os.environ[\"GRAPHRAG_API_KEY\"]\n",
"llm_model = os.environ[\"GRAPHRAG_LLM_MODEL\"]\n",
"embedding_model = os.environ[\"GRAPHRAG_EMBEDDING_MODEL\"]\n",
"\n",
"chat_config = LanguageModelConfig(\n",
" api_key=api_key,\n",
" type=ModelType.OpenAIChat,\n",
" model=llm_model,\n",
" type=ModelType.Chat,\n",
" model_provider=\"openai\",\n",
" model=\"gpt-4.1\",\n",
" max_retries=20,\n",
")\n",
"chat_model = ModelManager().get_or_create_chat_model(\n",
" name=\"local_search\",\n",
" model_type=ModelType.OpenAIChat,\n",
" model_type=ModelType.Chat,\n",
" config=chat_config,\n",
")\n",
"\n",
"token_encoder = tiktoken.encoding_for_model(llm_model)\n",
"\n",
"embedding_config = LanguageModelConfig(\n",
" api_key=api_key,\n",
" type=ModelType.OpenAIEmbedding,\n",
" model=embedding_model,\n",
" type=ModelType.Embedding,\n",
" model_provider=\"openai\",\n",
" model=\"text-embedding-3-small\",\n",
" max_retries=20,\n",
")\n",
"\n",
"text_embedder = ModelManager().get_or_create_embedding_model(\n",
" name=\"local_search_embedding\",\n",
" model_type=ModelType.OpenAIEmbedding,\n",
" model_type=ModelType.Embedding,\n",
" config=embedding_config,\n",
")"
")\n",
"\n",
"tokenizer = get_tokenizer(chat_config)"
]
},
{
@ -251,7 +254,7 @@
" entity_text_embeddings=description_embedding_store,\n",
" embedding_vectorstore_key=EntityVectorStoreKey.ID, # if the vectorstore uses entity title as ids, set this to EntityVectorStoreKey.TITLE\n",
" text_embedder=text_embedder,\n",
" token_encoder=token_encoder,\n",
" tokenizer=tokenizer,\n",
")"
]
},
@ -314,7 +317,7 @@
"search_engine = LocalSearch(\n",
" model=chat_model,\n",
" context_builder=context_builder,\n",
" token_encoder=token_encoder,\n",
" tokenizer=tokenizer,\n",
" model_params=model_params,\n",
" context_builder_params=local_context_params,\n",
" response_type=\"multiple paragraphs\", # free form text describing the response type and format, can be anything, e.g. prioritized list, single paragraph, multiple paragraphs, multiple-page report\n",
@ -426,7 +429,7 @@
"question_generator = LocalQuestionGen(\n",
" model=chat_model,\n",
" context_builder=context_builder,\n",
" token_encoder=token_encoder,\n",
" tokenizer=tokenizer,\n",
" model_params=model_params,\n",
" context_builder_params=local_context_params,\n",
")"
@ -451,7 +454,7 @@
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"display_name": "graphrag",
"language": "python",
"name": "python3"
},
@ -465,7 +468,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.9"
"version": "3.12.10"
}
},
"nbformat": 4,

View File

@ -24,7 +24,7 @@ Below are the key parameters of the [DRIFTSearch class](https://github.com/micro
- `llm`: OpenAI model object to be used for response generation
- `context_builder`: [context builder](https://github.com/microsoft/graphrag/blob/main/graphrag/query/structured_search/drift_search/drift_context.py) object to be used for preparing context data from community reports and query information
- `config`: model to define the DRIFT Search hyperparameters. [DRIFT Config model](https://github.com/microsoft/graphrag/blob/main/graphrag/config/models/drift_search_config.py)
- `token_encoder`: token encoder for tracking the budget for the algorithm.
- `tokenizer`: token encoder for tracking the budget for the algorithm.
- `query_state`: a state object as defined in [Query State](https://github.com/microsoft/graphrag/blob/main/graphrag/query/structured_search/drift_search/state.py) that allows to track execution of a DRIFT Search instance, alongside follow ups and [DRIFT actions](https://github.com/microsoft/graphrag/blob/main/graphrag/query/structured_search/drift_search/action.py).
## How to Use

View File

@ -231,6 +231,10 @@ async def multi_index_global_search(
"""
init_loggers(config=config, verbose=verbose, filename="query.log")
logger.warning(
"Multi-index search is deprecated and will be removed in GraphRAG v3."
)
# Streaming not supported yet
if streaming:
message = "Streaming not yet implemented for multi_global_search"
@ -510,6 +514,9 @@ async def multi_index_local_search(
"""
init_loggers(config=config, verbose=verbose, filename="query.log")
logger.warning(
"Multi-index search is deprecated and will be removed in GraphRAG v3."
)
# Streaming not supported yet
if streaming:
message = "Streaming not yet implemented for multi_index_local_search"
@ -874,6 +881,10 @@ async def multi_index_drift_search(
"""
init_loggers(config=config, verbose=verbose, filename="query.log")
logger.warning(
"Multi-index search is deprecated and will be removed in GraphRAG v3."
)
# Streaming not supported yet
if streaming:
message = "Streaming not yet implemented for multi_drift_search"
@ -1166,6 +1177,10 @@ async def multi_index_basic_search(
"""
init_loggers(config=config, verbose=verbose, filename="query.log")
logger.warning(
"Multi-index search is deprecated and will be removed in GraphRAG v3."
)
# Streaming not supported yet
if streaming:
message = "Streaming not yet implemented for multi_basic_search"

View File

@ -11,7 +11,6 @@ def get_embedding_settings(
vector_store_params: dict | None = None,
) -> dict:
"""Transform GraphRAG config into settings for workflows."""
# TEMP
embeddings_llm_settings = settings.get_language_model_config(
settings.embed_text.model_id
)

View File

@ -98,6 +98,14 @@ class LanguageModelConfig(BaseModel):
if not ModelFactory.is_supported_model(self.type):
msg = f"Model type {self.type} is not recognized, must be one of {ModelFactory.get_chat_models() + ModelFactory.get_embedding_models()}."
raise KeyError(msg)
if self.type in [
"openai_chat",
"openai_embedding",
"azure_openai_chat",
"azure_openai_embedding",
]:
msg = f"Model config based on fnllm is deprecated and will be removed in GraphRAG v3, please use {ModelType.Chat} or {ModelType.Embedding} instead to switch to LiteLLM config."
logger.warning(msg)
model_provider: str | None = Field(
description="The model provider to use.",

View File

@ -210,7 +210,7 @@ def _create_vector_store(
vector_store = VectorStoreFactory().create_vector_store(
vector_store_schema_config=single_embedding_config,
vector_store_type=vector_store_type,
kwargs=vector_store_config,
**vector_store_config,
)
vector_store.connect(**vector_store_config)

View File

@ -8,10 +8,12 @@ import graphrag.data_model.schemas as schemas
from graphrag.index.operations.summarize_communities.graph_context.sort_context import (
sort_context,
)
from graphrag.query.llm.text_utils import num_tokens
from graphrag.tokenizer.tokenizer import Tokenizer
def build_mixed_context(context: list[dict], max_context_tokens: int) -> str:
def build_mixed_context(
context: list[dict], tokenizer: Tokenizer, max_context_tokens: int
) -> str:
"""
Build parent context by concatenating all sub-communities' contexts.
@ -45,9 +47,10 @@ def build_mixed_context(context: list[dict], max_context_tokens: int) -> str:
remaining_local_context.extend(sorted_context[rid][schemas.ALL_CONTEXT])
new_context_string = sort_context(
local_context=remaining_local_context + final_local_contexts,
tokenizer=tokenizer,
sub_community_reports=substitute_reports,
)
if num_tokens(new_context_string) <= max_context_tokens:
if tokenizer.num_tokens(new_context_string) <= max_context_tokens:
exceeded_limit = False
context_string = new_context_string
break
@ -63,7 +66,7 @@ def build_mixed_context(context: list[dict], max_context_tokens: int) -> str:
new_context_string = pd.DataFrame(substitute_reports).to_csv(
index=False, sep=","
)
if num_tokens(new_context_string) > max_context_tokens:
if tokenizer.num_tokens(new_context_string) > max_context_tokens:
break
context_string = new_context_string

View File

@ -30,7 +30,7 @@ from graphrag.index.utils.dataframes import (
where_column_equals,
)
from graphrag.logger.progress import progress_iterable
from graphrag.query.llm.text_utils import num_tokens
from graphrag.tokenizer.tokenizer import Tokenizer
logger = logging.getLogger(__name__)
@ -39,6 +39,7 @@ def build_local_context(
nodes,
edges,
claims,
tokenizer: Tokenizer,
callbacks: WorkflowCallbacks,
max_context_tokens: int = 16_000,
):
@ -49,7 +50,7 @@ def build_local_context(
for level in progress_iterable(levels, callbacks.progress, len(levels)):
communities_at_level_df = _prepare_reports_at_level(
nodes, edges, claims, level, max_context_tokens
nodes, edges, claims, tokenizer, level, max_context_tokens
)
communities_at_level_df.loc[:, schemas.COMMUNITY_LEVEL] = level
@ -63,6 +64,7 @@ def _prepare_reports_at_level(
node_df: pd.DataFrame,
edge_df: pd.DataFrame,
claim_df: pd.DataFrame | None,
tokenizer: Tokenizer,
level: int,
max_context_tokens: int = 16_000,
) -> pd.DataFrame:
@ -181,6 +183,7 @@ def _prepare_reports_at_level(
# Generate community-level context strings using vectorized batch processing
return parallel_sort_context_batch(
community_df,
tokenizer=tokenizer,
max_context_tokens=max_context_tokens,
)
@ -189,6 +192,7 @@ def build_level_context(
report_df: pd.DataFrame | None,
community_hierarchy_df: pd.DataFrame,
local_context_df: pd.DataFrame,
tokenizer: Tokenizer,
level: int,
max_context_tokens: int,
) -> pd.DataFrame:
@ -219,11 +223,11 @@ def build_level_context(
if report_df is None or report_df.empty:
invalid_context_df.loc[:, schemas.CONTEXT_STRING] = _sort_and_trim_context(
invalid_context_df, max_context_tokens
invalid_context_df, tokenizer, max_context_tokens
)
invalid_context_df[schemas.CONTEXT_SIZE] = invalid_context_df.loc[
:, schemas.CONTEXT_STRING
].map(num_tokens)
].map(tokenizer.num_tokens)
invalid_context_df[schemas.CONTEXT_EXCEED_FLAG] = False
return union(valid_context_df, invalid_context_df)
@ -237,6 +241,7 @@ def build_level_context(
invalid_context_df,
sub_context_df,
community_hierarchy_df,
tokenizer,
max_context_tokens,
)
@ -244,11 +249,13 @@ def build_level_context(
# this should be rare, but if it happens, we will just trim the local context to fit the limit
remaining_df = _antijoin_reports(invalid_context_df, community_df)
remaining_df.loc[:, schemas.CONTEXT_STRING] = _sort_and_trim_context(
remaining_df, max_context_tokens
remaining_df, tokenizer, max_context_tokens
)
result = union(valid_context_df, community_df, remaining_df)
result[schemas.CONTEXT_SIZE] = result.loc[:, schemas.CONTEXT_STRING].map(num_tokens)
result[schemas.CONTEXT_SIZE] = result.loc[:, schemas.CONTEXT_STRING].map(
tokenizer.num_tokens
)
result[schemas.CONTEXT_EXCEED_FLAG] = False
return result
@ -269,19 +276,29 @@ def _antijoin_reports(df: pd.DataFrame, reports: pd.DataFrame) -> pd.DataFrame:
return antijoin(df, reports, schemas.COMMUNITY_ID)
def _sort_and_trim_context(df: pd.DataFrame, max_context_tokens: int) -> pd.Series:
def _sort_and_trim_context(
df: pd.DataFrame, tokenizer: Tokenizer, max_context_tokens: int
) -> pd.Series:
"""Sort and trim context to fit the limit."""
series = cast("pd.Series", df[schemas.ALL_CONTEXT])
return transform_series(
series, lambda x: sort_context(x, max_context_tokens=max_context_tokens)
series,
lambda x: sort_context(
x, tokenizer=tokenizer, max_context_tokens=max_context_tokens
),
)
def _build_mixed_context(df: pd.DataFrame, max_context_tokens: int) -> pd.Series:
def _build_mixed_context(
df: pd.DataFrame, tokenizer: Tokenizer, max_context_tokens: int
) -> pd.Series:
"""Sort and trim context to fit the limit."""
series = cast("pd.Series", df[schemas.ALL_CONTEXT])
return transform_series(
series, lambda x: build_mixed_context(x, max_context_tokens=max_context_tokens)
series,
lambda x: build_mixed_context(
x, tokenizer, max_context_tokens=max_context_tokens
),
)
@ -303,6 +320,7 @@ def _get_community_df(
invalid_context_df: pd.DataFrame,
sub_context_df: pd.DataFrame,
community_hierarchy_df: pd.DataFrame,
tokenizer: Tokenizer,
max_context_tokens: int,
) -> pd.DataFrame:
"""Get community context for each community."""
@ -338,7 +356,7 @@ def _get_community_df(
.reset_index()
)
community_df[schemas.CONTEXT_STRING] = _build_mixed_context(
community_df, max_context_tokens
community_df, tokenizer, max_context_tokens
)
community_df[schemas.COMMUNITY_LEVEL] = level
return community_df

View File

@ -5,11 +5,12 @@
import pandas as pd
import graphrag.data_model.schemas as schemas
from graphrag.query.llm.text_utils import num_tokens
from graphrag.tokenizer.tokenizer import Tokenizer
def sort_context(
local_context: list[dict],
tokenizer: Tokenizer,
sub_community_reports: list[dict] | None = None,
max_context_tokens: int | None = None,
node_name_column: str = schemas.TITLE,
@ -112,7 +113,10 @@ def sort_context(
new_context_string = _get_context_string(
sorted_nodes, sorted_edges, sorted_claims, sub_community_reports
)
if max_context_tokens and num_tokens(new_context_string) > max_context_tokens:
if (
max_context_tokens
and tokenizer.num_tokens(new_context_string) > max_context_tokens
):
break
context_string = new_context_string
@ -122,7 +126,9 @@ def sort_context(
)
def parallel_sort_context_batch(community_df, max_context_tokens, parallel=False):
def parallel_sort_context_batch(
community_df, tokenizer: Tokenizer, max_context_tokens, parallel=False
):
"""Calculate context using parallelization if enabled."""
if parallel:
# Use ThreadPoolExecutor for parallel execution
@ -131,7 +137,9 @@ def parallel_sort_context_batch(community_df, max_context_tokens, parallel=False
with ThreadPoolExecutor(max_workers=None) as executor:
context_strings = list(
executor.map(
lambda x: sort_context(x, max_context_tokens=max_context_tokens),
lambda x: sort_context(
x, tokenizer, max_context_tokens=max_context_tokens
),
community_df[schemas.ALL_CONTEXT],
)
)
@ -141,13 +149,13 @@ def parallel_sort_context_batch(community_df, max_context_tokens, parallel=False
# Assign context strings directly to the DataFrame
community_df[schemas.CONTEXT_STRING] = community_df[schemas.ALL_CONTEXT].apply(
lambda context_list: sort_context(
context_list, max_context_tokens=max_context_tokens
context_list, tokenizer, max_context_tokens=max_context_tokens
)
)
# Calculate other columns
community_df[schemas.CONTEXT_SIZE] = community_df[schemas.CONTEXT_STRING].apply(
num_tokens
tokenizer.num_tokens
)
community_df[schemas.CONTEXT_EXCEED_FLAG] = (
community_df[schemas.CONTEXT_SIZE] > max_context_tokens

View File

@ -23,6 +23,7 @@ from graphrag.index.operations.summarize_communities.utils import (
)
from graphrag.index.utils.derive_from_rows import derive_from_rows
from graphrag.logger.progress import progress_ticker
from graphrag.tokenizer.tokenizer import Tokenizer
logger = logging.getLogger(__name__)
@ -35,6 +36,7 @@ async def summarize_communities(
callbacks: WorkflowCallbacks,
cache: PipelineCache,
strategy: dict,
tokenizer: Tokenizer,
max_input_length: int,
async_mode: AsyncType = AsyncType.AsyncIO,
num_threads: int = 4,
@ -44,7 +46,6 @@ async def summarize_communities(
tick = progress_ticker(callbacks.progress, len(local_contexts))
strategy_exec = load_strategy(strategy["type"])
strategy_config = {**strategy}
community_hierarchy = (
communities.explode("children")
.rename({"children": "sub_community"}, axis=1)
@ -60,6 +61,7 @@ async def summarize_communities(
community_hierarchy_df=community_hierarchy,
local_context_df=local_contexts,
level=level,
tokenizer=tokenizer,
max_context_tokens=max_input_length,
)
level_contexts.append(level_context)

View File

@ -18,7 +18,7 @@ from graphrag.index.operations.summarize_communities.text_unit_context.prep_text
from graphrag.index.operations.summarize_communities.text_unit_context.sort_context import (
sort_context,
)
from graphrag.query.llm.text_utils import num_tokens
from graphrag.tokenizer.tokenizer import Tokenizer
logger = logging.getLogger(__name__)
@ -27,6 +27,7 @@ def build_local_context(
community_membership_df: pd.DataFrame,
text_units_df: pd.DataFrame,
node_df: pd.DataFrame,
tokenizer: Tokenizer,
max_context_tokens: int = 16000,
) -> pd.DataFrame:
"""
@ -69,10 +70,10 @@ def build_local_context(
.reset_index()
)
context_df[schemas.CONTEXT_STRING] = context_df[schemas.ALL_CONTEXT].apply(
lambda x: sort_context(x)
lambda x: sort_context(x, tokenizer)
)
context_df[schemas.CONTEXT_SIZE] = context_df[schemas.CONTEXT_STRING].apply(
lambda x: num_tokens(x)
lambda x: tokenizer.num_tokens(x)
)
context_df[schemas.CONTEXT_EXCEED_FLAG] = context_df[schemas.CONTEXT_SIZE].apply(
lambda x: x > max_context_tokens
@ -86,6 +87,7 @@ def build_level_context(
community_hierarchy_df: pd.DataFrame,
local_context_df: pd.DataFrame,
level: int,
tokenizer: Tokenizer,
max_context_tokens: int = 16000,
) -> pd.DataFrame:
"""
@ -116,10 +118,12 @@ def build_level_context(
invalid_context_df.loc[:, [schemas.CONTEXT_STRING]] = invalid_context_df[
schemas.ALL_CONTEXT
].apply(lambda x: sort_context(x, max_context_tokens=max_context_tokens))
].apply(
lambda x: sort_context(x, tokenizer, max_context_tokens=max_context_tokens)
)
invalid_context_df.loc[:, [schemas.CONTEXT_SIZE]] = invalid_context_df[
schemas.CONTEXT_STRING
].apply(lambda x: num_tokens(x))
].apply(lambda x: tokenizer.num_tokens(x))
invalid_context_df.loc[:, [schemas.CONTEXT_EXCEED_FLAG]] = False
return pd.concat([valid_context_df, invalid_context_df])
@ -199,10 +203,10 @@ def build_level_context(
.reset_index()
)
community_df[schemas.CONTEXT_STRING] = community_df[schemas.ALL_CONTEXT].apply(
lambda x: build_mixed_context(x, max_context_tokens)
lambda x: build_mixed_context(x, tokenizer, max_context_tokens)
)
community_df[schemas.CONTEXT_SIZE] = community_df[schemas.CONTEXT_STRING].apply(
lambda x: num_tokens(x)
lambda x: tokenizer.num_tokens(x)
)
community_df[schemas.CONTEXT_EXCEED_FLAG] = False
community_df[schemas.COMMUNITY_LEVEL] = level
@ -220,10 +224,10 @@ def build_level_context(
)
remaining_df[schemas.CONTEXT_STRING] = cast(
"pd.DataFrame", remaining_df[schemas.ALL_CONTEXT]
).apply(lambda x: sort_context(x, max_context_tokens=max_context_tokens))
).apply(lambda x: sort_context(x, tokenizer, max_context_tokens=max_context_tokens))
remaining_df[schemas.CONTEXT_SIZE] = cast(
"pd.DataFrame", remaining_df[schemas.CONTEXT_STRING]
).apply(lambda x: num_tokens(x))
).apply(lambda x: tokenizer.num_tokens(x))
remaining_df[schemas.CONTEXT_EXCEED_FLAG] = False
return cast(

View File

@ -8,7 +8,7 @@ import logging
import pandas as pd
import graphrag.data_model.schemas as schemas
from graphrag.query.llm.text_utils import num_tokens
from graphrag.tokenizer.tokenizer import Tokenizer
logger = logging.getLogger(__name__)
@ -57,6 +57,7 @@ def get_context_string(
def sort_context(
local_context: list[dict],
tokenizer: Tokenizer,
sub_community_reports: list[dict] | None = None,
max_context_tokens: int | None = None,
) -> str:
@ -73,7 +74,7 @@ def sort_context(
new_context_string = get_context_string(
current_text_units, sub_community_reports
)
if num_tokens(new_context_string) > max_context_tokens:
if tokenizer.num_tokens(new_context_string) > max_context_tokens:
break
context_string = new_context_string

View File

@ -7,9 +7,9 @@ import json
from dataclasses import dataclass
from graphrag.index.typing.error_handler import ErrorHandlerFn
from graphrag.index.utils.tokens import num_tokens_from_string
from graphrag.language_model.protocol.base import ChatModel
from graphrag.prompts.index.summarize_descriptions import SUMMARIZE_PROMPT
from graphrag.tokenizer.get_tokenizer import get_tokenizer
# these tokens are used in the prompt
ENTITY_NAME_KEY = "entity_name"
@ -45,7 +45,7 @@ class SummarizeExtractor:
"""Init method definition."""
# TODO: streamline construction
self._model = model_invoker
self._tokenizer = get_tokenizer(model_invoker.config)
self._summarization_prompt = summarization_prompt or SUMMARIZE_PROMPT
self._on_error = on_error or (lambda _e, _s, _d: None)
self._max_summary_length = max_summary_length
@ -85,14 +85,14 @@ class SummarizeExtractor:
descriptions = sorted(descriptions)
# Iterate over descriptions, adding all until the max input tokens is reached
usable_tokens = self._max_input_tokens - num_tokens_from_string(
usable_tokens = self._max_input_tokens - self._tokenizer.num_tokens(
self._summarization_prompt
)
descriptions_collected = []
result = ""
for i, description in enumerate(descriptions):
usable_tokens -= num_tokens_from_string(description)
usable_tokens -= self._tokenizer.num_tokens(description)
descriptions_collected.append(description)
# If buffer is full, or all descriptions have been added, summarize
@ -109,8 +109,8 @@ class SummarizeExtractor:
descriptions_collected = [result]
usable_tokens = (
self._max_input_tokens
- num_tokens_from_string(self._summarization_prompt)
- num_tokens_from_string(result)
- self._tokenizer.num_tokens(self._summarization_prompt)
- self._tokenizer.num_tokens(result)
)
return result

View File

@ -94,7 +94,7 @@ class TokenTextSplitter(TextSplitter):
def num_tokens(self, text: str) -> int:
"""Return the number of tokens in a string."""
return len(self._tokenizer.encode(text))
return self._tokenizer.num_tokens(text)
def split_text(self, text: str | list[str]) -> list[str]:
"""Split text method."""

View File

@ -1,44 +0,0 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""Utilities for working with tokens."""
import logging
import tiktoken
import graphrag.config.defaults as defs
DEFAULT_ENCODING_NAME = defs.ENCODING_MODEL
logger = logging.getLogger(__name__)
def num_tokens_from_string(
string: str, model: str | None = None, encoding_name: str | None = None
) -> int:
"""Return the number of tokens in a text string."""
if model is not None:
try:
encoding = tiktoken.encoding_for_model(model)
except KeyError:
msg = f"Failed to get encoding for {model} when getting num_tokens_from_string. Fall back to default encoding {DEFAULT_ENCODING_NAME}"
logger.warning(msg)
encoding = tiktoken.get_encoding(DEFAULT_ENCODING_NAME)
else:
encoding = tiktoken.get_encoding(encoding_name or DEFAULT_ENCODING_NAME)
return len(encoding.encode(string))
def string_from_tokens(
tokens: list[int], model: str | None = None, encoding_name: str | None = None
) -> str:
"""Return a text string from a list of tokens."""
if model is not None:
encoding = tiktoken.encoding_for_model(model)
elif encoding_name is not None:
encoding = tiktoken.get_encoding(encoding_name)
else:
msg = "Either model or encoding_name must be specified."
raise ValueError(msg)
return encoding.decode(tokens)

View File

@ -13,6 +13,7 @@ from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
from graphrag.config.defaults import graphrag_config_defaults
from graphrag.config.enums import AsyncType
from graphrag.config.models.graph_rag_config import GraphRagConfig
from graphrag.config.models.language_model_config import LanguageModelConfig
from graphrag.index.operations.finalize_community_reports import (
finalize_community_reports,
)
@ -28,6 +29,7 @@ from graphrag.index.operations.summarize_communities.summarize_communities impor
)
from graphrag.index.typing.context import PipelineRunContext
from graphrag.index.typing.workflow import WorkflowFunctionOutput
from graphrag.tokenizer.get_tokenizer import get_tokenizer
from graphrag.utils.storage import (
load_table_from_storage,
storage_has_table,
@ -102,6 +104,9 @@ async def create_community_reports(
summarization_strategy["extraction_prompt"] = summarization_strategy["graph_prompt"]
model_config = LanguageModelConfig(**summarization_strategy["llm"])
tokenizer = get_tokenizer(model_config)
max_input_length = summarization_strategy.get(
"max_input_length", graphrag_config_defaults.community_reports.max_input_length
)
@ -110,6 +115,7 @@ async def create_community_reports(
nodes,
edges,
claims,
tokenizer,
callbacks,
max_input_length,
)
@ -122,6 +128,7 @@ async def create_community_reports(
callbacks,
cache,
summarization_strategy,
tokenizer=tokenizer,
max_input_length=max_input_length,
async_mode=async_mode,
num_threads=num_threads,

View File

@ -12,6 +12,7 @@ from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
from graphrag.config.defaults import graphrag_config_defaults
from graphrag.config.enums import AsyncType
from graphrag.config.models.graph_rag_config import GraphRagConfig
from graphrag.config.models.language_model_config import LanguageModelConfig
from graphrag.index.operations.finalize_community_reports import (
finalize_community_reports,
)
@ -27,6 +28,7 @@ from graphrag.index.operations.summarize_communities.text_unit_context.context_b
)
from graphrag.index.typing.context import PipelineRunContext
from graphrag.index.typing.workflow import WorkflowFunctionOutput
from graphrag.tokenizer.get_tokenizer import get_tokenizer
from graphrag.utils.storage import load_table_from_storage, write_table_to_storage
logger = logging.getLogger(__name__)
@ -88,8 +90,11 @@ async def create_community_reports_text(
"max_input_length", graphrag_config_defaults.community_reports.max_input_length
)
model_config = LanguageModelConfig(**summarization_strategy["llm"])
tokenizer = get_tokenizer(model_config)
local_contexts = build_local_context(
communities, text_units, nodes, max_input_length
communities, text_units, nodes, tokenizer, max_input_length
)
community_reports = await summarize_communities(
@ -100,6 +105,7 @@ async def create_community_reports_text(
callbacks,
cache,
summarization_strategy,
tokenizer=tokenizer,
max_input_length=max_input_length,
async_mode=async_mode,
num_threads=num_threads,

View File

@ -56,9 +56,7 @@ class ProgressTicker:
description=self._description,
)
if p.description:
logger.info(
"%s%s/%s", p.description, str(p.completed_items), str(p.total_items)
)
logger.info("%s%s/%s", p.description, p.completed_items, p.total_items)
self._callback(p)
def done(self) -> None:

View File

@ -9,21 +9,15 @@ import re
from collections.abc import Iterator
from itertools import islice
import tiktoken
from json_repair import repair_json
import graphrag.config.defaults as defs
from graphrag.tokenizer.get_tokenizer import get_tokenizer
from graphrag.tokenizer.tokenizer import Tokenizer
logger = logging.getLogger(__name__)
def num_tokens(text: str, token_encoder: tiktoken.Encoding | None = None) -> int:
"""Return the number of tokens in the given text."""
if token_encoder is None:
token_encoder = tiktoken.get_encoding(defs.ENCODING_MODEL)
return len(token_encoder.encode(text)) # type: ignore
def batched(iterable: Iterator, n: int):
"""
Batch data into tuples of length n. The last batch may be shorter.
@ -39,15 +33,13 @@ def batched(iterable: Iterator, n: int):
yield batch
def chunk_text(
text: str, max_tokens: int, token_encoder: tiktoken.Encoding | None = None
):
def chunk_text(text: str, max_tokens: int, tokenizer: Tokenizer | None = None):
"""Chunk text by token length."""
if token_encoder is None:
token_encoder = tiktoken.get_encoding(defs.ENCODING_MODEL)
tokens = token_encoder.encode(text) # type: ignore
if tokenizer is None:
tokenizer = get_tokenizer(encoding_model=defs.ENCODING_MODEL)
tokens = tokenizer.encode(text) # type: ignore
chunk_iterator = batched(iter(tokens), max_tokens)
yield from (token_encoder.decode(list(chunk)) for chunk in chunk_iterator)
yield from (tokenizer.decode(list(chunk)) for chunk in chunk_iterator)
def try_parse_json_object(input: str, verbose: bool = True) -> tuple[str, dict]:

View File

@ -7,13 +7,13 @@ from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any
import tiktoken
from graphrag.language_model.protocol.base import ChatModel
from graphrag.query.context_builder.builders import (
GlobalContextBuilder,
LocalContextBuilder,
)
from graphrag.tokenizer.get_tokenizer import get_tokenizer
from graphrag.tokenizer.tokenizer import Tokenizer
@dataclass
@ -34,13 +34,13 @@ class BaseQuestionGen(ABC):
self,
model: ChatModel,
context_builder: GlobalContextBuilder | LocalContextBuilder,
token_encoder: tiktoken.Encoding | None = None,
tokenizer: Tokenizer | None = None,
model_params: dict[str, Any] | None = None,
context_builder_params: dict[str, Any] | None = None,
):
self.model = model
self.context_builder = context_builder
self.token_encoder = token_encoder
self.tokenizer = tokenizer or get_tokenizer(model.config)
self.model_params = model_params or {}
self.context_builder_params = context_builder_params or {}

View File

@ -7,8 +7,6 @@ import logging
import time
from typing import Any, cast
import tiktoken
from graphrag.callbacks.llm_callbacks import BaseLLMCallback
from graphrag.language_model.protocol.base import ChatModel
from graphrag.prompts.query.question_gen_system_prompt import QUESTION_SYSTEM_PROMPT
@ -19,8 +17,8 @@ from graphrag.query.context_builder.builders import (
from graphrag.query.context_builder.conversation_history import (
ConversationHistory,
)
from graphrag.query.llm.text_utils import num_tokens
from graphrag.query.question_gen.base import BaseQuestionGen, QuestionResult
from graphrag.tokenizer.tokenizer import Tokenizer
logger = logging.getLogger(__name__)
@ -32,7 +30,7 @@ class LocalQuestionGen(BaseQuestionGen):
self,
model: ChatModel,
context_builder: LocalContextBuilder,
token_encoder: tiktoken.Encoding | None = None,
tokenizer: Tokenizer | None = None,
system_prompt: str = QUESTION_SYSTEM_PROMPT,
callbacks: list[BaseLLMCallback] | None = None,
model_params: dict[str, Any] | None = None,
@ -41,7 +39,7 @@ class LocalQuestionGen(BaseQuestionGen):
super().__init__(
model=model,
context_builder=context_builder,
token_encoder=token_encoder,
tokenizer=tokenizer,
model_params=model_params,
context_builder_params=context_builder_params,
)
@ -118,7 +116,7 @@ class LocalQuestionGen(BaseQuestionGen):
},
completion_time=time.time() - start_time,
llm_calls=1,
prompt_tokens=num_tokens(system_prompt, self.token_encoder),
prompt_tokens=self.tokenizer.num_tokens(system_prompt),
)
except Exception:
@ -128,7 +126,7 @@ class LocalQuestionGen(BaseQuestionGen):
context_data=context_records,
completion_time=time.time() - start_time,
llm_calls=1,
prompt_tokens=num_tokens(system_prompt, self.token_encoder),
prompt_tokens=self.tokenizer.num_tokens(system_prompt),
)
async def generate(
@ -201,7 +199,7 @@ class LocalQuestionGen(BaseQuestionGen):
},
completion_time=time.time() - start_time,
llm_calls=1,
prompt_tokens=num_tokens(system_prompt, self.token_encoder),
prompt_tokens=self.tokenizer.num_tokens(system_prompt),
)
except Exception:
@ -211,5 +209,5 @@ class LocalQuestionGen(BaseQuestionGen):
context_data=context_records,
completion_time=time.time() - start_time,
llm_calls=1,
prompt_tokens=num_tokens(system_prompt, self.token_encoder),
prompt_tokens=self.tokenizer.num_tokens(system_prompt),
)

View File

@ -3,19 +3,15 @@
"""Get Tokenizer."""
from typing import TYPE_CHECKING
from graphrag.config.defaults import ENCODING_MODEL
from graphrag.config.models.language_model_config import LanguageModelConfig
from graphrag.tokenizer.litellm_tokenizer import LitellmTokenizer
from graphrag.tokenizer.tiktoken_tokenizer import TiktokenTokenizer
from graphrag.tokenizer.tokenizer import Tokenizer
if TYPE_CHECKING:
from graphrag.config.models.language_model_config import LanguageModelConfig
def get_tokenizer(
model_config: "LanguageModelConfig | None" = None,
model_config: LanguageModelConfig | None = None,
encoding_model: str = ENCODING_MODEL,
) -> Tokenizer:
"""

View File

@ -130,7 +130,7 @@ def get_embedding_store(
embedding_store = VectorStoreFactory().create_vector_store(
vector_store_type=vector_store_type,
vector_store_schema_config=single_embedding_config,
kwargs={**store},
**store,
)
embedding_store.connect(**store)
# If there is only a single index, return the embedding store directly

View File

@ -53,7 +53,7 @@ class VectorStoreFactory:
cls,
vector_store_type: str,
vector_store_schema_config: VectorStoreSchemaConfig,
kwargs: dict,
**kwargs: dict,
) -> BaseVectorStore:
"""Create a vector store object from the provided type.

View File

@ -69,7 +69,7 @@ dependencies = [
"litellm>=1.77.1",
]
[project.optional-dependencies]
[dependency-groups]
dev = [
"coverage>=7.6.9",
"ipykernel>=6.29.5",
@ -239,6 +239,7 @@ ignore = [
"PERF203", # Needs restructuring of errors, we should bail-out on first error
"C901", # needs refactoring to remove cyclomatic complexity
"B008", # Needs to restructure our cli params with Typer into constants
"ASYNC240"
]
[tool.ruff.lint.per-file-ignores]

View File

@ -81,9 +81,7 @@ def test_register_and_create_custom_vector_store():
)
vector_store = VectorStoreFactory.create_vector_store(
vector_store_type="custom",
vector_store_schema_config=VectorStoreSchemaConfig(),
kwargs={},
vector_store_type="custom", vector_store_schema_config=VectorStoreSchemaConfig()
)
assert custom_vector_store_class.called
@ -109,7 +107,6 @@ def test_create_unknown_vector_store():
VectorStoreFactory.create_vector_store(
vector_store_type="unknown",
vector_store_schema_config=VectorStoreSchemaConfig(),
kwargs={},
)
@ -162,7 +159,6 @@ def test_register_class_directly_works():
vector_store = VectorStoreFactory.create_vector_store(
vector_store_type="custom_class",
vector_store_schema_config=VectorStoreSchemaConfig(),
kwargs={"collection_name": "test"},
)
assert isinstance(vector_store, CustomVectorStore)

View File

@ -6,7 +6,7 @@ import platform
from graphrag.index.operations.summarize_communities.graph_context.sort_context import (
sort_context,
)
from graphrag.query.llm.text_utils import num_tokens
from graphrag.tokenizer.get_tokenizer import get_tokenizer
nan = math.nan
@ -204,16 +204,18 @@ context: list[dict] = [
def test_sort_context():
ctx = sort_context(context)
tokenizer = get_tokenizer()
ctx = sort_context(context, tokenizer=tokenizer)
assert ctx is not None, "Context is none"
num = num_tokens(ctx)
num = tokenizer.num_tokens(ctx)
assert num == 828 if platform.system() == "Windows" else 826, (
f"num_tokens is not matched for platform (win = 827, else 826): {num}"
)
def test_sort_context_max_tokens():
ctx = sort_context(context, max_context_tokens=800)
tokenizer = get_tokenizer()
ctx = sort_context(context, tokenizer=tokenizer, max_context_tokens=800)
assert ctx is not None, "Context is none"
num = num_tokens(ctx)
num = tokenizer.num_tokens(ctx)
assert num <= 800, f"num_tokens is not less than or equal to 800: {num}"

View File

@ -91,7 +91,7 @@ You can host Unified Search datasets locally or in a blob.
# Run the app
Install all the dependencies: `uv sync --extra dev`
Install all the dependencies: `uv sync`
Run the project using streamlit: `uv run poe start`

591
uv.lock generated

File diff suppressed because it is too large Load Diff