mirror of
https://github.com/microsoft/graphrag.git
synced 2026-01-13 16:47:20 +08:00
Housekeeping (#2086)
Some checks failed
gh-pages / build (push) Has been cancelled
Python CI / python-ci (ubuntu-latest, 3.10) (push) Has been cancelled
Python CI / python-ci (ubuntu-latest, 3.11) (push) Has been cancelled
Python CI / python-ci (windows-latest, 3.10) (push) Has been cancelled
Python CI / python-ci (windows-latest, 3.11) (push) Has been cancelled
Python Integration Tests / python-ci (ubuntu-latest, 3.10) (push) Has been cancelled
Python Integration Tests / python-ci (windows-latest, 3.10) (push) Has been cancelled
Python Notebook Tests / python-ci (ubuntu-latest, 3.10) (push) Has been cancelled
Python Notebook Tests / python-ci (windows-latest, 3.10) (push) Has been cancelled
Python Publish (pypi) / Upload release to PyPI (push) Has been cancelled
Python Smoke Tests / python-ci (ubuntu-latest, 3.10) (push) Has been cancelled
Python Smoke Tests / python-ci (windows-latest, 3.10) (push) Has been cancelled
Spellcheck / spellcheck (push) Has been cancelled
Some checks failed
gh-pages / build (push) Has been cancelled
Python CI / python-ci (ubuntu-latest, 3.10) (push) Has been cancelled
Python CI / python-ci (ubuntu-latest, 3.11) (push) Has been cancelled
Python CI / python-ci (windows-latest, 3.10) (push) Has been cancelled
Python CI / python-ci (windows-latest, 3.11) (push) Has been cancelled
Python Integration Tests / python-ci (ubuntu-latest, 3.10) (push) Has been cancelled
Python Integration Tests / python-ci (windows-latest, 3.10) (push) Has been cancelled
Python Notebook Tests / python-ci (ubuntu-latest, 3.10) (push) Has been cancelled
Python Notebook Tests / python-ci (windows-latest, 3.10) (push) Has been cancelled
Python Publish (pypi) / Upload release to PyPI (push) Has been cancelled
Python Smoke Tests / python-ci (ubuntu-latest, 3.10) (push) Has been cancelled
Python Smoke Tests / python-ci (windows-latest, 3.10) (push) Has been cancelled
Spellcheck / spellcheck (push) Has been cancelled
* Add deprecation warnings for fnllm and multi-search * Fix dangling token_encoder refs * Fix local_search notebook * Fix global search dynamic notebook * Fix global search notebook * Fix drift notebook * Switch example notebooks to use LiteLLM config * Properly annotate dev deps as a group * Semver * Remove --extra dev * Remove llm_model variable * Ignore ruff ASYNC240 * Add note about expected broken notebook in docs * Fix custom vector store notebook * Push tokenizer throughout
This commit is contained in:
parent
6c86b0a7bb
commit
ac8a7f5eef
2
.github/workflows/gh-pages.yml
vendored
2
.github/workflows/gh-pages.yml
vendored
@ -31,7 +31,7 @@ jobs:
|
||||
|
||||
- name: Install dependencies
|
||||
shell: bash
|
||||
run: uv sync --extra dev
|
||||
run: uv sync
|
||||
|
||||
- name: mkdocs build
|
||||
shell: bash
|
||||
|
||||
2
.github/workflows/python-ci.yml
vendored
2
.github/workflows/python-ci.yml
vendored
@ -67,7 +67,7 @@ jobs:
|
||||
- name: Install dependencies
|
||||
shell: bash
|
||||
run: |
|
||||
uv sync --extra dev
|
||||
uv sync
|
||||
uv pip install gensim
|
||||
|
||||
- name: Check
|
||||
|
||||
@ -67,7 +67,7 @@ jobs:
|
||||
- name: Install dependencies
|
||||
shell: bash
|
||||
run: |
|
||||
uv sync --extra dev
|
||||
uv sync
|
||||
uv pip install gensim
|
||||
|
||||
- name: Build
|
||||
|
||||
2
.github/workflows/python-notebook-tests.yml
vendored
2
.github/workflows/python-notebook-tests.yml
vendored
@ -67,7 +67,7 @@ jobs:
|
||||
- name: Install dependencies
|
||||
shell: bash
|
||||
run: |
|
||||
uv sync --extra dev
|
||||
uv sync
|
||||
uv pip install gensim
|
||||
|
||||
- name: Notebook Test
|
||||
|
||||
2
.github/workflows/python-smoke-tests.yml
vendored
2
.github/workflows/python-smoke-tests.yml
vendored
@ -72,7 +72,7 @@ jobs:
|
||||
- name: Install dependencies
|
||||
shell: bash
|
||||
run: |
|
||||
uv sync --extra dev
|
||||
uv sync
|
||||
uv pip install gensim
|
||||
|
||||
- name: Build
|
||||
|
||||
@ -0,0 +1,4 @@
|
||||
{
|
||||
"type": "patch",
|
||||
"description": "Housekeeping toward 2.7."
|
||||
}
|
||||
@ -11,12 +11,8 @@
|
||||
|
||||
## Install Dependencies
|
||||
```shell
|
||||
# (optional) create virtual environment
|
||||
uv venv --python 3.10
|
||||
source .venv/bin/activate
|
||||
|
||||
# install python dependencies
|
||||
uv sync --extra dev
|
||||
uv sync
|
||||
```
|
||||
|
||||
## Execute the indexing engine
|
||||
|
||||
@ -12,12 +12,8 @@
|
||||
## Install Dependencies
|
||||
|
||||
```sh
|
||||
# (optional) create virtual environment
|
||||
uv venv --python 3.10
|
||||
source .venv/bin/activate
|
||||
|
||||
# install python dependencies
|
||||
uv sync --extra dev
|
||||
uv sync
|
||||
```
|
||||
|
||||
## Execute the Indexing Engine
|
||||
|
||||
@ -67,6 +67,8 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# note that we expect this to fail on the deployed docs because the PROJECT_DIRECTORY is not set to a real location.\n",
|
||||
"# if you run this notebook locally, make sure to point at a location containing your settings.yaml\n",
|
||||
"graphrag_config = load_config(Path(PROJECT_DIRECTORY))"
|
||||
]
|
||||
},
|
||||
|
||||
@ -61,6 +61,7 @@
|
||||
"import numpy as np\n",
|
||||
"import yaml\n",
|
||||
"\n",
|
||||
"from graphrag.config.models.vector_store_schema_config import VectorStoreSchemaConfig\n",
|
||||
"from graphrag.data_model.types import TextEmbedder\n",
|
||||
"\n",
|
||||
"# GraphRAG vector store components\n",
|
||||
@ -147,14 +148,12 @@
|
||||
" self.vectors: dict[str, np.ndarray] = {}\n",
|
||||
" self.connected = False\n",
|
||||
"\n",
|
||||
" print(\n",
|
||||
" f\"🚀 SimpleInMemoryVectorStore initialized for collection: {self.collection_name}\"\n",
|
||||
" )\n",
|
||||
" print(f\"🚀 SimpleInMemoryVectorStore initialized for index: {self.index_name}\")\n",
|
||||
"\n",
|
||||
" def connect(self, **kwargs: Any) -> None:\n",
|
||||
" \"\"\"Connect to the vector storage (no-op for in-memory store).\"\"\"\n",
|
||||
" self.connected = True\n",
|
||||
" print(f\"✅ Connected to in-memory vector store: {self.collection_name}\")\n",
|
||||
" print(f\"✅ Connected to in-memory vector store: {self.index_name}\")\n",
|
||||
"\n",
|
||||
" def load_documents(\n",
|
||||
" self, documents: list[VectorStoreDocument], overwrite: bool = True\n",
|
||||
@ -250,7 +249,7 @@
|
||||
" def get_stats(self) -> dict[str, Any]:\n",
|
||||
" \"\"\"Get statistics about the vector store (custom method).\"\"\"\n",
|
||||
" return {\n",
|
||||
" \"collection_name\": self.collection_name,\n",
|
||||
" \"index_name\": self.index_name,\n",
|
||||
" \"document_count\": len(self.documents),\n",
|
||||
" \"vector_count\": len(self.vectors),\n",
|
||||
" \"connected\": self.connected,\n",
|
||||
@ -353,11 +352,11 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Test creating vector store using the factory\n",
|
||||
"vector_store_config = {\"collection_name\": \"test_collection\"}\n",
|
||||
"schema = VectorStoreSchemaConfig(index_name=\"test_collection\")\n",
|
||||
"\n",
|
||||
"# Create vector store instance using factory\n",
|
||||
"vector_store = VectorStoreFactory.create_vector_store(\n",
|
||||
" CUSTOM_VECTOR_STORE_TYPE, vector_store_config\n",
|
||||
" CUSTOM_VECTOR_STORE_TYPE, vector_store_schema_config=schema\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"print(f\"✅ Created vector store instance: {type(vector_store).__name__}\")\n",
|
||||
@ -486,9 +485,13 @@
|
||||
" print(\"🚀 Simulating GraphRAG pipeline with custom vector store...\\n\")\n",
|
||||
"\n",
|
||||
" # 1. GraphRAG creates vector store using factory\n",
|
||||
" config = {\"collection_name\": \"graphrag_entities\", \"similarity_threshold\": 0.3}\n",
|
||||
" schema = VectorStoreSchemaConfig(index_name=\"graphrag_entities\")\n",
|
||||
"\n",
|
||||
" store = VectorStoreFactory.create_vector_store(CUSTOM_VECTOR_STORE_TYPE, config)\n",
|
||||
" store = VectorStoreFactory.create_vector_store(\n",
|
||||
" CUSTOM_VECTOR_STORE_TYPE,\n",
|
||||
" vector_store_schema_config=schema,\n",
|
||||
" similarity_threshold=0.3,\n",
|
||||
" )\n",
|
||||
" store.connect()\n",
|
||||
"\n",
|
||||
" print(\"✅ Step 1: Vector store created and connected\")\n",
|
||||
@ -549,7 +552,8 @@
|
||||
" # Test 1: Basic functionality\n",
|
||||
" print(\"Test 1: Basic functionality\")\n",
|
||||
" store = VectorStoreFactory.create_vector_store(\n",
|
||||
" CUSTOM_VECTOR_STORE_TYPE, {\"collection_name\": \"test\"}\n",
|
||||
" CUSTOM_VECTOR_STORE_TYPE,\n",
|
||||
" vector_store_schema_config=VectorStoreSchemaConfig(index_name=\"test\"),\n",
|
||||
" )\n",
|
||||
" store.connect()\n",
|
||||
"\n",
|
||||
@ -597,7 +601,8 @@
|
||||
" # Test 5: Error handling\n",
|
||||
" print(\"\\nTest 5: Error handling\")\n",
|
||||
" disconnected_store = VectorStoreFactory.create_vector_store(\n",
|
||||
" CUSTOM_VECTOR_STORE_TYPE, {\"collection_name\": \"test2\"}\n",
|
||||
" CUSTOM_VECTOR_STORE_TYPE,\n",
|
||||
" vector_store_schema_config=VectorStoreSchemaConfig(index_name=\"test2\"),\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" try:\n",
|
||||
@ -653,7 +658,7 @@
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "graphrag-venv (3.10.18)",
|
||||
"display_name": "graphrag",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
@ -667,7 +672,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.18"
|
||||
"version": "3.12.10"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
@ -20,11 +20,11 @@
|
||||
"from pathlib import Path\n",
|
||||
"\n",
|
||||
"import pandas as pd\n",
|
||||
"import tiktoken\n",
|
||||
"\n",
|
||||
"from graphrag.config.enums import ModelType\n",
|
||||
"from graphrag.config.models.drift_search_config import DRIFTSearchConfig\n",
|
||||
"from graphrag.config.models.language_model_config import LanguageModelConfig\n",
|
||||
"from graphrag.config.models.vector_store_schema_config import VectorStoreSchemaConfig\n",
|
||||
"from graphrag.language_model.manager import ModelManager\n",
|
||||
"from graphrag.query.indexer_adapters import (\n",
|
||||
" read_indexer_entities,\n",
|
||||
@ -37,6 +37,7 @@
|
||||
" DRIFTSearchContextBuilder,\n",
|
||||
")\n",
|
||||
"from graphrag.query.structured_search.drift_search.search import DRIFTSearch\n",
|
||||
"from graphrag.tokenizer.get_tokenizer import get_tokenizer\n",
|
||||
"from graphrag.vector_stores.lancedb import LanceDBVectorStore\n",
|
||||
"\n",
|
||||
"INPUT_DIR = \"./inputs/operation dulce\"\n",
|
||||
@ -62,12 +63,16 @@
|
||||
"# load description embeddings to an in-memory lancedb vectorstore\n",
|
||||
"# to connect to a remote db, specify url and port values.\n",
|
||||
"description_embedding_store = LanceDBVectorStore(\n",
|
||||
" collection_name=\"default-entity-description\",\n",
|
||||
" vector_store_schema_config=VectorStoreSchemaConfig(\n",
|
||||
" index_name=\"default-entity-description\"\n",
|
||||
" ),\n",
|
||||
")\n",
|
||||
"description_embedding_store.connect(db_uri=LANCEDB_URI)\n",
|
||||
"\n",
|
||||
"full_content_embedding_store = LanceDBVectorStore(\n",
|
||||
" collection_name=\"default-community-full_content\",\n",
|
||||
" vector_store_schema_config=VectorStoreSchemaConfig(\n",
|
||||
" index_name=\"default-community-full_content\"\n",
|
||||
" )\n",
|
||||
")\n",
|
||||
"full_content_embedding_store.connect(db_uri=LANCEDB_URI)\n",
|
||||
"\n",
|
||||
@ -94,33 +99,33 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"api_key = os.environ[\"GRAPHRAG_API_KEY\"]\n",
|
||||
"llm_model = os.environ[\"GRAPHRAG_LLM_MODEL\"]\n",
|
||||
"embedding_model = os.environ[\"GRAPHRAG_EMBEDDING_MODEL\"]\n",
|
||||
"\n",
|
||||
"chat_config = LanguageModelConfig(\n",
|
||||
" api_key=api_key,\n",
|
||||
" type=ModelType.OpenAIChat,\n",
|
||||
" model=llm_model,\n",
|
||||
" type=ModelType.Chat,\n",
|
||||
" model_provider=\"openai\",\n",
|
||||
" model=\"gpt-4.1\",\n",
|
||||
" max_retries=20,\n",
|
||||
")\n",
|
||||
"chat_model = ModelManager().get_or_create_chat_model(\n",
|
||||
" name=\"local_search\",\n",
|
||||
" model_type=ModelType.OpenAIChat,\n",
|
||||
" model_type=ModelType.Chat,\n",
|
||||
" config=chat_config,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"token_encoder = tiktoken.encoding_for_model(llm_model)\n",
|
||||
"tokenizer = get_tokenizer(chat_config)\n",
|
||||
"\n",
|
||||
"embedding_config = LanguageModelConfig(\n",
|
||||
" api_key=api_key,\n",
|
||||
" type=ModelType.OpenAIEmbedding,\n",
|
||||
" model=embedding_model,\n",
|
||||
" type=ModelType.Embedding,\n",
|
||||
" model_provider=\"openai\",\n",
|
||||
" model=\"text-embedding-3-small\",\n",
|
||||
" max_retries=20,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"text_embedder = ModelManager().get_or_create_embedding_model(\n",
|
||||
" name=\"local_search_embedding\",\n",
|
||||
" model_type=ModelType.OpenAIEmbedding,\n",
|
||||
" model_type=ModelType.Embedding,\n",
|
||||
" config=embedding_config,\n",
|
||||
")"
|
||||
]
|
||||
@ -173,12 +178,12 @@
|
||||
" reports=reports,\n",
|
||||
" entity_text_embeddings=description_embedding_store,\n",
|
||||
" text_units=text_units,\n",
|
||||
" token_encoder=token_encoder,\n",
|
||||
" tokenizer=tokenizer,\n",
|
||||
" config=drift_params,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"search = DRIFTSearch(\n",
|
||||
" model=chat_model, context_builder=context_builder, token_encoder=token_encoder\n",
|
||||
" model=chat_model, context_builder=context_builder, tokenizer=tokenizer\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
@ -212,7 +217,7 @@
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": ".venv",
|
||||
"display_name": "graphrag",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
@ -226,7 +231,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.9"
|
||||
"version": "3.12.10"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
@ -19,7 +19,6 @@
|
||||
"import os\n",
|
||||
"\n",
|
||||
"import pandas as pd\n",
|
||||
"import tiktoken\n",
|
||||
"\n",
|
||||
"from graphrag.config.enums import ModelType\n",
|
||||
"from graphrag.config.models.language_model_config import LanguageModelConfig\n",
|
||||
@ -32,7 +31,8 @@
|
||||
"from graphrag.query.structured_search.global_search.community_context import (\n",
|
||||
" GlobalCommunityContext,\n",
|
||||
")\n",
|
||||
"from graphrag.query.structured_search.global_search.search import GlobalSearch"
|
||||
"from graphrag.query.structured_search.global_search.search import GlobalSearch\n",
|
||||
"from graphrag.tokenizer.get_tokenizer import get_tokenizer"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -58,21 +58,21 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"api_key = os.environ[\"GRAPHRAG_API_KEY\"]\n",
|
||||
"llm_model = os.environ[\"GRAPHRAG_LLM_MODEL\"]\n",
|
||||
"\n",
|
||||
"config = LanguageModelConfig(\n",
|
||||
" api_key=api_key,\n",
|
||||
" type=ModelType.OpenAIChat,\n",
|
||||
" model=llm_model,\n",
|
||||
" type=ModelType.Chat,\n",
|
||||
" model_provider=\"openai\",\n",
|
||||
" model=\"gpt-4.1\",\n",
|
||||
" max_retries=20,\n",
|
||||
")\n",
|
||||
"model = ModelManager().get_or_create_chat_model(\n",
|
||||
" name=\"global_search\",\n",
|
||||
" model_type=ModelType.OpenAIChat,\n",
|
||||
" model_type=ModelType.Chat,\n",
|
||||
" config=config,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"token_encoder = tiktoken.encoding_for_model(llm_model)"
|
||||
"tokenizer = get_tokenizer(config)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -142,7 +142,7 @@
|
||||
" community_reports=reports,\n",
|
||||
" communities=communities,\n",
|
||||
" entities=entities, # default to None if you don't want to use community weights for ranking\n",
|
||||
" token_encoder=token_encoder,\n",
|
||||
" tokenizer=tokenizer,\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
@ -193,7 +193,7 @@
|
||||
"search_engine = GlobalSearch(\n",
|
||||
" model=model,\n",
|
||||
" context_builder=context_builder,\n",
|
||||
" token_encoder=token_encoder,\n",
|
||||
" tokenizer=tokenizer,\n",
|
||||
" max_data_tokens=12_000, # change this based on the token limit you have on your model (if you are using a model with 8k limit, a good setting could be 5000)\n",
|
||||
" map_llm_params=map_llm_params,\n",
|
||||
" reduce_llm_params=reduce_llm_params,\n",
|
||||
@ -241,7 +241,7 @@
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": ".venv",
|
||||
"display_name": "graphrag",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
@ -255,7 +255,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.9"
|
||||
"version": "3.12.10"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
@ -19,7 +19,6 @@
|
||||
"import os\n",
|
||||
"\n",
|
||||
"import pandas as pd\n",
|
||||
"import tiktoken\n",
|
||||
"\n",
|
||||
"from graphrag.config.enums import ModelType\n",
|
||||
"from graphrag.config.models.language_model_config import LanguageModelConfig\n",
|
||||
@ -57,22 +56,24 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from graphrag.tokenizer.get_tokenizer import get_tokenizer\n",
|
||||
"\n",
|
||||
"api_key = os.environ[\"GRAPHRAG_API_KEY\"]\n",
|
||||
"llm_model = os.environ[\"GRAPHRAG_LLM_MODEL\"]\n",
|
||||
"\n",
|
||||
"config = LanguageModelConfig(\n",
|
||||
" api_key=api_key,\n",
|
||||
" type=ModelType.OpenAIChat,\n",
|
||||
" model=llm_model,\n",
|
||||
" type=ModelType.Chat,\n",
|
||||
" model_provider=\"openai\",\n",
|
||||
" model=\"gpt-4.1\",\n",
|
||||
" max_retries=20,\n",
|
||||
")\n",
|
||||
"model = ModelManager().get_or_create_chat_model(\n",
|
||||
" name=\"global_search\",\n",
|
||||
" model_type=ModelType.OpenAIChat,\n",
|
||||
" model_type=ModelType.Chat,\n",
|
||||
" config=config,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"token_encoder = tiktoken.encoding_for_model(llm_model)"
|
||||
"tokenizer = get_tokenizer(config)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -155,11 +156,11 @@
|
||||
" community_reports=reports,\n",
|
||||
" communities=communities,\n",
|
||||
" entities=entities, # default to None if you don't want to use community weights for ranking\n",
|
||||
" token_encoder=token_encoder,\n",
|
||||
" tokenizer=tokenizer,\n",
|
||||
" dynamic_community_selection=True,\n",
|
||||
" dynamic_community_selection_kwargs={\n",
|
||||
" \"model\": model,\n",
|
||||
" \"token_encoder\": token_encoder,\n",
|
||||
" \"tokenizer\": tokenizer,\n",
|
||||
" },\n",
|
||||
")"
|
||||
]
|
||||
@ -211,7 +212,7 @@
|
||||
"search_engine = GlobalSearch(\n",
|
||||
" model=model,\n",
|
||||
" context_builder=context_builder,\n",
|
||||
" token_encoder=token_encoder,\n",
|
||||
" tokenizer=tokenizer,\n",
|
||||
" max_data_tokens=12_000, # change this based on the token limit you have on your model (if you are using a model with 8k limit, a good setting could be 5000)\n",
|
||||
" map_llm_params=map_llm_params,\n",
|
||||
" reduce_llm_params=reduce_llm_params,\n",
|
||||
@ -255,7 +256,7 @@
|
||||
"prompt_tokens = result.prompt_tokens_categories[\"build_context\"]\n",
|
||||
"output_tokens = result.output_tokens_categories[\"build_context\"]\n",
|
||||
"print(\n",
|
||||
" f\"Build context ({llm_model})\\nLLM calls: {llm_calls}. Prompt tokens: {prompt_tokens}. Output tokens: {output_tokens}.\"\n",
|
||||
" f\"Build context LLM calls: {llm_calls}. Prompt tokens: {prompt_tokens}. Output tokens: {output_tokens}.\"\n",
|
||||
")\n",
|
||||
"# inspect number of LLM calls and tokens in map-reduce\n",
|
||||
"llm_calls = result.llm_calls_categories[\"map\"] + result.llm_calls_categories[\"reduce\"]\n",
|
||||
@ -266,14 +267,14 @@
|
||||
" result.output_tokens_categories[\"map\"] + result.output_tokens_categories[\"reduce\"]\n",
|
||||
")\n",
|
||||
"print(\n",
|
||||
" f\"Map-reduce ({llm_model})\\nLLM calls: {llm_calls}. Prompt tokens: {prompt_tokens}. Output tokens: {output_tokens}.\"\n",
|
||||
" f\"Map-reduce LLM calls: {llm_calls}. Prompt tokens: {prompt_tokens}. Output tokens: {output_tokens}.\"\n",
|
||||
")"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": ".venv",
|
||||
"display_name": "graphrag",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
@ -287,7 +288,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.9"
|
||||
"version": "3.12.10"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
@ -19,8 +19,8 @@
|
||||
"import os\n",
|
||||
"\n",
|
||||
"import pandas as pd\n",
|
||||
"import tiktoken\n",
|
||||
"\n",
|
||||
"from graphrag.config.models.vector_store_schema_config import VectorStoreSchemaConfig\n",
|
||||
"from graphrag.query.context_builder.entity_extraction import EntityVectorStoreKey\n",
|
||||
"from graphrag.query.indexer_adapters import (\n",
|
||||
" read_indexer_covariates,\n",
|
||||
@ -102,7 +102,9 @@
|
||||
"# load description embeddings to an in-memory lancedb vectorstore\n",
|
||||
"# to connect to a remote db, specify url and port values.\n",
|
||||
"description_embedding_store = LanceDBVectorStore(\n",
|
||||
" collection_name=\"default-entity-description\",\n",
|
||||
" vector_store_schema_config=VectorStoreSchemaConfig(\n",
|
||||
" index_name=\"default-entity-description\"\n",
|
||||
" )\n",
|
||||
")\n",
|
||||
"description_embedding_store.connect(db_uri=LANCEDB_URI)\n",
|
||||
"\n",
|
||||
@ -195,37 +197,38 @@
|
||||
"from graphrag.config.enums import ModelType\n",
|
||||
"from graphrag.config.models.language_model_config import LanguageModelConfig\n",
|
||||
"from graphrag.language_model.manager import ModelManager\n",
|
||||
"from graphrag.tokenizer.get_tokenizer import get_tokenizer\n",
|
||||
"\n",
|
||||
"api_key = os.environ[\"GRAPHRAG_API_KEY\"]\n",
|
||||
"llm_model = os.environ[\"GRAPHRAG_LLM_MODEL\"]\n",
|
||||
"embedding_model = os.environ[\"GRAPHRAG_EMBEDDING_MODEL\"]\n",
|
||||
"\n",
|
||||
"chat_config = LanguageModelConfig(\n",
|
||||
" api_key=api_key,\n",
|
||||
" type=ModelType.OpenAIChat,\n",
|
||||
" model=llm_model,\n",
|
||||
" type=ModelType.Chat,\n",
|
||||
" model_provider=\"openai\",\n",
|
||||
" model=\"gpt-4.1\",\n",
|
||||
" max_retries=20,\n",
|
||||
")\n",
|
||||
"chat_model = ModelManager().get_or_create_chat_model(\n",
|
||||
" name=\"local_search\",\n",
|
||||
" model_type=ModelType.OpenAIChat,\n",
|
||||
" model_type=ModelType.Chat,\n",
|
||||
" config=chat_config,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"token_encoder = tiktoken.encoding_for_model(llm_model)\n",
|
||||
"\n",
|
||||
"embedding_config = LanguageModelConfig(\n",
|
||||
" api_key=api_key,\n",
|
||||
" type=ModelType.OpenAIEmbedding,\n",
|
||||
" model=embedding_model,\n",
|
||||
" type=ModelType.Embedding,\n",
|
||||
" model_provider=\"openai\",\n",
|
||||
" model=\"text-embedding-3-small\",\n",
|
||||
" max_retries=20,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"text_embedder = ModelManager().get_or_create_embedding_model(\n",
|
||||
" name=\"local_search_embedding\",\n",
|
||||
" model_type=ModelType.OpenAIEmbedding,\n",
|
||||
" model_type=ModelType.Embedding,\n",
|
||||
" config=embedding_config,\n",
|
||||
")"
|
||||
")\n",
|
||||
"\n",
|
||||
"tokenizer = get_tokenizer(chat_config)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -251,7 +254,7 @@
|
||||
" entity_text_embeddings=description_embedding_store,\n",
|
||||
" embedding_vectorstore_key=EntityVectorStoreKey.ID, # if the vectorstore uses entity title as ids, set this to EntityVectorStoreKey.TITLE\n",
|
||||
" text_embedder=text_embedder,\n",
|
||||
" token_encoder=token_encoder,\n",
|
||||
" tokenizer=tokenizer,\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
@ -314,7 +317,7 @@
|
||||
"search_engine = LocalSearch(\n",
|
||||
" model=chat_model,\n",
|
||||
" context_builder=context_builder,\n",
|
||||
" token_encoder=token_encoder,\n",
|
||||
" tokenizer=tokenizer,\n",
|
||||
" model_params=model_params,\n",
|
||||
" context_builder_params=local_context_params,\n",
|
||||
" response_type=\"multiple paragraphs\", # free form text describing the response type and format, can be anything, e.g. prioritized list, single paragraph, multiple paragraphs, multiple-page report\n",
|
||||
@ -426,7 +429,7 @@
|
||||
"question_generator = LocalQuestionGen(\n",
|
||||
" model=chat_model,\n",
|
||||
" context_builder=context_builder,\n",
|
||||
" token_encoder=token_encoder,\n",
|
||||
" tokenizer=tokenizer,\n",
|
||||
" model_params=model_params,\n",
|
||||
" context_builder_params=local_context_params,\n",
|
||||
")"
|
||||
@ -451,7 +454,7 @@
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": ".venv",
|
||||
"display_name": "graphrag",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
@ -465,7 +468,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.9"
|
||||
"version": "3.12.10"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
@ -24,7 +24,7 @@ Below are the key parameters of the [DRIFTSearch class](https://github.com/micro
|
||||
- `llm`: OpenAI model object to be used for response generation
|
||||
- `context_builder`: [context builder](https://github.com/microsoft/graphrag/blob/main/graphrag/query/structured_search/drift_search/drift_context.py) object to be used for preparing context data from community reports and query information
|
||||
- `config`: model to define the DRIFT Search hyperparameters. [DRIFT Config model](https://github.com/microsoft/graphrag/blob/main/graphrag/config/models/drift_search_config.py)
|
||||
- `token_encoder`: token encoder for tracking the budget for the algorithm.
|
||||
- `tokenizer`: token encoder for tracking the budget for the algorithm.
|
||||
- `query_state`: a state object as defined in [Query State](https://github.com/microsoft/graphrag/blob/main/graphrag/query/structured_search/drift_search/state.py) that allows to track execution of a DRIFT Search instance, alongside follow ups and [DRIFT actions](https://github.com/microsoft/graphrag/blob/main/graphrag/query/structured_search/drift_search/action.py).
|
||||
|
||||
## How to Use
|
||||
|
||||
@ -231,6 +231,10 @@ async def multi_index_global_search(
|
||||
"""
|
||||
init_loggers(config=config, verbose=verbose, filename="query.log")
|
||||
|
||||
logger.warning(
|
||||
"Multi-index search is deprecated and will be removed in GraphRAG v3."
|
||||
)
|
||||
|
||||
# Streaming not supported yet
|
||||
if streaming:
|
||||
message = "Streaming not yet implemented for multi_global_search"
|
||||
@ -510,6 +514,9 @@ async def multi_index_local_search(
|
||||
"""
|
||||
init_loggers(config=config, verbose=verbose, filename="query.log")
|
||||
|
||||
logger.warning(
|
||||
"Multi-index search is deprecated and will be removed in GraphRAG v3."
|
||||
)
|
||||
# Streaming not supported yet
|
||||
if streaming:
|
||||
message = "Streaming not yet implemented for multi_index_local_search"
|
||||
@ -874,6 +881,10 @@ async def multi_index_drift_search(
|
||||
"""
|
||||
init_loggers(config=config, verbose=verbose, filename="query.log")
|
||||
|
||||
logger.warning(
|
||||
"Multi-index search is deprecated and will be removed in GraphRAG v3."
|
||||
)
|
||||
|
||||
# Streaming not supported yet
|
||||
if streaming:
|
||||
message = "Streaming not yet implemented for multi_drift_search"
|
||||
@ -1166,6 +1177,10 @@ async def multi_index_basic_search(
|
||||
"""
|
||||
init_loggers(config=config, verbose=verbose, filename="query.log")
|
||||
|
||||
logger.warning(
|
||||
"Multi-index search is deprecated and will be removed in GraphRAG v3."
|
||||
)
|
||||
|
||||
# Streaming not supported yet
|
||||
if streaming:
|
||||
message = "Streaming not yet implemented for multi_basic_search"
|
||||
|
||||
@ -11,7 +11,6 @@ def get_embedding_settings(
|
||||
vector_store_params: dict | None = None,
|
||||
) -> dict:
|
||||
"""Transform GraphRAG config into settings for workflows."""
|
||||
# TEMP
|
||||
embeddings_llm_settings = settings.get_language_model_config(
|
||||
settings.embed_text.model_id
|
||||
)
|
||||
|
||||
@ -98,6 +98,14 @@ class LanguageModelConfig(BaseModel):
|
||||
if not ModelFactory.is_supported_model(self.type):
|
||||
msg = f"Model type {self.type} is not recognized, must be one of {ModelFactory.get_chat_models() + ModelFactory.get_embedding_models()}."
|
||||
raise KeyError(msg)
|
||||
if self.type in [
|
||||
"openai_chat",
|
||||
"openai_embedding",
|
||||
"azure_openai_chat",
|
||||
"azure_openai_embedding",
|
||||
]:
|
||||
msg = f"Model config based on fnllm is deprecated and will be removed in GraphRAG v3, please use {ModelType.Chat} or {ModelType.Embedding} instead to switch to LiteLLM config."
|
||||
logger.warning(msg)
|
||||
|
||||
model_provider: str | None = Field(
|
||||
description="The model provider to use.",
|
||||
|
||||
@ -210,7 +210,7 @@ def _create_vector_store(
|
||||
vector_store = VectorStoreFactory().create_vector_store(
|
||||
vector_store_schema_config=single_embedding_config,
|
||||
vector_store_type=vector_store_type,
|
||||
kwargs=vector_store_config,
|
||||
**vector_store_config,
|
||||
)
|
||||
|
||||
vector_store.connect(**vector_store_config)
|
||||
|
||||
@ -8,10 +8,12 @@ import graphrag.data_model.schemas as schemas
|
||||
from graphrag.index.operations.summarize_communities.graph_context.sort_context import (
|
||||
sort_context,
|
||||
)
|
||||
from graphrag.query.llm.text_utils import num_tokens
|
||||
from graphrag.tokenizer.tokenizer import Tokenizer
|
||||
|
||||
|
||||
def build_mixed_context(context: list[dict], max_context_tokens: int) -> str:
|
||||
def build_mixed_context(
|
||||
context: list[dict], tokenizer: Tokenizer, max_context_tokens: int
|
||||
) -> str:
|
||||
"""
|
||||
Build parent context by concatenating all sub-communities' contexts.
|
||||
|
||||
@ -45,9 +47,10 @@ def build_mixed_context(context: list[dict], max_context_tokens: int) -> str:
|
||||
remaining_local_context.extend(sorted_context[rid][schemas.ALL_CONTEXT])
|
||||
new_context_string = sort_context(
|
||||
local_context=remaining_local_context + final_local_contexts,
|
||||
tokenizer=tokenizer,
|
||||
sub_community_reports=substitute_reports,
|
||||
)
|
||||
if num_tokens(new_context_string) <= max_context_tokens:
|
||||
if tokenizer.num_tokens(new_context_string) <= max_context_tokens:
|
||||
exceeded_limit = False
|
||||
context_string = new_context_string
|
||||
break
|
||||
@ -63,7 +66,7 @@ def build_mixed_context(context: list[dict], max_context_tokens: int) -> str:
|
||||
new_context_string = pd.DataFrame(substitute_reports).to_csv(
|
||||
index=False, sep=","
|
||||
)
|
||||
if num_tokens(new_context_string) > max_context_tokens:
|
||||
if tokenizer.num_tokens(new_context_string) > max_context_tokens:
|
||||
break
|
||||
|
||||
context_string = new_context_string
|
||||
|
||||
@ -30,7 +30,7 @@ from graphrag.index.utils.dataframes import (
|
||||
where_column_equals,
|
||||
)
|
||||
from graphrag.logger.progress import progress_iterable
|
||||
from graphrag.query.llm.text_utils import num_tokens
|
||||
from graphrag.tokenizer.tokenizer import Tokenizer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -39,6 +39,7 @@ def build_local_context(
|
||||
nodes,
|
||||
edges,
|
||||
claims,
|
||||
tokenizer: Tokenizer,
|
||||
callbacks: WorkflowCallbacks,
|
||||
max_context_tokens: int = 16_000,
|
||||
):
|
||||
@ -49,7 +50,7 @@ def build_local_context(
|
||||
|
||||
for level in progress_iterable(levels, callbacks.progress, len(levels)):
|
||||
communities_at_level_df = _prepare_reports_at_level(
|
||||
nodes, edges, claims, level, max_context_tokens
|
||||
nodes, edges, claims, tokenizer, level, max_context_tokens
|
||||
)
|
||||
|
||||
communities_at_level_df.loc[:, schemas.COMMUNITY_LEVEL] = level
|
||||
@ -63,6 +64,7 @@ def _prepare_reports_at_level(
|
||||
node_df: pd.DataFrame,
|
||||
edge_df: pd.DataFrame,
|
||||
claim_df: pd.DataFrame | None,
|
||||
tokenizer: Tokenizer,
|
||||
level: int,
|
||||
max_context_tokens: int = 16_000,
|
||||
) -> pd.DataFrame:
|
||||
@ -181,6 +183,7 @@ def _prepare_reports_at_level(
|
||||
# Generate community-level context strings using vectorized batch processing
|
||||
return parallel_sort_context_batch(
|
||||
community_df,
|
||||
tokenizer=tokenizer,
|
||||
max_context_tokens=max_context_tokens,
|
||||
)
|
||||
|
||||
@ -189,6 +192,7 @@ def build_level_context(
|
||||
report_df: pd.DataFrame | None,
|
||||
community_hierarchy_df: pd.DataFrame,
|
||||
local_context_df: pd.DataFrame,
|
||||
tokenizer: Tokenizer,
|
||||
level: int,
|
||||
max_context_tokens: int,
|
||||
) -> pd.DataFrame:
|
||||
@ -219,11 +223,11 @@ def build_level_context(
|
||||
|
||||
if report_df is None or report_df.empty:
|
||||
invalid_context_df.loc[:, schemas.CONTEXT_STRING] = _sort_and_trim_context(
|
||||
invalid_context_df, max_context_tokens
|
||||
invalid_context_df, tokenizer, max_context_tokens
|
||||
)
|
||||
invalid_context_df[schemas.CONTEXT_SIZE] = invalid_context_df.loc[
|
||||
:, schemas.CONTEXT_STRING
|
||||
].map(num_tokens)
|
||||
].map(tokenizer.num_tokens)
|
||||
invalid_context_df[schemas.CONTEXT_EXCEED_FLAG] = False
|
||||
return union(valid_context_df, invalid_context_df)
|
||||
|
||||
@ -237,6 +241,7 @@ def build_level_context(
|
||||
invalid_context_df,
|
||||
sub_context_df,
|
||||
community_hierarchy_df,
|
||||
tokenizer,
|
||||
max_context_tokens,
|
||||
)
|
||||
|
||||
@ -244,11 +249,13 @@ def build_level_context(
|
||||
# this should be rare, but if it happens, we will just trim the local context to fit the limit
|
||||
remaining_df = _antijoin_reports(invalid_context_df, community_df)
|
||||
remaining_df.loc[:, schemas.CONTEXT_STRING] = _sort_and_trim_context(
|
||||
remaining_df, max_context_tokens
|
||||
remaining_df, tokenizer, max_context_tokens
|
||||
)
|
||||
|
||||
result = union(valid_context_df, community_df, remaining_df)
|
||||
result[schemas.CONTEXT_SIZE] = result.loc[:, schemas.CONTEXT_STRING].map(num_tokens)
|
||||
result[schemas.CONTEXT_SIZE] = result.loc[:, schemas.CONTEXT_STRING].map(
|
||||
tokenizer.num_tokens
|
||||
)
|
||||
|
||||
result[schemas.CONTEXT_EXCEED_FLAG] = False
|
||||
return result
|
||||
@ -269,19 +276,29 @@ def _antijoin_reports(df: pd.DataFrame, reports: pd.DataFrame) -> pd.DataFrame:
|
||||
return antijoin(df, reports, schemas.COMMUNITY_ID)
|
||||
|
||||
|
||||
def _sort_and_trim_context(df: pd.DataFrame, max_context_tokens: int) -> pd.Series:
|
||||
def _sort_and_trim_context(
|
||||
df: pd.DataFrame, tokenizer: Tokenizer, max_context_tokens: int
|
||||
) -> pd.Series:
|
||||
"""Sort and trim context to fit the limit."""
|
||||
series = cast("pd.Series", df[schemas.ALL_CONTEXT])
|
||||
return transform_series(
|
||||
series, lambda x: sort_context(x, max_context_tokens=max_context_tokens)
|
||||
series,
|
||||
lambda x: sort_context(
|
||||
x, tokenizer=tokenizer, max_context_tokens=max_context_tokens
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _build_mixed_context(df: pd.DataFrame, max_context_tokens: int) -> pd.Series:
|
||||
def _build_mixed_context(
|
||||
df: pd.DataFrame, tokenizer: Tokenizer, max_context_tokens: int
|
||||
) -> pd.Series:
|
||||
"""Sort and trim context to fit the limit."""
|
||||
series = cast("pd.Series", df[schemas.ALL_CONTEXT])
|
||||
return transform_series(
|
||||
series, lambda x: build_mixed_context(x, max_context_tokens=max_context_tokens)
|
||||
series,
|
||||
lambda x: build_mixed_context(
|
||||
x, tokenizer, max_context_tokens=max_context_tokens
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@ -303,6 +320,7 @@ def _get_community_df(
|
||||
invalid_context_df: pd.DataFrame,
|
||||
sub_context_df: pd.DataFrame,
|
||||
community_hierarchy_df: pd.DataFrame,
|
||||
tokenizer: Tokenizer,
|
||||
max_context_tokens: int,
|
||||
) -> pd.DataFrame:
|
||||
"""Get community context for each community."""
|
||||
@ -338,7 +356,7 @@ def _get_community_df(
|
||||
.reset_index()
|
||||
)
|
||||
community_df[schemas.CONTEXT_STRING] = _build_mixed_context(
|
||||
community_df, max_context_tokens
|
||||
community_df, tokenizer, max_context_tokens
|
||||
)
|
||||
community_df[schemas.COMMUNITY_LEVEL] = level
|
||||
return community_df
|
||||
|
||||
@ -5,11 +5,12 @@
|
||||
import pandas as pd
|
||||
|
||||
import graphrag.data_model.schemas as schemas
|
||||
from graphrag.query.llm.text_utils import num_tokens
|
||||
from graphrag.tokenizer.tokenizer import Tokenizer
|
||||
|
||||
|
||||
def sort_context(
|
||||
local_context: list[dict],
|
||||
tokenizer: Tokenizer,
|
||||
sub_community_reports: list[dict] | None = None,
|
||||
max_context_tokens: int | None = None,
|
||||
node_name_column: str = schemas.TITLE,
|
||||
@ -112,7 +113,10 @@ def sort_context(
|
||||
new_context_string = _get_context_string(
|
||||
sorted_nodes, sorted_edges, sorted_claims, sub_community_reports
|
||||
)
|
||||
if max_context_tokens and num_tokens(new_context_string) > max_context_tokens:
|
||||
if (
|
||||
max_context_tokens
|
||||
and tokenizer.num_tokens(new_context_string) > max_context_tokens
|
||||
):
|
||||
break
|
||||
context_string = new_context_string
|
||||
|
||||
@ -122,7 +126,9 @@ def sort_context(
|
||||
)
|
||||
|
||||
|
||||
def parallel_sort_context_batch(community_df, max_context_tokens, parallel=False):
|
||||
def parallel_sort_context_batch(
|
||||
community_df, tokenizer: Tokenizer, max_context_tokens, parallel=False
|
||||
):
|
||||
"""Calculate context using parallelization if enabled."""
|
||||
if parallel:
|
||||
# Use ThreadPoolExecutor for parallel execution
|
||||
@ -131,7 +137,9 @@ def parallel_sort_context_batch(community_df, max_context_tokens, parallel=False
|
||||
with ThreadPoolExecutor(max_workers=None) as executor:
|
||||
context_strings = list(
|
||||
executor.map(
|
||||
lambda x: sort_context(x, max_context_tokens=max_context_tokens),
|
||||
lambda x: sort_context(
|
||||
x, tokenizer, max_context_tokens=max_context_tokens
|
||||
),
|
||||
community_df[schemas.ALL_CONTEXT],
|
||||
)
|
||||
)
|
||||
@ -141,13 +149,13 @@ def parallel_sort_context_batch(community_df, max_context_tokens, parallel=False
|
||||
# Assign context strings directly to the DataFrame
|
||||
community_df[schemas.CONTEXT_STRING] = community_df[schemas.ALL_CONTEXT].apply(
|
||||
lambda context_list: sort_context(
|
||||
context_list, max_context_tokens=max_context_tokens
|
||||
context_list, tokenizer, max_context_tokens=max_context_tokens
|
||||
)
|
||||
)
|
||||
|
||||
# Calculate other columns
|
||||
community_df[schemas.CONTEXT_SIZE] = community_df[schemas.CONTEXT_STRING].apply(
|
||||
num_tokens
|
||||
tokenizer.num_tokens
|
||||
)
|
||||
community_df[schemas.CONTEXT_EXCEED_FLAG] = (
|
||||
community_df[schemas.CONTEXT_SIZE] > max_context_tokens
|
||||
|
||||
@ -23,6 +23,7 @@ from graphrag.index.operations.summarize_communities.utils import (
|
||||
)
|
||||
from graphrag.index.utils.derive_from_rows import derive_from_rows
|
||||
from graphrag.logger.progress import progress_ticker
|
||||
from graphrag.tokenizer.tokenizer import Tokenizer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -35,6 +36,7 @@ async def summarize_communities(
|
||||
callbacks: WorkflowCallbacks,
|
||||
cache: PipelineCache,
|
||||
strategy: dict,
|
||||
tokenizer: Tokenizer,
|
||||
max_input_length: int,
|
||||
async_mode: AsyncType = AsyncType.AsyncIO,
|
||||
num_threads: int = 4,
|
||||
@ -44,7 +46,6 @@ async def summarize_communities(
|
||||
tick = progress_ticker(callbacks.progress, len(local_contexts))
|
||||
strategy_exec = load_strategy(strategy["type"])
|
||||
strategy_config = {**strategy}
|
||||
|
||||
community_hierarchy = (
|
||||
communities.explode("children")
|
||||
.rename({"children": "sub_community"}, axis=1)
|
||||
@ -60,6 +61,7 @@ async def summarize_communities(
|
||||
community_hierarchy_df=community_hierarchy,
|
||||
local_context_df=local_contexts,
|
||||
level=level,
|
||||
tokenizer=tokenizer,
|
||||
max_context_tokens=max_input_length,
|
||||
)
|
||||
level_contexts.append(level_context)
|
||||
|
||||
@ -18,7 +18,7 @@ from graphrag.index.operations.summarize_communities.text_unit_context.prep_text
|
||||
from graphrag.index.operations.summarize_communities.text_unit_context.sort_context import (
|
||||
sort_context,
|
||||
)
|
||||
from graphrag.query.llm.text_utils import num_tokens
|
||||
from graphrag.tokenizer.tokenizer import Tokenizer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -27,6 +27,7 @@ def build_local_context(
|
||||
community_membership_df: pd.DataFrame,
|
||||
text_units_df: pd.DataFrame,
|
||||
node_df: pd.DataFrame,
|
||||
tokenizer: Tokenizer,
|
||||
max_context_tokens: int = 16000,
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
@ -69,10 +70,10 @@ def build_local_context(
|
||||
.reset_index()
|
||||
)
|
||||
context_df[schemas.CONTEXT_STRING] = context_df[schemas.ALL_CONTEXT].apply(
|
||||
lambda x: sort_context(x)
|
||||
lambda x: sort_context(x, tokenizer)
|
||||
)
|
||||
context_df[schemas.CONTEXT_SIZE] = context_df[schemas.CONTEXT_STRING].apply(
|
||||
lambda x: num_tokens(x)
|
||||
lambda x: tokenizer.num_tokens(x)
|
||||
)
|
||||
context_df[schemas.CONTEXT_EXCEED_FLAG] = context_df[schemas.CONTEXT_SIZE].apply(
|
||||
lambda x: x > max_context_tokens
|
||||
@ -86,6 +87,7 @@ def build_level_context(
|
||||
community_hierarchy_df: pd.DataFrame,
|
||||
local_context_df: pd.DataFrame,
|
||||
level: int,
|
||||
tokenizer: Tokenizer,
|
||||
max_context_tokens: int = 16000,
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
@ -116,10 +118,12 @@ def build_level_context(
|
||||
|
||||
invalid_context_df.loc[:, [schemas.CONTEXT_STRING]] = invalid_context_df[
|
||||
schemas.ALL_CONTEXT
|
||||
].apply(lambda x: sort_context(x, max_context_tokens=max_context_tokens))
|
||||
].apply(
|
||||
lambda x: sort_context(x, tokenizer, max_context_tokens=max_context_tokens)
|
||||
)
|
||||
invalid_context_df.loc[:, [schemas.CONTEXT_SIZE]] = invalid_context_df[
|
||||
schemas.CONTEXT_STRING
|
||||
].apply(lambda x: num_tokens(x))
|
||||
].apply(lambda x: tokenizer.num_tokens(x))
|
||||
invalid_context_df.loc[:, [schemas.CONTEXT_EXCEED_FLAG]] = False
|
||||
|
||||
return pd.concat([valid_context_df, invalid_context_df])
|
||||
@ -199,10 +203,10 @@ def build_level_context(
|
||||
.reset_index()
|
||||
)
|
||||
community_df[schemas.CONTEXT_STRING] = community_df[schemas.ALL_CONTEXT].apply(
|
||||
lambda x: build_mixed_context(x, max_context_tokens)
|
||||
lambda x: build_mixed_context(x, tokenizer, max_context_tokens)
|
||||
)
|
||||
community_df[schemas.CONTEXT_SIZE] = community_df[schemas.CONTEXT_STRING].apply(
|
||||
lambda x: num_tokens(x)
|
||||
lambda x: tokenizer.num_tokens(x)
|
||||
)
|
||||
community_df[schemas.CONTEXT_EXCEED_FLAG] = False
|
||||
community_df[schemas.COMMUNITY_LEVEL] = level
|
||||
@ -220,10 +224,10 @@ def build_level_context(
|
||||
)
|
||||
remaining_df[schemas.CONTEXT_STRING] = cast(
|
||||
"pd.DataFrame", remaining_df[schemas.ALL_CONTEXT]
|
||||
).apply(lambda x: sort_context(x, max_context_tokens=max_context_tokens))
|
||||
).apply(lambda x: sort_context(x, tokenizer, max_context_tokens=max_context_tokens))
|
||||
remaining_df[schemas.CONTEXT_SIZE] = cast(
|
||||
"pd.DataFrame", remaining_df[schemas.CONTEXT_STRING]
|
||||
).apply(lambda x: num_tokens(x))
|
||||
).apply(lambda x: tokenizer.num_tokens(x))
|
||||
remaining_df[schemas.CONTEXT_EXCEED_FLAG] = False
|
||||
|
||||
return cast(
|
||||
|
||||
@ -8,7 +8,7 @@ import logging
|
||||
import pandas as pd
|
||||
|
||||
import graphrag.data_model.schemas as schemas
|
||||
from graphrag.query.llm.text_utils import num_tokens
|
||||
from graphrag.tokenizer.tokenizer import Tokenizer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -57,6 +57,7 @@ def get_context_string(
|
||||
|
||||
def sort_context(
|
||||
local_context: list[dict],
|
||||
tokenizer: Tokenizer,
|
||||
sub_community_reports: list[dict] | None = None,
|
||||
max_context_tokens: int | None = None,
|
||||
) -> str:
|
||||
@ -73,7 +74,7 @@ def sort_context(
|
||||
new_context_string = get_context_string(
|
||||
current_text_units, sub_community_reports
|
||||
)
|
||||
if num_tokens(new_context_string) > max_context_tokens:
|
||||
if tokenizer.num_tokens(new_context_string) > max_context_tokens:
|
||||
break
|
||||
|
||||
context_string = new_context_string
|
||||
|
||||
@ -7,9 +7,9 @@ import json
|
||||
from dataclasses import dataclass
|
||||
|
||||
from graphrag.index.typing.error_handler import ErrorHandlerFn
|
||||
from graphrag.index.utils.tokens import num_tokens_from_string
|
||||
from graphrag.language_model.protocol.base import ChatModel
|
||||
from graphrag.prompts.index.summarize_descriptions import SUMMARIZE_PROMPT
|
||||
from graphrag.tokenizer.get_tokenizer import get_tokenizer
|
||||
|
||||
# these tokens are used in the prompt
|
||||
ENTITY_NAME_KEY = "entity_name"
|
||||
@ -45,7 +45,7 @@ class SummarizeExtractor:
|
||||
"""Init method definition."""
|
||||
# TODO: streamline construction
|
||||
self._model = model_invoker
|
||||
|
||||
self._tokenizer = get_tokenizer(model_invoker.config)
|
||||
self._summarization_prompt = summarization_prompt or SUMMARIZE_PROMPT
|
||||
self._on_error = on_error or (lambda _e, _s, _d: None)
|
||||
self._max_summary_length = max_summary_length
|
||||
@ -85,14 +85,14 @@ class SummarizeExtractor:
|
||||
descriptions = sorted(descriptions)
|
||||
|
||||
# Iterate over descriptions, adding all until the max input tokens is reached
|
||||
usable_tokens = self._max_input_tokens - num_tokens_from_string(
|
||||
usable_tokens = self._max_input_tokens - self._tokenizer.num_tokens(
|
||||
self._summarization_prompt
|
||||
)
|
||||
descriptions_collected = []
|
||||
result = ""
|
||||
|
||||
for i, description in enumerate(descriptions):
|
||||
usable_tokens -= num_tokens_from_string(description)
|
||||
usable_tokens -= self._tokenizer.num_tokens(description)
|
||||
descriptions_collected.append(description)
|
||||
|
||||
# If buffer is full, or all descriptions have been added, summarize
|
||||
@ -109,8 +109,8 @@ class SummarizeExtractor:
|
||||
descriptions_collected = [result]
|
||||
usable_tokens = (
|
||||
self._max_input_tokens
|
||||
- num_tokens_from_string(self._summarization_prompt)
|
||||
- num_tokens_from_string(result)
|
||||
- self._tokenizer.num_tokens(self._summarization_prompt)
|
||||
- self._tokenizer.num_tokens(result)
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
@ -94,7 +94,7 @@ class TokenTextSplitter(TextSplitter):
|
||||
|
||||
def num_tokens(self, text: str) -> int:
|
||||
"""Return the number of tokens in a string."""
|
||||
return len(self._tokenizer.encode(text))
|
||||
return self._tokenizer.num_tokens(text)
|
||||
|
||||
def split_text(self, text: str | list[str]) -> list[str]:
|
||||
"""Split text method."""
|
||||
|
||||
@ -1,44 +0,0 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""Utilities for working with tokens."""
|
||||
|
||||
import logging
|
||||
|
||||
import tiktoken
|
||||
|
||||
import graphrag.config.defaults as defs
|
||||
|
||||
DEFAULT_ENCODING_NAME = defs.ENCODING_MODEL
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def num_tokens_from_string(
|
||||
string: str, model: str | None = None, encoding_name: str | None = None
|
||||
) -> int:
|
||||
"""Return the number of tokens in a text string."""
|
||||
if model is not None:
|
||||
try:
|
||||
encoding = tiktoken.encoding_for_model(model)
|
||||
except KeyError:
|
||||
msg = f"Failed to get encoding for {model} when getting num_tokens_from_string. Fall back to default encoding {DEFAULT_ENCODING_NAME}"
|
||||
logger.warning(msg)
|
||||
encoding = tiktoken.get_encoding(DEFAULT_ENCODING_NAME)
|
||||
else:
|
||||
encoding = tiktoken.get_encoding(encoding_name or DEFAULT_ENCODING_NAME)
|
||||
return len(encoding.encode(string))
|
||||
|
||||
|
||||
def string_from_tokens(
|
||||
tokens: list[int], model: str | None = None, encoding_name: str | None = None
|
||||
) -> str:
|
||||
"""Return a text string from a list of tokens."""
|
||||
if model is not None:
|
||||
encoding = tiktoken.encoding_for_model(model)
|
||||
elif encoding_name is not None:
|
||||
encoding = tiktoken.get_encoding(encoding_name)
|
||||
else:
|
||||
msg = "Either model or encoding_name must be specified."
|
||||
raise ValueError(msg)
|
||||
return encoding.decode(tokens)
|
||||
@ -13,6 +13,7 @@ from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
|
||||
from graphrag.config.defaults import graphrag_config_defaults
|
||||
from graphrag.config.enums import AsyncType
|
||||
from graphrag.config.models.graph_rag_config import GraphRagConfig
|
||||
from graphrag.config.models.language_model_config import LanguageModelConfig
|
||||
from graphrag.index.operations.finalize_community_reports import (
|
||||
finalize_community_reports,
|
||||
)
|
||||
@ -28,6 +29,7 @@ from graphrag.index.operations.summarize_communities.summarize_communities impor
|
||||
)
|
||||
from graphrag.index.typing.context import PipelineRunContext
|
||||
from graphrag.index.typing.workflow import WorkflowFunctionOutput
|
||||
from graphrag.tokenizer.get_tokenizer import get_tokenizer
|
||||
from graphrag.utils.storage import (
|
||||
load_table_from_storage,
|
||||
storage_has_table,
|
||||
@ -102,6 +104,9 @@ async def create_community_reports(
|
||||
|
||||
summarization_strategy["extraction_prompt"] = summarization_strategy["graph_prompt"]
|
||||
|
||||
model_config = LanguageModelConfig(**summarization_strategy["llm"])
|
||||
tokenizer = get_tokenizer(model_config)
|
||||
|
||||
max_input_length = summarization_strategy.get(
|
||||
"max_input_length", graphrag_config_defaults.community_reports.max_input_length
|
||||
)
|
||||
@ -110,6 +115,7 @@ async def create_community_reports(
|
||||
nodes,
|
||||
edges,
|
||||
claims,
|
||||
tokenizer,
|
||||
callbacks,
|
||||
max_input_length,
|
||||
)
|
||||
@ -122,6 +128,7 @@ async def create_community_reports(
|
||||
callbacks,
|
||||
cache,
|
||||
summarization_strategy,
|
||||
tokenizer=tokenizer,
|
||||
max_input_length=max_input_length,
|
||||
async_mode=async_mode,
|
||||
num_threads=num_threads,
|
||||
|
||||
@ -12,6 +12,7 @@ from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
|
||||
from graphrag.config.defaults import graphrag_config_defaults
|
||||
from graphrag.config.enums import AsyncType
|
||||
from graphrag.config.models.graph_rag_config import GraphRagConfig
|
||||
from graphrag.config.models.language_model_config import LanguageModelConfig
|
||||
from graphrag.index.operations.finalize_community_reports import (
|
||||
finalize_community_reports,
|
||||
)
|
||||
@ -27,6 +28,7 @@ from graphrag.index.operations.summarize_communities.text_unit_context.context_b
|
||||
)
|
||||
from graphrag.index.typing.context import PipelineRunContext
|
||||
from graphrag.index.typing.workflow import WorkflowFunctionOutput
|
||||
from graphrag.tokenizer.get_tokenizer import get_tokenizer
|
||||
from graphrag.utils.storage import load_table_from_storage, write_table_to_storage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -88,8 +90,11 @@ async def create_community_reports_text(
|
||||
"max_input_length", graphrag_config_defaults.community_reports.max_input_length
|
||||
)
|
||||
|
||||
model_config = LanguageModelConfig(**summarization_strategy["llm"])
|
||||
tokenizer = get_tokenizer(model_config)
|
||||
|
||||
local_contexts = build_local_context(
|
||||
communities, text_units, nodes, max_input_length
|
||||
communities, text_units, nodes, tokenizer, max_input_length
|
||||
)
|
||||
|
||||
community_reports = await summarize_communities(
|
||||
@ -100,6 +105,7 @@ async def create_community_reports_text(
|
||||
callbacks,
|
||||
cache,
|
||||
summarization_strategy,
|
||||
tokenizer=tokenizer,
|
||||
max_input_length=max_input_length,
|
||||
async_mode=async_mode,
|
||||
num_threads=num_threads,
|
||||
|
||||
@ -56,9 +56,7 @@ class ProgressTicker:
|
||||
description=self._description,
|
||||
)
|
||||
if p.description:
|
||||
logger.info(
|
||||
"%s%s/%s", p.description, str(p.completed_items), str(p.total_items)
|
||||
)
|
||||
logger.info("%s%s/%s", p.description, p.completed_items, p.total_items)
|
||||
self._callback(p)
|
||||
|
||||
def done(self) -> None:
|
||||
|
||||
@ -9,21 +9,15 @@ import re
|
||||
from collections.abc import Iterator
|
||||
from itertools import islice
|
||||
|
||||
import tiktoken
|
||||
from json_repair import repair_json
|
||||
|
||||
import graphrag.config.defaults as defs
|
||||
from graphrag.tokenizer.get_tokenizer import get_tokenizer
|
||||
from graphrag.tokenizer.tokenizer import Tokenizer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def num_tokens(text: str, token_encoder: tiktoken.Encoding | None = None) -> int:
|
||||
"""Return the number of tokens in the given text."""
|
||||
if token_encoder is None:
|
||||
token_encoder = tiktoken.get_encoding(defs.ENCODING_MODEL)
|
||||
return len(token_encoder.encode(text)) # type: ignore
|
||||
|
||||
|
||||
def batched(iterable: Iterator, n: int):
|
||||
"""
|
||||
Batch data into tuples of length n. The last batch may be shorter.
|
||||
@ -39,15 +33,13 @@ def batched(iterable: Iterator, n: int):
|
||||
yield batch
|
||||
|
||||
|
||||
def chunk_text(
|
||||
text: str, max_tokens: int, token_encoder: tiktoken.Encoding | None = None
|
||||
):
|
||||
def chunk_text(text: str, max_tokens: int, tokenizer: Tokenizer | None = None):
|
||||
"""Chunk text by token length."""
|
||||
if token_encoder is None:
|
||||
token_encoder = tiktoken.get_encoding(defs.ENCODING_MODEL)
|
||||
tokens = token_encoder.encode(text) # type: ignore
|
||||
if tokenizer is None:
|
||||
tokenizer = get_tokenizer(encoding_model=defs.ENCODING_MODEL)
|
||||
tokens = tokenizer.encode(text) # type: ignore
|
||||
chunk_iterator = batched(iter(tokens), max_tokens)
|
||||
yield from (token_encoder.decode(list(chunk)) for chunk in chunk_iterator)
|
||||
yield from (tokenizer.decode(list(chunk)) for chunk in chunk_iterator)
|
||||
|
||||
|
||||
def try_parse_json_object(input: str, verbose: bool = True) -> tuple[str, dict]:
|
||||
|
||||
@ -7,13 +7,13 @@ from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import tiktoken
|
||||
|
||||
from graphrag.language_model.protocol.base import ChatModel
|
||||
from graphrag.query.context_builder.builders import (
|
||||
GlobalContextBuilder,
|
||||
LocalContextBuilder,
|
||||
)
|
||||
from graphrag.tokenizer.get_tokenizer import get_tokenizer
|
||||
from graphrag.tokenizer.tokenizer import Tokenizer
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -34,13 +34,13 @@ class BaseQuestionGen(ABC):
|
||||
self,
|
||||
model: ChatModel,
|
||||
context_builder: GlobalContextBuilder | LocalContextBuilder,
|
||||
token_encoder: tiktoken.Encoding | None = None,
|
||||
tokenizer: Tokenizer | None = None,
|
||||
model_params: dict[str, Any] | None = None,
|
||||
context_builder_params: dict[str, Any] | None = None,
|
||||
):
|
||||
self.model = model
|
||||
self.context_builder = context_builder
|
||||
self.token_encoder = token_encoder
|
||||
self.tokenizer = tokenizer or get_tokenizer(model.config)
|
||||
self.model_params = model_params or {}
|
||||
self.context_builder_params = context_builder_params or {}
|
||||
|
||||
|
||||
@ -7,8 +7,6 @@ import logging
|
||||
import time
|
||||
from typing import Any, cast
|
||||
|
||||
import tiktoken
|
||||
|
||||
from graphrag.callbacks.llm_callbacks import BaseLLMCallback
|
||||
from graphrag.language_model.protocol.base import ChatModel
|
||||
from graphrag.prompts.query.question_gen_system_prompt import QUESTION_SYSTEM_PROMPT
|
||||
@ -19,8 +17,8 @@ from graphrag.query.context_builder.builders import (
|
||||
from graphrag.query.context_builder.conversation_history import (
|
||||
ConversationHistory,
|
||||
)
|
||||
from graphrag.query.llm.text_utils import num_tokens
|
||||
from graphrag.query.question_gen.base import BaseQuestionGen, QuestionResult
|
||||
from graphrag.tokenizer.tokenizer import Tokenizer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -32,7 +30,7 @@ class LocalQuestionGen(BaseQuestionGen):
|
||||
self,
|
||||
model: ChatModel,
|
||||
context_builder: LocalContextBuilder,
|
||||
token_encoder: tiktoken.Encoding | None = None,
|
||||
tokenizer: Tokenizer | None = None,
|
||||
system_prompt: str = QUESTION_SYSTEM_PROMPT,
|
||||
callbacks: list[BaseLLMCallback] | None = None,
|
||||
model_params: dict[str, Any] | None = None,
|
||||
@ -41,7 +39,7 @@ class LocalQuestionGen(BaseQuestionGen):
|
||||
super().__init__(
|
||||
model=model,
|
||||
context_builder=context_builder,
|
||||
token_encoder=token_encoder,
|
||||
tokenizer=tokenizer,
|
||||
model_params=model_params,
|
||||
context_builder_params=context_builder_params,
|
||||
)
|
||||
@ -118,7 +116,7 @@ class LocalQuestionGen(BaseQuestionGen):
|
||||
},
|
||||
completion_time=time.time() - start_time,
|
||||
llm_calls=1,
|
||||
prompt_tokens=num_tokens(system_prompt, self.token_encoder),
|
||||
prompt_tokens=self.tokenizer.num_tokens(system_prompt),
|
||||
)
|
||||
|
||||
except Exception:
|
||||
@ -128,7 +126,7 @@ class LocalQuestionGen(BaseQuestionGen):
|
||||
context_data=context_records,
|
||||
completion_time=time.time() - start_time,
|
||||
llm_calls=1,
|
||||
prompt_tokens=num_tokens(system_prompt, self.token_encoder),
|
||||
prompt_tokens=self.tokenizer.num_tokens(system_prompt),
|
||||
)
|
||||
|
||||
async def generate(
|
||||
@ -201,7 +199,7 @@ class LocalQuestionGen(BaseQuestionGen):
|
||||
},
|
||||
completion_time=time.time() - start_time,
|
||||
llm_calls=1,
|
||||
prompt_tokens=num_tokens(system_prompt, self.token_encoder),
|
||||
prompt_tokens=self.tokenizer.num_tokens(system_prompt),
|
||||
)
|
||||
|
||||
except Exception:
|
||||
@ -211,5 +209,5 @@ class LocalQuestionGen(BaseQuestionGen):
|
||||
context_data=context_records,
|
||||
completion_time=time.time() - start_time,
|
||||
llm_calls=1,
|
||||
prompt_tokens=num_tokens(system_prompt, self.token_encoder),
|
||||
prompt_tokens=self.tokenizer.num_tokens(system_prompt),
|
||||
)
|
||||
|
||||
@ -3,19 +3,15 @@
|
||||
|
||||
"""Get Tokenizer."""
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from graphrag.config.defaults import ENCODING_MODEL
|
||||
from graphrag.config.models.language_model_config import LanguageModelConfig
|
||||
from graphrag.tokenizer.litellm_tokenizer import LitellmTokenizer
|
||||
from graphrag.tokenizer.tiktoken_tokenizer import TiktokenTokenizer
|
||||
from graphrag.tokenizer.tokenizer import Tokenizer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from graphrag.config.models.language_model_config import LanguageModelConfig
|
||||
|
||||
|
||||
def get_tokenizer(
|
||||
model_config: "LanguageModelConfig | None" = None,
|
||||
model_config: LanguageModelConfig | None = None,
|
||||
encoding_model: str = ENCODING_MODEL,
|
||||
) -> Tokenizer:
|
||||
"""
|
||||
|
||||
@ -130,7 +130,7 @@ def get_embedding_store(
|
||||
embedding_store = VectorStoreFactory().create_vector_store(
|
||||
vector_store_type=vector_store_type,
|
||||
vector_store_schema_config=single_embedding_config,
|
||||
kwargs={**store},
|
||||
**store,
|
||||
)
|
||||
embedding_store.connect(**store)
|
||||
# If there is only a single index, return the embedding store directly
|
||||
|
||||
@ -53,7 +53,7 @@ class VectorStoreFactory:
|
||||
cls,
|
||||
vector_store_type: str,
|
||||
vector_store_schema_config: VectorStoreSchemaConfig,
|
||||
kwargs: dict,
|
||||
**kwargs: dict,
|
||||
) -> BaseVectorStore:
|
||||
"""Create a vector store object from the provided type.
|
||||
|
||||
|
||||
@ -69,7 +69,7 @@ dependencies = [
|
||||
"litellm>=1.77.1",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
[dependency-groups]
|
||||
dev = [
|
||||
"coverage>=7.6.9",
|
||||
"ipykernel>=6.29.5",
|
||||
@ -239,6 +239,7 @@ ignore = [
|
||||
"PERF203", # Needs restructuring of errors, we should bail-out on first error
|
||||
"C901", # needs refactoring to remove cyclomatic complexity
|
||||
"B008", # Needs to restructure our cli params with Typer into constants
|
||||
"ASYNC240"
|
||||
]
|
||||
|
||||
[tool.ruff.lint.per-file-ignores]
|
||||
|
||||
@ -81,9 +81,7 @@ def test_register_and_create_custom_vector_store():
|
||||
)
|
||||
|
||||
vector_store = VectorStoreFactory.create_vector_store(
|
||||
vector_store_type="custom",
|
||||
vector_store_schema_config=VectorStoreSchemaConfig(),
|
||||
kwargs={},
|
||||
vector_store_type="custom", vector_store_schema_config=VectorStoreSchemaConfig()
|
||||
)
|
||||
|
||||
assert custom_vector_store_class.called
|
||||
@ -109,7 +107,6 @@ def test_create_unknown_vector_store():
|
||||
VectorStoreFactory.create_vector_store(
|
||||
vector_store_type="unknown",
|
||||
vector_store_schema_config=VectorStoreSchemaConfig(),
|
||||
kwargs={},
|
||||
)
|
||||
|
||||
|
||||
@ -162,7 +159,6 @@ def test_register_class_directly_works():
|
||||
vector_store = VectorStoreFactory.create_vector_store(
|
||||
vector_store_type="custom_class",
|
||||
vector_store_schema_config=VectorStoreSchemaConfig(),
|
||||
kwargs={"collection_name": "test"},
|
||||
)
|
||||
|
||||
assert isinstance(vector_store, CustomVectorStore)
|
||||
|
||||
@ -6,7 +6,7 @@ import platform
|
||||
from graphrag.index.operations.summarize_communities.graph_context.sort_context import (
|
||||
sort_context,
|
||||
)
|
||||
from graphrag.query.llm.text_utils import num_tokens
|
||||
from graphrag.tokenizer.get_tokenizer import get_tokenizer
|
||||
|
||||
nan = math.nan
|
||||
|
||||
@ -204,16 +204,18 @@ context: list[dict] = [
|
||||
|
||||
|
||||
def test_sort_context():
|
||||
ctx = sort_context(context)
|
||||
tokenizer = get_tokenizer()
|
||||
ctx = sort_context(context, tokenizer=tokenizer)
|
||||
assert ctx is not None, "Context is none"
|
||||
num = num_tokens(ctx)
|
||||
num = tokenizer.num_tokens(ctx)
|
||||
assert num == 828 if platform.system() == "Windows" else 826, (
|
||||
f"num_tokens is not matched for platform (win = 827, else 826): {num}"
|
||||
)
|
||||
|
||||
|
||||
def test_sort_context_max_tokens():
|
||||
ctx = sort_context(context, max_context_tokens=800)
|
||||
tokenizer = get_tokenizer()
|
||||
ctx = sort_context(context, tokenizer=tokenizer, max_context_tokens=800)
|
||||
assert ctx is not None, "Context is none"
|
||||
num = num_tokens(ctx)
|
||||
num = tokenizer.num_tokens(ctx)
|
||||
assert num <= 800, f"num_tokens is not less than or equal to 800: {num}"
|
||||
|
||||
@ -91,7 +91,7 @@ You can host Unified Search datasets locally or in a blob.
|
||||
|
||||
# Run the app
|
||||
|
||||
Install all the dependencies: `uv sync --extra dev`
|
||||
Install all the dependencies: `uv sync`
|
||||
|
||||
Run the project using streamlit: `uv run poe start`
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user