Compare commits

...

19 Commits

Author SHA1 Message Date
Alonso Guevara
fdb7e3835b
Release v2.7.0 (#2087)
Some checks failed
gh-pages / build (push) Has been cancelled
Python CI / python-ci (ubuntu-latest, 3.10) (push) Has been cancelled
Python CI / python-ci (ubuntu-latest, 3.11) (push) Has been cancelled
Python CI / python-ci (windows-latest, 3.10) (push) Has been cancelled
Python CI / python-ci (windows-latest, 3.11) (push) Has been cancelled
Python Integration Tests / python-ci (ubuntu-latest, 3.10) (push) Has been cancelled
Python Integration Tests / python-ci (windows-latest, 3.10) (push) Has been cancelled
Python Notebook Tests / python-ci (ubuntu-latest, 3.10) (push) Has been cancelled
Python Notebook Tests / python-ci (windows-latest, 3.10) (push) Has been cancelled
Python Publish (pypi) / Upload release to PyPI (push) Has been cancelled
Python Smoke Tests / python-ci (ubuntu-latest, 3.10) (push) Has been cancelled
Python Smoke Tests / python-ci (windows-latest, 3.10) (push) Has been cancelled
Spellcheck / spellcheck (push) Has been cancelled
2025-10-08 21:33:34 -07:00
Nathan Evans
ac8a7f5eef
Housekeeping (#2086)
Some checks failed
gh-pages / build (push) Has been cancelled
Python CI / python-ci (ubuntu-latest, 3.10) (push) Has been cancelled
Python CI / python-ci (ubuntu-latest, 3.11) (push) Has been cancelled
Python CI / python-ci (windows-latest, 3.10) (push) Has been cancelled
Python CI / python-ci (windows-latest, 3.11) (push) Has been cancelled
Python Integration Tests / python-ci (ubuntu-latest, 3.10) (push) Has been cancelled
Python Integration Tests / python-ci (windows-latest, 3.10) (push) Has been cancelled
Python Notebook Tests / python-ci (ubuntu-latest, 3.10) (push) Has been cancelled
Python Notebook Tests / python-ci (windows-latest, 3.10) (push) Has been cancelled
Python Publish (pypi) / Upload release to PyPI (push) Has been cancelled
Python Smoke Tests / python-ci (ubuntu-latest, 3.10) (push) Has been cancelled
Python Smoke Tests / python-ci (windows-latest, 3.10) (push) Has been cancelled
Spellcheck / spellcheck (push) Has been cancelled
* Add deprecation warnings for fnllm and multi-search

* Fix dangling token_encoder refs

* Fix local_search notebook

* Fix global search dynamic notebook

* Fix global search notebook

* Fix drift notebook

* Switch example notebooks to use LiteLLM config

* Properly annotate dev deps as a group

* Semver

* Remove --extra dev

* Remove llm_model variable

* Ignore ruff ASYNC240

* Add note about expected broken notebook in docs

* Fix custom vector store notebook

* Push tokenizer throughout
2025-10-07 16:21:24 -07:00
Nathan Evans
6c86b0a7bb
Init config cleanup (#2084)
Some checks failed
gh-pages / build (push) Has been cancelled
Python CI / python-ci (ubuntu-latest, 3.10) (push) Has been cancelled
Python CI / python-ci (ubuntu-latest, 3.11) (push) Has been cancelled
Python CI / python-ci (windows-latest, 3.10) (push) Has been cancelled
Python CI / python-ci (windows-latest, 3.11) (push) Has been cancelled
Python Integration Tests / python-ci (ubuntu-latest, 3.10) (push) Has been cancelled
Python Integration Tests / python-ci (windows-latest, 3.10) (push) Has been cancelled
Python Notebook Tests / python-ci (ubuntu-latest, 3.10) (push) Has been cancelled
Python Notebook Tests / python-ci (windows-latest, 3.10) (push) Has been cancelled
Python Publish (pypi) / Upload release to PyPI (push) Has been cancelled
Python Smoke Tests / python-ci (ubuntu-latest, 3.10) (push) Has been cancelled
Python Smoke Tests / python-ci (windows-latest, 3.10) (push) Has been cancelled
Spellcheck / spellcheck (push) Has been cancelled
* Spruce up init_config output, including LiteLLM default

* Remove deployment_name requirement for Azure

* Semver

* Add model_provider

* Add default model_provider

* Remove OBE test

* Update minimal config for tests

* Add model_provider to verb tests
2025-10-06 12:06:41 -07:00
Nathan Evans
2bd3922d8d
Litellm auth fix (#2083)
* Fix scope for Azure auth with LiteLLM

* Change internal language on max_attempts to max_retries

* Rework model config connectivity validation

* Semver

* Swtich smoke tests to LiteLLM

* Take out temporary retry_strategy = none since it is not fnllm compatible

* Bump smoke test timeout

* Bump smoke timeout further

* Tune smoke params

* Update smoke test bounds

* Remove covariates from min-csv smoke

* Smoke: adjust communities, remove drift

* Remove secrets where they aren't necessary

* Clean out old env var references
2025-10-06 10:54:21 -07:00
Nathan Evans
7f996cf584
Docs/2.6.0 (#2070)
Some checks failed
gh-pages / build (push) Has been cancelled
Python CI / python-ci (ubuntu-latest, 3.10) (push) Has been cancelled
Python CI / python-ci (ubuntu-latest, 3.11) (push) Has been cancelled
Python CI / python-ci (windows-latest, 3.10) (push) Has been cancelled
Python CI / python-ci (windows-latest, 3.11) (push) Has been cancelled
Python Integration Tests / python-ci (ubuntu-latest, 3.10) (push) Has been cancelled
Python Integration Tests / python-ci (windows-latest, 3.10) (push) Has been cancelled
Python Notebook Tests / python-ci (ubuntu-latest, 3.10) (push) Has been cancelled
Python Notebook Tests / python-ci (windows-latest, 3.10) (push) Has been cancelled
Python Publish (pypi) / Upload release to PyPI (push) Has been cancelled
Python Smoke Tests / python-ci (ubuntu-latest, 3.10) (push) Has been cancelled
Python Smoke Tests / python-ci (windows-latest, 3.10) (push) Has been cancelled
Spellcheck / spellcheck (push) Has been cancelled
* Add basic search to overview

* Add info on input documents DataFrame

* Add info on factories to docs

* Add consumption warning and switch to "christmas" for folder name

* Add logger to factories list

* Add litellm docs. (#2058)

* Fix version for input docs

* Spelling

---------

Co-authored-by: Derek Worthen <worthend.derek@gmail.com>
2025-09-23 14:48:28 -07:00
Alonso Guevara
9bc899fe95
Release v2.6.0 (#2068)
Some checks are pending
gh-pages / build (push) Waiting to run
Python CI / python-ci (ubuntu-latest, 3.10) (push) Waiting to run
Python CI / python-ci (ubuntu-latest, 3.11) (push) Waiting to run
Python CI / python-ci (windows-latest, 3.10) (push) Waiting to run
Python CI / python-ci (windows-latest, 3.11) (push) Waiting to run
Python Integration Tests / python-ci (ubuntu-latest, 3.10) (push) Waiting to run
Python Integration Tests / python-ci (windows-latest, 3.10) (push) Waiting to run
Python Notebook Tests / python-ci (ubuntu-latest, 3.10) (push) Waiting to run
Python Notebook Tests / python-ci (windows-latest, 3.10) (push) Waiting to run
Python Publish (pypi) / Upload release to PyPI (push) Waiting to run
Python Smoke Tests / python-ci (ubuntu-latest, 3.10) (push) Waiting to run
Python Smoke Tests / python-ci (windows-latest, 3.10) (push) Waiting to run
Spellcheck / spellcheck (push) Waiting to run
2025-09-22 16:16:54 -06:00
Derek Worthen
2b70e4a4f3
Tokenizer (#2051)
* Add LiteLLM chat and embedding model providers.

* Fix code review findings.

* Add litellm.

* Fix formatting.

* Update dictionary.

* Update litellm.

* Fix embedding.

* Remove manual use of tiktoken and replace with
Tokenizer interface. Adds support for encoding
and decoding the models supported by litellm.

* Update litellm.

* Configure litellm to drop unsupported params.

* Cleanup semversioner release notes.

* Add num_tokens util to Tokenizer interface.

* Update litellm service factories.

* Cleanup litellm chat/embedding model argument assignment.

* Update chat and embedding type field for litellm use and future migration away from fnllm.

* Flatten litellm service organization.

* Update litellm.

* Update litellm factory validation.

* Flatten litellm rate limit service organization.

* Update rate limiter - disable with None/null instead of 0.

* Fix usage of get_tokenizer.

* Update litellm service registrations.

* Add jitter to exponential retry.

* Update validation.

* Update validation.

* Add litellm request logging layer.

* Update cache key.

* Update defaults.

---------

Co-authored-by: Alonso Guevara <alonsog@microsoft.com>
2025-09-22 13:55:14 -06:00
gaudyb
82cd3b7df2
Custom vector store schema implementation (#2062)
Some checks failed
gh-pages / build (push) Has been cancelled
Python CI / python-ci (ubuntu-latest, 3.10) (push) Has been cancelled
Python CI / python-ci (ubuntu-latest, 3.11) (push) Has been cancelled
Python CI / python-ci (windows-latest, 3.10) (push) Has been cancelled
Python CI / python-ci (windows-latest, 3.11) (push) Has been cancelled
Python Integration Tests / python-ci (ubuntu-latest, 3.10) (push) Has been cancelled
Python Integration Tests / python-ci (windows-latest, 3.10) (push) Has been cancelled
Python Notebook Tests / python-ci (ubuntu-latest, 3.10) (push) Has been cancelled
Python Notebook Tests / python-ci (windows-latest, 3.10) (push) Has been cancelled
Python Publish (pypi) / Upload release to PyPI (push) Has been cancelled
Python Smoke Tests / python-ci (ubuntu-latest, 3.10) (push) Has been cancelled
Python Smoke Tests / python-ci (windows-latest, 3.10) (push) Has been cancelled
Spellcheck / spellcheck (push) Has been cancelled
* progress on vector customization

* fix for lancedb vectors

* cosmosdb implementation

* uv run poe format

* clean test for vector store

* semversioner update

* test_factory.py integration test fixes

* fixes for cosmosdb test

* integration test fix for lancedb

* uv fix for format

* test fixes

* fixes for tests

* fix cosmosdb bug

* print statement

* test

* test

* fix cosmosdb bug

* test validation

* validation cosmosdb

* validate cosmosdb

* fix cosmosdb

* fix small feedback from PR

---------

Co-authored-by: Gaudy Blanco <gaudy-microsoft@MacBook-Pro-m4-Gaudy-For-Work.local>
2025-09-19 10:11:34 -07:00
Nathan Evans
075cadd59a
Clarify managed auth setup in Azure documentation (#2064)
Some checks are pending
gh-pages / build (push) Waiting to run
Python CI / python-ci (ubuntu-latest, 3.10) (push) Waiting to run
Python CI / python-ci (ubuntu-latest, 3.11) (push) Waiting to run
Python CI / python-ci (windows-latest, 3.10) (push) Waiting to run
Python CI / python-ci (windows-latest, 3.11) (push) Waiting to run
Python Integration Tests / python-ci (ubuntu-latest, 3.10) (push) Waiting to run
Python Integration Tests / python-ci (windows-latest, 3.10) (push) Waiting to run
Python Notebook Tests / python-ci (ubuntu-latest, 3.10) (push) Waiting to run
Python Notebook Tests / python-ci (windows-latest, 3.10) (push) Waiting to run
Python Publish (pypi) / Upload release to PyPI (push) Waiting to run
Python Smoke Tests / python-ci (ubuntu-latest, 3.10) (push) Waiting to run
Python Smoke Tests / python-ci (windows-latest, 3.10) (push) Waiting to run
Spellcheck / spellcheck (push) Waiting to run
Updated instructions for using managed auth on Azure.
2025-09-18 14:58:09 -07:00
Nathan Evans
6d7a50b7f0
Remove community reports rate limiter (#2056)
* Remove hard-coded community reports rate limiter

* Semver

* Format

* Add memory cache factory
2025-09-18 13:40:24 -07:00
Nathan Evans
2bf7e7c018
Fix multi-index search (#2063) 2025-09-18 12:49:56 -07:00
Nathan Evans
6c66b7c30f
Configure async for NLP extraction (#2059)
Some checks failed
gh-pages / build (push) Has been cancelled
Python CI / python-ci (ubuntu-latest, 3.10) (push) Has been cancelled
Python CI / python-ci (ubuntu-latest, 3.11) (push) Has been cancelled
Python CI / python-ci (windows-latest, 3.10) (push) Has been cancelled
Python CI / python-ci (windows-latest, 3.11) (push) Has been cancelled
Python Integration Tests / python-ci (ubuntu-latest, 3.10) (push) Has been cancelled
Python Integration Tests / python-ci (windows-latest, 3.10) (push) Has been cancelled
Python Notebook Tests / python-ci (ubuntu-latest, 3.10) (push) Has been cancelled
Python Notebook Tests / python-ci (windows-latest, 3.10) (push) Has been cancelled
Python Publish (pypi) / Upload release to PyPI (push) Has been cancelled
Python Smoke Tests / python-ci (ubuntu-latest, 3.10) (push) Has been cancelled
Python Smoke Tests / python-ci (windows-latest, 3.10) (push) Has been cancelled
Spellcheck / spellcheck (push) Has been cancelled
* Make async mode configurable for NLP extraction

* Semver
2025-09-16 11:52:18 -07:00
Chenghua Duan
a398cc38bb
Update command to use no-discover-entity-types (#2038)
Some checks failed
gh-pages / build (push) Has been cancelled
Python CI / python-ci (ubuntu-latest, 3.10) (push) Has been cancelled
Python CI / python-ci (ubuntu-latest, 3.11) (push) Has been cancelled
Python CI / python-ci (windows-latest, 3.10) (push) Has been cancelled
Python CI / python-ci (windows-latest, 3.11) (push) Has been cancelled
Python Integration Tests / python-ci (ubuntu-latest, 3.10) (push) Has been cancelled
Python Integration Tests / python-ci (windows-latest, 3.10) (push) Has been cancelled
Python Notebook Tests / python-ci (ubuntu-latest, 3.10) (push) Has been cancelled
Python Notebook Tests / python-ci (windows-latest, 3.10) (push) Has been cancelled
Python Publish (pypi) / Upload release to PyPI (push) Has been cancelled
Python Smoke Tests / python-ci (ubuntu-latest, 3.10) (push) Has been cancelled
Python Smoke Tests / python-ci (windows-latest, 3.10) (push) Has been cancelled
Spellcheck / spellcheck (push) Has been cancelled
"no-entity-types" is an incorrect configuration parameter.

Co-authored-by: Alonso Guevara <alonsog@microsoft.com>
2025-09-09 16:46:06 -06:00
Derek Worthen
ac95c917d3
Update fnllm. (#2043)
Some checks failed
gh-pages / build (push) Has been cancelled
Python CI / python-ci (ubuntu-latest, 3.10) (push) Has been cancelled
Python CI / python-ci (ubuntu-latest, 3.11) (push) Has been cancelled
Python CI / python-ci (windows-latest, 3.10) (push) Has been cancelled
Python CI / python-ci (windows-latest, 3.11) (push) Has been cancelled
Python Integration Tests / python-ci (ubuntu-latest, 3.10) (push) Has been cancelled
Python Integration Tests / python-ci (windows-latest, 3.10) (push) Has been cancelled
Python Notebook Tests / python-ci (ubuntu-latest, 3.10) (push) Has been cancelled
Python Notebook Tests / python-ci (windows-latest, 3.10) (push) Has been cancelled
Python Publish (pypi) / Upload release to PyPI (push) Has been cancelled
Python Smoke Tests / python-ci (ubuntu-latest, 3.10) (push) Has been cancelled
Python Smoke Tests / python-ci (windows-latest, 3.10) (push) Has been cancelled
Spellcheck / spellcheck (push) Has been cancelled
2025-09-05 08:52:05 -07:00
Nathan Evans
1cb20b66f5
Input docs API parameter (#2034)
Some checks failed
gh-pages / build (push) Has been cancelled
Python CI / python-ci (ubuntu-latest, 3.10) (push) Has been cancelled
Python CI / python-ci (ubuntu-latest, 3.11) (push) Has been cancelled
Python CI / python-ci (windows-latest, 3.10) (push) Has been cancelled
Python CI / python-ci (windows-latest, 3.11) (push) Has been cancelled
Python Integration Tests / python-ci (ubuntu-latest, 3.10) (push) Has been cancelled
Python Integration Tests / python-ci (windows-latest, 3.10) (push) Has been cancelled
Python Notebook Tests / python-ci (ubuntu-latest, 3.10) (push) Has been cancelled
Python Notebook Tests / python-ci (windows-latest, 3.10) (push) Has been cancelled
Python Publish (pypi) / Upload release to PyPI (push) Has been cancelled
Python Smoke Tests / python-ci (ubuntu-latest, 3.10) (push) Has been cancelled
Python Smoke Tests / python-ci (windows-latest, 3.10) (push) Has been cancelled
Spellcheck / spellcheck (push) Has been cancelled
* Add optional input_documents to index API

* Semver

* Add input dataframe example notebook

* Format

* Fix docs and notebook
2025-09-02 16:15:50 -07:00
Copilot
2030f94eb4
Refactor CacheFactory, StorageFactory, and VectorStoreFactory to use consistent registration patterns and add custom vector store documentation (#2006)
Some checks failed
gh-pages / build (push) Has been cancelled
Python CI / python-ci (ubuntu-latest, 3.10) (push) Has been cancelled
Python CI / python-ci (ubuntu-latest, 3.11) (push) Has been cancelled
Python CI / python-ci (windows-latest, 3.10) (push) Has been cancelled
Python CI / python-ci (windows-latest, 3.11) (push) Has been cancelled
Python Integration Tests / python-ci (ubuntu-latest, 3.10) (push) Has been cancelled
Python Integration Tests / python-ci (windows-latest, 3.10) (push) Has been cancelled
Python Notebook Tests / python-ci (ubuntu-latest, 3.10) (push) Has been cancelled
Python Notebook Tests / python-ci (windows-latest, 3.10) (push) Has been cancelled
Python Publish (pypi) / Upload release to PyPI (push) Has been cancelled
Python Smoke Tests / python-ci (ubuntu-latest, 3.10) (push) Has been cancelled
Python Smoke Tests / python-ci (windows-latest, 3.10) (push) Has been cancelled
Spellcheck / spellcheck (push) Has been cancelled
* Initial plan

* Refactor VectorStoreFactory to use registration functionality like StorageFactory

Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com>

* Fix linting issues in VectorStoreFactory refactoring

Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com>

* Remove backward compatibility support from VectorStoreFactory and StorageFactory

Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com>

* Run ruff check --fix and ruff format, add semversioner file

Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com>

* ruff formatting fixes

* Fix pytest errors in storage factory tests by updating PipelineStorage interface implementation

Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com>

* ruff formatting fixes

* update storage factory design

* Refactor CacheFactory to use registration functionality like StorageFactory

Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com>

* revert copilot changes

* fix copilot changes

* update comments

* Fix failing pytest compatibility for factory tests

Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com>

* update class instantiation issue

* ruff fixes

* fix pytest

* add default value

* ruff formatting changes

* ruff fixes

* revert minor changes

* cleanup cache factory

* Update CacheFactory tests to match consistent factory pattern

Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com>

* update pytest thresholds

* adjust threshold levels

* Add custom vector store implementation notebook

Create comprehensive notebook demonstrating how to implement and register custom vector stores with GraphRAG as a plug-and-play framework. Includes:

- Complete implementation of SimpleInMemoryVectorStore
- Registration with VectorStoreFactory
- Testing and validation examples
- Configuration examples for GraphRAG settings
- Advanced features and best practices
- Production considerations checklist

The notebook provides a complete walkthrough for developers to understand and implement their own vector store backends.

Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com>

* remove sample notebook for now

* update tests

* fix cache pytests

* add pandas-stub to dev dependencies

* disable warning check for well known key

* skip tests when running on ubuntu

* add documentation for custom vector store implementations

* ignore ruff findings in notebooks

* fix merge breakages

* speedup CLI import statements

* remove unnecessary import statements in init file

* Add str type option on storage/cache type

* Fix store name

* Add LoggerFactory

* Fix up logging setup across CLI/API

* Add LoggerFactory test

* Fix err message

* Semver

* Remove enums from factory methods

---------

Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: jgbradley1 <654554+jgbradley1@users.noreply.github.com>
Co-authored-by: Josh Bradley <joshbradley@microsoft.com>
Co-authored-by: Nathan Evans <github@talkswithnumbers.com>
2025-08-28 13:53:07 -07:00
Nathan Evans
69ad36e735
Fix id baseline (#2036)
Some checks failed
gh-pages / build (push) Has been cancelled
Python CI / python-ci (ubuntu-latest, 3.10) (push) Has been cancelled
Python CI / python-ci (ubuntu-latest, 3.11) (push) Has been cancelled
Python CI / python-ci (windows-latest, 3.10) (push) Has been cancelled
Python CI / python-ci (windows-latest, 3.11) (push) Has been cancelled
Python Integration Tests / python-ci (ubuntu-latest, 3.10) (push) Has been cancelled
Python Integration Tests / python-ci (windows-latest, 3.10) (push) Has been cancelled
Python Notebook Tests / python-ci (ubuntu-latest, 3.10) (push) Has been cancelled
Python Notebook Tests / python-ci (windows-latest, 3.10) (push) Has been cancelled
Python Publish (pypi) / Upload release to PyPI (push) Has been cancelled
Python Smoke Tests / python-ci (ubuntu-latest, 3.10) (push) Has been cancelled
Python Smoke Tests / python-ci (windows-latest, 3.10) (push) Has been cancelled
Spellcheck / spellcheck (push) Has been cancelled
* Fix all human_readable_id columns to start at 0

* Semver
2025-08-27 11:15:21 -07:00
Nathan Evans
30bdb35cc8
Selective embeddings loading (#2035)
* Invert embedding table loading logic

* Semver
2025-08-27 11:12:01 -07:00
Nathan Evans
77fb7d9d7d
Logging improvements (#2030)
Some checks failed
gh-pages / build (push) Has been cancelled
Python CI / python-ci (ubuntu-latest, 3.10) (push) Has been cancelled
Python CI / python-ci (ubuntu-latest, 3.11) (push) Has been cancelled
Python CI / python-ci (windows-latest, 3.10) (push) Has been cancelled
Python CI / python-ci (windows-latest, 3.11) (push) Has been cancelled
Python Integration Tests / python-ci (ubuntu-latest, 3.10) (push) Has been cancelled
Python Integration Tests / python-ci (windows-latest, 3.10) (push) Has been cancelled
Python Notebook Tests / python-ci (ubuntu-latest, 3.10) (push) Has been cancelled
Python Notebook Tests / python-ci (windows-latest, 3.10) (push) Has been cancelled
Python Publish (pypi) / Upload release to PyPI (push) Has been cancelled
Python Smoke Tests / python-ci (ubuntu-latest, 3.10) (push) Has been cancelled
Python Smoke Tests / python-ci (windows-latest, 3.10) (push) Has been cancelled
Spellcheck / spellcheck (push) Has been cancelled
* Turn down blob/cosmos exception reporting to match file storage

* Restore indexing-engine.log

* Restore some basic console logging and progress for index CLI

* Semver

* Ignore small ruff complaints

* Fix CLI console printing
2025-08-25 14:56:43 -07:00
195 changed files with 8645 additions and 3555 deletions

View File

@ -15,8 +15,6 @@ jobs:
GH_PAGES: 1
DEBUG: 1
GRAPHRAG_API_KEY: ${{ secrets.GRAPHRAG_API_KEY }}
GRAPHRAG_LLM_MODEL: ${{ secrets.GRAPHRAG_LLM_MODEL }}
GRAPHRAG_EMBEDDING_MODEL: ${{ secrets.GRAPHRAG_EMBEDDING_MODEL }}
steps:
- uses: actions/checkout@v4
@ -33,7 +31,7 @@ jobs:
- name: Install dependencies
shell: bash
run: uv sync --extra dev
run: uv sync
- name: mkdocs build
shell: bash

View File

@ -67,7 +67,7 @@ jobs:
- name: Install dependencies
shell: bash
run: |
uv sync --extra dev
uv sync
uv pip install gensim
- name: Check

View File

@ -67,7 +67,7 @@ jobs:
- name: Install dependencies
shell: bash
run: |
uv sync --extra dev
uv sync
uv pip install gensim
- name: Build

View File

@ -38,8 +38,6 @@ jobs:
env:
DEBUG: 1
GRAPHRAG_API_KEY: ${{ secrets.OPENAI_NOTEBOOK_KEY }}
GRAPHRAG_LLM_MODEL: ${{ secrets.GRAPHRAG_LLM_MODEL }}
GRAPHRAG_EMBEDDING_MODEL: ${{ secrets.GRAPHRAG_EMBEDDING_MODEL }}
runs-on: ${{ matrix.os }}
steps:
@ -69,7 +67,7 @@ jobs:
- name: Install dependencies
shell: bash
run: |
uv sync --extra dev
uv sync
uv pip install gensim
- name: Notebook Test

View File

@ -37,20 +37,8 @@ jobs:
fail-fast: false # Continue running all jobs even if one fails
env:
DEBUG: 1
GRAPHRAG_LLM_TYPE: "azure_openai_chat"
GRAPHRAG_EMBEDDING_TYPE: "azure_openai_embedding"
GRAPHRAG_API_KEY: ${{ secrets.OPENAI_API_KEY }}
GRAPHRAG_API_BASE: ${{ secrets.GRAPHRAG_API_BASE }}
GRAPHRAG_API_VERSION: ${{ secrets.GRAPHRAG_API_VERSION }}
GRAPHRAG_LLM_DEPLOYMENT_NAME: ${{ secrets.GRAPHRAG_LLM_DEPLOYMENT_NAME }}
GRAPHRAG_EMBEDDING_DEPLOYMENT_NAME: ${{ secrets.GRAPHRAG_EMBEDDING_DEPLOYMENT_NAME }}
GRAPHRAG_LLM_MODEL: ${{ secrets.GRAPHRAG_LLM_MODEL }}
GRAPHRAG_EMBEDDING_MODEL: ${{ secrets.GRAPHRAG_EMBEDDING_MODEL }}
# We have Windows + Linux runners in 3.10, so we need to divide the rate limits by 2
GRAPHRAG_LLM_TPM: 200_000 # 400_000 / 2
GRAPHRAG_LLM_RPM: 1_000 # 2_000 / 2
GRAPHRAG_EMBEDDING_TPM: 225_000 # 450_000 / 2
GRAPHRAG_EMBEDDING_RPM: 1_000 # 2_000 / 2
# Azure AI Search config
AZURE_AI_SEARCH_URL_ENDPOINT: ${{ secrets.AZURE_AI_SEARCH_URL_ENDPOINT }}
AZURE_AI_SEARCH_API_KEY: ${{ secrets.AZURE_AI_SEARCH_API_KEY }}
@ -84,7 +72,7 @@ jobs:
- name: Install dependencies
shell: bash
run: |
uv sync --extra dev
uv sync
uv pip install gensim
- name: Build

54
.semversioner/2.6.0.json Normal file
View File

@ -0,0 +1,54 @@
{
"changes": [
{
"description": "Add LiteLLM chat and embedding model providers.",
"type": "minor"
},
{
"description": "Add LoggerFactory and clean up related API.",
"type": "minor"
},
{
"description": "Add config for NLP async mode.",
"type": "minor"
},
{
"description": "Add optional input documents to indexing API.",
"type": "minor"
},
{
"description": "add customization to vector store",
"type": "minor"
},
{
"description": "Add gpt-5 support by updating fnllm dependency.",
"type": "patch"
},
{
"description": "Fix all human_readable_id fields to be 0-based.",
"type": "patch"
},
{
"description": "Fix multi-index search.",
"type": "patch"
},
{
"description": "Improve upon recent logging refactor",
"type": "patch"
},
{
"description": "Make cache, storage, and vector_store factories consistent with similar registration support",
"type": "patch"
},
{
"description": "Remove hard-coded community rate limiter.",
"type": "patch"
},
{
"description": "generate_text_embeddings only loads tables if embedding field is specified.",
"type": "patch"
}
],
"created_at": "2025-09-22T21:44:51+00:00",
"version": "2.6.0"
}

18
.semversioner/2.7.0.json Normal file
View File

@ -0,0 +1,18 @@
{
"changes": [
{
"description": "Set LiteLLM as default in init_content.",
"type": "minor"
},
{
"description": "Fix Azure auth scope issue with LiteLLM.",
"type": "patch"
},
{
"description": "Housekeeping toward 2.7.",
"type": "patch"
}
],
"created_at": "2025-10-08T22:39:42+00:00",
"version": "2.7.0"
}

55
.vscode/launch.json vendored
View File

@ -6,21 +6,24 @@
"name": "Indexer",
"type": "debugpy",
"request": "launch",
"module": "uv",
"module": "graphrag",
"args": [
"poe", "index",
"--root", "<path_to_ragtest_root_demo>"
"index",
"--root",
"<path_to_index_folder>"
],
"console": "integratedTerminal"
},
{
"name": "Query",
"type": "debugpy",
"request": "launch",
"module": "uv",
"module": "graphrag",
"args": [
"poe", "query",
"--root", "<path_to_ragtest_root_demo>",
"--method", "global",
"query",
"--root",
"<path_to_index_folder>",
"--method", "basic",
"--query", "What are the top themes in this story",
]
},
@ -34,6 +37,42 @@
"--config",
"<path_to_ragtest_root_demo>/settings.yaml",
]
}
},
{
"name": "Debug Integration Pytest",
"type": "debugpy",
"request": "launch",
"module": "pytest",
"args": [
"./tests/integration/vector_stores",
"-k", "test_azure_ai_search"
],
"console": "integratedTerminal",
"justMyCode": false
},
{
"name": "Debug Verbs Pytest",
"type": "debugpy",
"request": "launch",
"module": "pytest",
"args": [
"./tests/verbs",
"-k", "test_generate_text_embeddings"
],
"console": "integratedTerminal",
"justMyCode": false
},
{
"name": "Debug Smoke Pytest",
"type": "debugpy",
"request": "launch",
"module": "pytest",
"args": [
"./tests/smoke",
"-k", "test_fixtures"
],
"console": "integratedTerminal",
"justMyCode": false
},
]
}

View File

@ -1,6 +1,27 @@
# Changelog
Note: version releases in the 0.x.y range may introduce breaking changes.
## 2.7.0
- minor: Set LiteLLM as default in init_content.
- patch: Fix Azure auth scope issue with LiteLLM.
- patch: Housekeeping toward 2.7.
## 2.6.0
- minor: Add LiteLLM chat and embedding model providers.
- minor: Add LoggerFactory and clean up related API.
- minor: Add config for NLP async mode.
- minor: Add optional input documents to indexing API.
- minor: add customization to vector store
- patch: Add gpt-5 support by updating fnllm dependency.
- patch: Fix all human_readable_id fields to be 0-based.
- patch: Fix multi-index search.
- patch: Improve upon recent logging refactor
- patch: Make cache, storage, and vector_store factories consistent with similar registration support
- patch: Remove hard-coded community rate limiter.
- patch: generate_text_embeddings only loads tables if embedding field is specified.
## 2.5.0
- minor: Add additional context variable to build index signature for custom parameter bag

View File

@ -11,12 +11,8 @@
## Install Dependencies
```shell
# (optional) create virtual environment
uv venv --python 3.10
source .venv/bin/activate
# install python dependencies
uv sync --extra dev
uv sync
```
## Execute the indexing engine
@ -119,8 +115,3 @@ and then in your bashrc, add
Make sure you have python3.10-dev installed or more generally `python<version>-dev`
`sudo apt-get install python3.10-dev`
### LLM call constantly exceeds TPM, RPM or time limits
`GRAPHRAG_LLM_THREAD_COUNT` and `GRAPHRAG_EMBEDDING_THREAD_COUNT` are both set to 50 by default. You can modify these values
to reduce concurrency. Please refer to the [Configuration Documents](https://microsoft.github.io/graphrag/config/overview/)

View File

@ -81,6 +81,7 @@ typer
spacy
kwargs
ollama
litellm
# Library Methods
iterrows
@ -103,6 +104,8 @@ isin
nocache
nbconvert
levelno
acompletion
aembedding
# HTML
nbsp

View File

@ -8,9 +8,38 @@ GraphRAG was built and tested using OpenAI models, so this is the default model
GraphRAG also utilizes a language model wrapper library used by several projects within our team, called fnllm. fnllm provides two important functions for GraphRAG: rate limiting configuration to help us maximize throughput for large indexing jobs, and robust caching of API calls to minimize consumption on repeated indexes for testing, experimentation, or incremental ingest. fnllm uses the OpenAI Python SDK under the covers, so OpenAI-compliant endpoints are a base requirement out-of-the-box.
Starting with version 2.6.0, GraphRAG supports using [LiteLLM](https://docs.litellm.ai/) instead of fnllm for calling language models. LiteLLM provides support for 100+ models though it is important to note that when choosing a model it must support returning [structured outputs](https://openai.com/index/introducing-structured-outputs-in-the-api/) adhering to a [JSON schema](https://docs.litellm.ai/docs/completion/json_mode).
Example using LiteLLm as the language model tool for GraphRAG:
```yaml
models:
default_chat_model:
type: chat
auth_type: api_key
api_key: ${GEMINI_API_KEY}
model_provider: gemini
model: gemini-2.5-flash-lite
default_embedding_model:
type: embedding
auth_type: api_key
api_key: ${GEMINI_API_KEY}
model_provider: gemini
model: gemini-embedding-001
```
To use LiteLLM one must
- Set `type` to either `chat` or `embedding`.
- Provide a `model_provider`, e.g., `openai`, `azure`, `gemini`, etc.
- Set the `model` to a one supported by the `model_provider`'s API.
- Provide a `deployment_name` if using `azure` as the `model_provider`.
See [Detailed Configuration](yaml.md) for more details on configuration. [View LiteLLm basic usage](https://docs.litellm.ai/docs/#basic-usage) for details on how models are called (The `model_provider` is the portion prior to `/` while the `model` is the portion following the `/`).
## Model Selection Considerations
GraphRAG has been most thoroughly tested with the gpt-4 series of models from OpenAI, including gpt-4 gpt-4-turbo, gpt-4o, and gpt-4o-mini. Our [arXiv paper](https://arxiv.org/abs/2404.16130), for example, performed quality evaluation using gpt-4-turbo.
GraphRAG has been most thoroughly tested with the gpt-4 series of models from OpenAI, including gpt-4 gpt-4-turbo, gpt-4o, and gpt-4o-mini. Our [arXiv paper](https://arxiv.org/abs/2404.16130), for example, performed quality evaluation using gpt-4-turbo. As stated above, non-OpenAI models are now supported with GraphRAG 2.6.0 and onwards through the use of LiteLLM but the suite of gpt-4 series of models from OpenAI remain the most tested and supported suite of models for GraphRAG.
Versions of GraphRAG before 2.2.0 made extensive use of `max_tokens` and `logit_bias` to control generated response length or content. The introduction of the o-series of models added new, non-compatible parameters because these models include a reasoning component that has different consumption patterns and response generation attributes than non-reasoning models. GraphRAG 2.2.0 now supports these models, but there are important differences that need to be understood before you switch.
@ -58,11 +87,11 @@ Another option would be to avoid using a language model at all for the graph ext
## Using Non-OpenAI Models
As noted above, our primary experience and focus has been on OpenAI models, so this is what is supported out-of-the-box. Many users have requested support for additional model types, but it's out of the scope of our research to handle the many models available today. There are two approaches you can use to connect to a non-OpenAI model:
As shown above, non-OpenAI models may be used via LiteLLM starting with GraphRAG version 2.6.0 but cases may still exist in which some users wish to use models not supported by LiteLLM. There are two approaches one can use to connect to unsupported models:
### Proxy APIs
Many users have used platforms such as [ollama](https://ollama.com/) to proxy the underlying model HTTP calls to a different model provider. This seems to work reasonably well, but we frequently see issues with malformed responses (especially JSON), so if you do this please understand that your model needs to reliably return the specific response formats that GraphRAG expects. If you're having trouble with a model, you may need to try prompting to coax the format, or intercepting the response within your proxy to try and handle malformed responses.
Many users have used platforms such as [ollama](https://ollama.com/) and [LiteLLM Proxy Server](https://docs.litellm.ai/docs/simple_proxy) to proxy the underlying model HTTP calls to a different model provider. This seems to work reasonably well, but we frequently see issues with malformed responses (especially JSON), so if you do this please understand that your model needs to reliably return the specific response formats that GraphRAG expects. If you're having trouble with a model, you may need to try prompting to coax the format, or intercepting the response within your proxy to try and handle malformed responses.
### Model Protocol

View File

@ -41,7 +41,8 @@ models:
- `api_key` **str** - The OpenAI API key to use.
- `auth_type` **api_key|azure_managed_identity** - Indicate how you want to authenticate requests.
- `type` **openai_chat|azure_openai_chat|openai_embedding|azure_openai_embedding|mock_chat|mock_embeddings** - The type of LLM to use.
- `type` **chat**|**embedding**|**openai_chat|azure_openai_chat|openai_embedding|azure_openai_embedding|mock_chat|mock_embeddings** - The type of LLM to use.
- `model_provider` **str|None** - The model provider to use, e.g., openai, azure, anthropic, etc. Required when `type == chat|embedding`. When `type == chat|embedding`, [LiteLLM](https://docs.litellm.ai/) is used under the hood which has support for calling 100+ models. [View LiteLLm basic usage](https://docs.litellm.ai/docs/#basic-usage) for details on how models are called (The `model_provider` is the portion prior to `/` while the `model` is the portion following the `/`). [View Language Model Selection](models.md) for more details and examples on using LiteLLM.
- `model` **str** - The model name.
- `encoding_model` **str** - The text encoding model to use. Default is to use the encoding model aligned with the language model (i.e., it is retrieved from tiktoken if unset).
- `api_base` **str** - The API base url to use.

View File

@ -12,12 +12,8 @@
## Install Dependencies
```sh
# (optional) create virtual environment
uv venv --python 3.10
source .venv/bin/activate
# install python dependencies
uv sync --extra dev
uv sync
```
## Execute the Indexing Engine
@ -77,8 +73,3 @@ Make sure llvm-9 and llvm-9-dev are installed:
and then in your bashrc, add
`export LLVM_CONFIG=/usr/bin/llvm-config-9`
### LLM call constantly exceeds TPM, RPM or time limits
`GRAPHRAG_LLM_THREAD_COUNT` and `GRAPHRAG_EMBEDDING_THREAD_COUNT` are both set to 50 by default. You can modify these values
to reduce concurrency. Please refer to the [Configuration Documents](config/overview.md)

View File

@ -67,6 +67,8 @@
"metadata": {},
"outputs": [],
"source": [
"# note that we expect this to fail on the deployed docs because the PROJECT_DIRECTORY is not set to a real location.\n",
"# if you run this notebook locally, make sure to point at a location containing your settings.yaml\n",
"graphrag_config = load_config(Path(PROJECT_DIRECTORY))"
]
},

View File

@ -0,0 +1,680 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Copyright (c) 2024 Microsoft Corporation.\n",
"# Licensed under the MIT License."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Bring-Your-Own Vector Store\n",
"\n",
"This notebook demonstrates how to implement a custom vector store and register for usage with GraphRAG.\n",
"\n",
"## Overview\n",
"\n",
"GraphRAG uses a plug-and-play architecture that allow for easy integration of custom vector stores (outside of what is natively supported) by following a factory design pattern. This allows you to:\n",
"\n",
"- **Extend functionality**: Add support for new vector database backends\n",
"- **Customize behavior**: Implement specialized search logic or data structures\n",
"- **Integrate existing systems**: Connect GraphRAG to your existing vector database infrastructure\n",
"\n",
"### What You'll Learn\n",
"\n",
"1. Understanding the `BaseVectorStore` interface\n",
"2. Implementing a custom vector store class\n",
"3. Registering your vector store with the `VectorStoreFactory`\n",
"4. Testing and validating your implementation\n",
"5. Configuring GraphRAG to use your custom vector store\n",
"\n",
"Let's get started!"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Step 1: Import Required Dependencies\n",
"\n",
"First, let's import the necessary GraphRAG components and other dependencies we'll need.\n",
"\n",
"```bash\n",
"pip install graphrag\n",
"```"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from typing import Any\n",
"\n",
"import numpy as np\n",
"import yaml\n",
"\n",
"from graphrag.config.models.vector_store_schema_config import VectorStoreSchemaConfig\n",
"from graphrag.data_model.types import TextEmbedder\n",
"\n",
"# GraphRAG vector store components\n",
"from graphrag.vector_stores.base import (\n",
" BaseVectorStore,\n",
" VectorStoreDocument,\n",
" VectorStoreSearchResult,\n",
")\n",
"from graphrag.vector_stores.factory import VectorStoreFactory"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Step 2: Understand the BaseVectorStore Interface\n",
"\n",
"Before using a custom vector store, let's examine the `BaseVectorStore` interface to understand what methods need to be implemented."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Let's inspect the BaseVectorStore class to understand the required methods\n",
"import inspect\n",
"\n",
"print(\"BaseVectorStore Abstract Methods:\")\n",
"print(\"=\" * 40)\n",
"\n",
"abstract_methods = []\n",
"for name, method in inspect.getmembers(BaseVectorStore, predicate=inspect.isfunction):\n",
" if getattr(method, \"__isabstractmethod__\", False):\n",
" signature = inspect.signature(method)\n",
" abstract_methods.append(f\"• {name}{signature}\")\n",
" print(f\"• {name}{signature}\")\n",
"\n",
"print(f\"\\nTotal abstract methods to implement: {len(abstract_methods)}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Step 3: Implement a Custom Vector Store\n",
"\n",
"Now let's implement a simple in-memory vector store as an example. This vector store will:\n",
"\n",
"- Store documents and vectors in memory using Python data structures\n",
"- Support all required BaseVectorStore methods\n",
"\n",
"**Note**: This is a simplified example for demonstration. Production vector stores would typically use optimized libraries like FAISS, more sophisticated indexing, and persistent storage."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class SimpleInMemoryVectorStore(BaseVectorStore):\n",
" \"\"\"A simple in-memory vector store implementation for demonstration purposes.\n",
"\n",
" This vector store stores documents and their embeddings in memory and provides\n",
" basic similarity search functionality using cosine similarity.\n",
"\n",
" WARNING: This is for demonstration only - not suitable for production use.\n",
" For production, consider using optimized vector databases like LanceDB,\n",
" Azure AI Search, or other specialized vector stores.\n",
" \"\"\"\n",
"\n",
" # Internal storage for documents and vectors\n",
" documents: dict[str, VectorStoreDocument]\n",
" vectors: dict[str, np.ndarray]\n",
" connected: bool\n",
"\n",
" def __init__(self, **kwargs: Any):\n",
" \"\"\"Initialize the in-memory vector store.\"\"\"\n",
" super().__init__(**kwargs)\n",
"\n",
" self.documents: dict[str, VectorStoreDocument] = {}\n",
" self.vectors: dict[str, np.ndarray] = {}\n",
" self.connected = False\n",
"\n",
" print(f\"🚀 SimpleInMemoryVectorStore initialized for index: {self.index_name}\")\n",
"\n",
" def connect(self, **kwargs: Any) -> None:\n",
" \"\"\"Connect to the vector storage (no-op for in-memory store).\"\"\"\n",
" self.connected = True\n",
" print(f\"✅ Connected to in-memory vector store: {self.index_name}\")\n",
"\n",
" def load_documents(\n",
" self, documents: list[VectorStoreDocument], overwrite: bool = True\n",
" ) -> None:\n",
" \"\"\"Load documents into the vector store.\"\"\"\n",
" if not self.connected:\n",
" msg = \"Vector store not connected. Call connect() first.\"\n",
" raise RuntimeError(msg)\n",
"\n",
" if overwrite:\n",
" self.documents.clear()\n",
" self.vectors.clear()\n",
"\n",
" loaded_count = 0\n",
" for doc in documents:\n",
" if doc.vector is not None:\n",
" doc_id = str(doc.id)\n",
" self.documents[doc_id] = doc\n",
" self.vectors[doc_id] = np.array(doc.vector, dtype=np.float32)\n",
" loaded_count += 1\n",
"\n",
" print(f\"📚 Loaded {loaded_count} documents into vector store\")\n",
"\n",
" def _cosine_similarity(self, vec1: np.ndarray, vec2: np.ndarray) -> float:\n",
" \"\"\"Calculate cosine similarity between two vectors.\"\"\"\n",
" # Normalize vectors\n",
" norm1 = np.linalg.norm(vec1)\n",
" norm2 = np.linalg.norm(vec2)\n",
"\n",
" if norm1 == 0 or norm2 == 0:\n",
" return 0.0\n",
"\n",
" return float(np.dot(vec1, vec2) / (norm1 * norm2))\n",
"\n",
" def similarity_search_by_vector(\n",
" self, query_embedding: list[float], k: int = 10, **kwargs: Any\n",
" ) -> list[VectorStoreSearchResult]:\n",
" \"\"\"Perform similarity search using a query vector.\"\"\"\n",
" if not self.connected:\n",
" msg = \"Vector store not connected. Call connect() first.\"\n",
" raise RuntimeError(msg)\n",
"\n",
" if not self.vectors:\n",
" return []\n",
"\n",
" query_vec = np.array(query_embedding, dtype=np.float32)\n",
" similarities = []\n",
"\n",
" # Calculate similarity with all stored vectors\n",
" for doc_id, stored_vec in self.vectors.items():\n",
" similarity = self._cosine_similarity(query_vec, stored_vec)\n",
" similarities.append((doc_id, similarity))\n",
"\n",
" # Sort by similarity (descending) and take top k\n",
" similarities.sort(key=lambda x: x[1], reverse=True)\n",
" top_k = similarities[:k]\n",
"\n",
" # Create search results\n",
" results = []\n",
" for doc_id, score in top_k:\n",
" document = self.documents[doc_id]\n",
" result = VectorStoreSearchResult(document=document, score=score)\n",
" results.append(result)\n",
"\n",
" return results\n",
"\n",
" def similarity_search_by_text(\n",
" self, text: str, text_embedder: TextEmbedder, k: int = 10, **kwargs: Any\n",
" ) -> list[VectorStoreSearchResult]:\n",
" \"\"\"Perform similarity search using text (which gets embedded first).\"\"\"\n",
" # Embed the text first\n",
" query_embedding = text_embedder(text)\n",
"\n",
" # Use vector search with the embedding\n",
" return self.similarity_search_by_vector(query_embedding, k, **kwargs)\n",
"\n",
" def filter_by_id(self, include_ids: list[str] | list[int]) -> Any:\n",
" \"\"\"Build a query filter to filter documents by id.\n",
"\n",
" For this simple implementation, we return the list of IDs as the filter.\n",
" \"\"\"\n",
" return [str(id_) for id_ in include_ids]\n",
"\n",
" def search_by_id(self, id: str) -> VectorStoreDocument:\n",
" \"\"\"Search for a document by id.\"\"\"\n",
" doc_id = str(id)\n",
" if doc_id not in self.documents:\n",
" msg = f\"Document with id '{id}' not found\"\n",
" raise KeyError(msg)\n",
"\n",
" return self.documents[doc_id]\n",
"\n",
" def get_stats(self) -> dict[str, Any]:\n",
" \"\"\"Get statistics about the vector store (custom method).\"\"\"\n",
" return {\n",
" \"index_name\": self.index_name,\n",
" \"document_count\": len(self.documents),\n",
" \"vector_count\": len(self.vectors),\n",
" \"connected\": self.connected,\n",
" \"vector_dimension\": len(next(iter(self.vectors.values())))\n",
" if self.vectors\n",
" else 0,\n",
" }\n",
"\n",
"\n",
"print(\"✅ SimpleInMemoryVectorStore class defined!\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Step 4: Register the Custom Vector Store\n",
"\n",
"Now let's register our custom vector store with the `VectorStoreFactory` so it can be used throughout GraphRAG."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Register our custom vector store with a unique identifier\n",
"CUSTOM_VECTOR_STORE_TYPE = \"simple_memory\"\n",
"\n",
"# Register the vector store class\n",
"VectorStoreFactory.register(CUSTOM_VECTOR_STORE_TYPE, SimpleInMemoryVectorStore)\n",
"\n",
"print(f\"✅ Registered custom vector store with type: '{CUSTOM_VECTOR_STORE_TYPE}'\")\n",
"\n",
"# Verify registration\n",
"available_types = VectorStoreFactory.get_vector_store_types()\n",
"print(f\"\\n📋 Available vector store types: {available_types}\")\n",
"print(\n",
" f\"🔍 Is our custom type supported? {VectorStoreFactory.is_supported_type(CUSTOM_VECTOR_STORE_TYPE)}\"\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Step 5: Test the Custom Vector Store\n",
"\n",
"Let's create some sample data and test our custom vector store implementation."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Create sample documents with mock embeddings\n",
"def create_mock_embedding(dimension: int = 384) -> list[float]:\n",
" \"\"\"Create a random embedding vector for testing.\"\"\"\n",
" return np.random.normal(0, 1, dimension).tolist()\n",
"\n",
"\n",
"# Sample documents\n",
"sample_documents = [\n",
" VectorStoreDocument(\n",
" id=\"doc_1\",\n",
" text=\"GraphRAG is a powerful knowledge graph extraction and reasoning framework.\",\n",
" vector=create_mock_embedding(),\n",
" attributes={\"category\": \"technology\", \"source\": \"documentation\"},\n",
" ),\n",
" VectorStoreDocument(\n",
" id=\"doc_2\",\n",
" text=\"Vector stores enable efficient similarity search over high-dimensional data.\",\n",
" vector=create_mock_embedding(),\n",
" attributes={\"category\": \"technology\", \"source\": \"research\"},\n",
" ),\n",
" VectorStoreDocument(\n",
" id=\"doc_3\",\n",
" text=\"Machine learning models can process and understand natural language text.\",\n",
" vector=create_mock_embedding(),\n",
" attributes={\"category\": \"AI\", \"source\": \"article\"},\n",
" ),\n",
" VectorStoreDocument(\n",
" id=\"doc_4\",\n",
" text=\"Custom implementations allow for specialized behavior and integration.\",\n",
" vector=create_mock_embedding(),\n",
" attributes={\"category\": \"development\", \"source\": \"tutorial\"},\n",
" ),\n",
"]\n",
"\n",
"print(f\"📝 Created {len(sample_documents)} sample documents\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Test creating vector store using the factory\n",
"schema = VectorStoreSchemaConfig(index_name=\"test_collection\")\n",
"\n",
"# Create vector store instance using factory\n",
"vector_store = VectorStoreFactory.create_vector_store(\n",
" CUSTOM_VECTOR_STORE_TYPE, vector_store_schema_config=schema\n",
")\n",
"\n",
"print(f\"✅ Created vector store instance: {type(vector_store).__name__}\")\n",
"print(f\"📊 Initial stats: {vector_store.get_stats()}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Connect and load documents\n",
"vector_store.connect()\n",
"vector_store.load_documents(sample_documents)\n",
"\n",
"print(f\"📊 Updated stats: {vector_store.get_stats()}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Test similarity search\n",
"query_vector = create_mock_embedding() # Random query vector for testing\n",
"\n",
"search_results = vector_store.similarity_search_by_vector(\n",
" query_vector,\n",
" k=3, # Get top 3 similar documents\n",
")\n",
"\n",
"print(f\"🔍 Found {len(search_results)} similar documents:\\n\")\n",
"\n",
"for i, result in enumerate(search_results, 1):\n",
" doc = result.document\n",
" print(f\"{i}. ID: {doc.id}\")\n",
" print(f\" Text: {doc.text[:60]}...\")\n",
" print(f\" Similarity Score: {result.score:.4f}\")\n",
" print(f\" Category: {doc.attributes.get('category', 'N/A')}\")\n",
" print()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Test search by ID\n",
"try:\n",
" found_doc = vector_store.search_by_id(\"doc_2\")\n",
" print(\"✅ Found document by ID:\")\n",
" print(f\" ID: {found_doc.id}\")\n",
" print(f\" Text: {found_doc.text}\")\n",
" print(f\" Attributes: {found_doc.attributes}\")\n",
"except KeyError as e:\n",
" print(f\"❌ Error: {e}\")\n",
"\n",
"# Test filter by ID\n",
"id_filter = vector_store.filter_by_id([\"doc_1\", \"doc_3\"])\n",
"print(f\"\\n🔧 ID filter result: {id_filter}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Step 6: Configuration for GraphRAG\n",
"\n",
"Now let's see how you would configure GraphRAG to use your custom vector store in a settings file."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Example GraphRAG yaml settings\n",
"example_settings = {\n",
" \"vector_store\": {\n",
" \"default_vector_store\": {\n",
" \"type\": CUSTOM_VECTOR_STORE_TYPE, # \"simple_memory\"\n",
" \"collection_name\": \"graphrag_entities\",\n",
" # Add any custom parameters your vector store needs\n",
" \"custom_parameter\": \"custom_value\",\n",
" }\n",
" },\n",
" # Other GraphRAG configuration...\n",
" \"models\": {\n",
" \"default_embedding_model\": {\n",
" \"type\": \"openai_embedding\",\n",
" \"model\": \"text-embedding-3-small\",\n",
" }\n",
" },\n",
"}\n",
"\n",
"# Convert to YAML format for settings.yml\n",
"yaml_config = yaml.dump(example_settings, default_flow_style=False, indent=2)\n",
"\n",
"print(\"📄 Example settings.yml configuration:\")\n",
"print(\"=\" * 40)\n",
"print(yaml_config)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Step 7: Integration with GraphRAG Pipeline\n",
"\n",
"Here's how your custom vector store would be used in a typical GraphRAG pipeline."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Example of how GraphRAG would use your custom vector store\n",
"def simulate_graphrag_pipeline():\n",
" \"\"\"Simulate how GraphRAG would use the custom vector store.\"\"\"\n",
" print(\"🚀 Simulating GraphRAG pipeline with custom vector store...\\n\")\n",
"\n",
" # 1. GraphRAG creates vector store using factory\n",
" schema = VectorStoreSchemaConfig(index_name=\"graphrag_entities\")\n",
"\n",
" store = VectorStoreFactory.create_vector_store(\n",
" CUSTOM_VECTOR_STORE_TYPE,\n",
" vector_store_schema_config=schema,\n",
" similarity_threshold=0.3,\n",
" )\n",
" store.connect()\n",
"\n",
" print(\"✅ Step 1: Vector store created and connected\")\n",
"\n",
" # 2. During indexing, GraphRAG loads extracted entities\n",
" entity_documents = [\n",
" VectorStoreDocument(\n",
" id=f\"entity_{i}\",\n",
" text=f\"Entity {i} description: Important concept in the knowledge graph\",\n",
" vector=create_mock_embedding(),\n",
" attributes={\"type\": \"entity\", \"importance\": i % 3 + 1},\n",
" )\n",
" for i in range(10)\n",
" ]\n",
"\n",
" store.load_documents(entity_documents)\n",
" print(f\"✅ Step 2: Loaded {len(entity_documents)} entity documents\")\n",
"\n",
" # 3. During query time, GraphRAG searches for relevant entities\n",
" query_embedding = create_mock_embedding()\n",
" relevant_entities = store.similarity_search_by_vector(query_embedding, k=5)\n",
"\n",
" print(f\"✅ Step 3: Found {len(relevant_entities)} relevant entities for query\")\n",
"\n",
" # 4. GraphRAG uses these entities for context building\n",
" context_entities = [result.document for result in relevant_entities]\n",
"\n",
" print(\"✅ Step 4: Context built using retrieved entities\")\n",
" print(f\"📊 Final stats: {store.get_stats()}\")\n",
"\n",
" return context_entities\n",
"\n",
"\n",
"# Run the simulation\n",
"context = simulate_graphrag_pipeline()\n",
"print(f\"\\n🎯 Retrieved {len(context)} entities for context building\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Step 8: Testing and Validation\n",
"\n",
"Let's create a comprehensive test suite to ensure our vector store works correctly."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def test_custom_vector_store():\n",
" \"\"\"Comprehensive test suite for the custom vector store.\"\"\"\n",
" print(\"🧪 Running comprehensive vector store tests...\\n\")\n",
"\n",
" # Test 1: Basic functionality\n",
" print(\"Test 1: Basic functionality\")\n",
" store = VectorStoreFactory.create_vector_store(\n",
" CUSTOM_VECTOR_STORE_TYPE,\n",
" vector_store_schema_config=VectorStoreSchemaConfig(index_name=\"test\"),\n",
" )\n",
" store.connect()\n",
"\n",
" # Load test documents\n",
" test_docs = sample_documents[:2]\n",
" store.load_documents(test_docs)\n",
"\n",
" assert len(store.documents) == 2, \"Should have 2 documents\"\n",
" assert len(store.vectors) == 2, \"Should have 2 vectors\"\n",
" print(\"✅ Basic functionality test passed\")\n",
"\n",
" # Test 2: Search functionality\n",
" print(\"\\nTest 2: Search functionality\")\n",
" query_vec = create_mock_embedding()\n",
" results = store.similarity_search_by_vector(query_vec, k=5)\n",
"\n",
" assert len(results) <= 2, \"Should not return more results than documents\"\n",
" assert all(isinstance(r, VectorStoreSearchResult) for r in results), (\n",
" \"Should return VectorStoreSearchResult objects\"\n",
" )\n",
" assert all(-1 <= r.score <= 1 for r in results), (\n",
" \"Similarity scores should be between -1 and 1\"\n",
" )\n",
" print(\"✅ Search functionality test passed\")\n",
"\n",
" # Test 3: Search by ID\n",
" print(\"\\nTest 3: Search by ID\")\n",
" found_doc = store.search_by_id(\"doc_1\")\n",
" assert found_doc.id == \"doc_1\", \"Should find correct document\"\n",
"\n",
" try:\n",
" store.search_by_id(\"nonexistent\")\n",
" assert False, \"Should raise KeyError for nonexistent ID\"\n",
" except KeyError:\n",
" pass # Expected\n",
"\n",
" print(\"✅ Search by ID test passed\")\n",
"\n",
" # Test 4: Filter functionality\n",
" print(\"\\nTest 4: Filter functionality\")\n",
" filter_result = store.filter_by_id([\"doc_1\", \"doc_2\"])\n",
" assert filter_result == [\"doc_1\", \"doc_2\"], \"Should return filtered IDs\"\n",
" print(\"✅ Filter functionality test passed\")\n",
"\n",
" # Test 5: Error handling\n",
" print(\"\\nTest 5: Error handling\")\n",
" disconnected_store = VectorStoreFactory.create_vector_store(\n",
" CUSTOM_VECTOR_STORE_TYPE,\n",
" vector_store_schema_config=VectorStoreSchemaConfig(index_name=\"test2\"),\n",
" )\n",
"\n",
" try:\n",
" disconnected_store.load_documents(test_docs)\n",
" assert False, \"Should raise error when not connected\"\n",
" except RuntimeError:\n",
" pass # Expected\n",
"\n",
" try:\n",
" disconnected_store.similarity_search_by_vector(query_vec)\n",
" assert False, \"Should raise error when not connected\"\n",
" except RuntimeError:\n",
" pass # Expected\n",
"\n",
" print(\"✅ Error handling test passed\")\n",
"\n",
" print(\"\\n🎉 All tests passed! Your custom vector store is working correctly.\")\n",
"\n",
"\n",
"# Run the tests\n",
"test_custom_vector_store()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Summary and Next Steps\n",
"\n",
"Congratulations! You've successfully learned how to implement and register a custom vector store with GraphRAG. Here's what you accomplished:\n",
"\n",
"### What You Built\n",
"- ✅ **Custom Vector Store Class**: Implemented `SimpleInMemoryVectorStore` with all required methods\n",
"- ✅ **Factory Integration**: Registered your vector store with `VectorStoreFactory`\n",
"- ✅ **Comprehensive Testing**: Validated functionality with a full test suite\n",
"- ✅ **Configuration Examples**: Learned how to configure GraphRAG to use your vector store\n",
"\n",
"### Key Takeaways\n",
"1. **Interface Compliance**: Always implement all methods from `BaseVectorStore`\n",
"2. **Factory Pattern**: Use `VectorStoreFactory.register()` to make your vector store available\n",
"3. **Configuration**: Vector stores are configured in GraphRAG settings files\n",
"4. **Testing**: Thoroughly test all functionality before deploying\n",
"\n",
"### Next Steps\n",
"Check out the API Overview notebook to learn how to index and query data via the graphrag API.\n",
"\n",
"### Resources\n",
"- [GraphRAG Documentation](https://microsoft.github.io/graphrag/)\n",
"\n",
"Happy building! 🚀"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "graphrag",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.10"
}
},
"nbformat": 4,
"nbformat_minor": 4
}

View File

@ -20,11 +20,11 @@
"from pathlib import Path\n",
"\n",
"import pandas as pd\n",
"import tiktoken\n",
"\n",
"from graphrag.config.enums import ModelType\n",
"from graphrag.config.models.drift_search_config import DRIFTSearchConfig\n",
"from graphrag.config.models.language_model_config import LanguageModelConfig\n",
"from graphrag.config.models.vector_store_schema_config import VectorStoreSchemaConfig\n",
"from graphrag.language_model.manager import ModelManager\n",
"from graphrag.query.indexer_adapters import (\n",
" read_indexer_entities,\n",
@ -37,6 +37,7 @@
" DRIFTSearchContextBuilder,\n",
")\n",
"from graphrag.query.structured_search.drift_search.search import DRIFTSearch\n",
"from graphrag.tokenizer.get_tokenizer import get_tokenizer\n",
"from graphrag.vector_stores.lancedb import LanceDBVectorStore\n",
"\n",
"INPUT_DIR = \"./inputs/operation dulce\"\n",
@ -62,12 +63,16 @@
"# load description embeddings to an in-memory lancedb vectorstore\n",
"# to connect to a remote db, specify url and port values.\n",
"description_embedding_store = LanceDBVectorStore(\n",
" collection_name=\"default-entity-description\",\n",
" vector_store_schema_config=VectorStoreSchemaConfig(\n",
" index_name=\"default-entity-description\"\n",
" ),\n",
")\n",
"description_embedding_store.connect(db_uri=LANCEDB_URI)\n",
"\n",
"full_content_embedding_store = LanceDBVectorStore(\n",
" collection_name=\"default-community-full_content\",\n",
" vector_store_schema_config=VectorStoreSchemaConfig(\n",
" index_name=\"default-community-full_content\"\n",
" )\n",
")\n",
"full_content_embedding_store.connect(db_uri=LANCEDB_URI)\n",
"\n",
@ -94,33 +99,33 @@
"outputs": [],
"source": [
"api_key = os.environ[\"GRAPHRAG_API_KEY\"]\n",
"llm_model = os.environ[\"GRAPHRAG_LLM_MODEL\"]\n",
"embedding_model = os.environ[\"GRAPHRAG_EMBEDDING_MODEL\"]\n",
"\n",
"chat_config = LanguageModelConfig(\n",
" api_key=api_key,\n",
" type=ModelType.OpenAIChat,\n",
" model=llm_model,\n",
" type=ModelType.Chat,\n",
" model_provider=\"openai\",\n",
" model=\"gpt-4.1\",\n",
" max_retries=20,\n",
")\n",
"chat_model = ModelManager().get_or_create_chat_model(\n",
" name=\"local_search\",\n",
" model_type=ModelType.OpenAIChat,\n",
" model_type=ModelType.Chat,\n",
" config=chat_config,\n",
")\n",
"\n",
"token_encoder = tiktoken.encoding_for_model(llm_model)\n",
"tokenizer = get_tokenizer(chat_config)\n",
"\n",
"embedding_config = LanguageModelConfig(\n",
" api_key=api_key,\n",
" type=ModelType.OpenAIEmbedding,\n",
" model=embedding_model,\n",
" type=ModelType.Embedding,\n",
" model_provider=\"openai\",\n",
" model=\"text-embedding-3-small\",\n",
" max_retries=20,\n",
")\n",
"\n",
"text_embedder = ModelManager().get_or_create_embedding_model(\n",
" name=\"local_search_embedding\",\n",
" model_type=ModelType.OpenAIEmbedding,\n",
" model_type=ModelType.Embedding,\n",
" config=embedding_config,\n",
")"
]
@ -173,12 +178,12 @@
" reports=reports,\n",
" entity_text_embeddings=description_embedding_store,\n",
" text_units=text_units,\n",
" token_encoder=token_encoder,\n",
" tokenizer=tokenizer,\n",
" config=drift_params,\n",
")\n",
"\n",
"search = DRIFTSearch(\n",
" model=chat_model, context_builder=context_builder, token_encoder=token_encoder\n",
" model=chat_model, context_builder=context_builder, tokenizer=tokenizer\n",
")"
]
},
@ -212,7 +217,7 @@
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"display_name": "graphrag",
"language": "python",
"name": "python3"
},
@ -226,7 +231,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.9"
"version": "3.12.10"
}
},
"nbformat": 4,

View File

@ -19,7 +19,6 @@
"import os\n",
"\n",
"import pandas as pd\n",
"import tiktoken\n",
"\n",
"from graphrag.config.enums import ModelType\n",
"from graphrag.config.models.language_model_config import LanguageModelConfig\n",
@ -32,7 +31,8 @@
"from graphrag.query.structured_search.global_search.community_context import (\n",
" GlobalCommunityContext,\n",
")\n",
"from graphrag.query.structured_search.global_search.search import GlobalSearch"
"from graphrag.query.structured_search.global_search.search import GlobalSearch\n",
"from graphrag.tokenizer.get_tokenizer import get_tokenizer"
]
},
{
@ -58,21 +58,21 @@
"outputs": [],
"source": [
"api_key = os.environ[\"GRAPHRAG_API_KEY\"]\n",
"llm_model = os.environ[\"GRAPHRAG_LLM_MODEL\"]\n",
"\n",
"config = LanguageModelConfig(\n",
" api_key=api_key,\n",
" type=ModelType.OpenAIChat,\n",
" model=llm_model,\n",
" type=ModelType.Chat,\n",
" model_provider=\"openai\",\n",
" model=\"gpt-4.1\",\n",
" max_retries=20,\n",
")\n",
"model = ModelManager().get_or_create_chat_model(\n",
" name=\"global_search\",\n",
" model_type=ModelType.OpenAIChat,\n",
" model_type=ModelType.Chat,\n",
" config=config,\n",
")\n",
"\n",
"token_encoder = tiktoken.encoding_for_model(llm_model)"
"tokenizer = get_tokenizer(config)"
]
},
{
@ -142,7 +142,7 @@
" community_reports=reports,\n",
" communities=communities,\n",
" entities=entities, # default to None if you don't want to use community weights for ranking\n",
" token_encoder=token_encoder,\n",
" tokenizer=tokenizer,\n",
")"
]
},
@ -193,7 +193,7 @@
"search_engine = GlobalSearch(\n",
" model=model,\n",
" context_builder=context_builder,\n",
" token_encoder=token_encoder,\n",
" tokenizer=tokenizer,\n",
" max_data_tokens=12_000, # change this based on the token limit you have on your model (if you are using a model with 8k limit, a good setting could be 5000)\n",
" map_llm_params=map_llm_params,\n",
" reduce_llm_params=reduce_llm_params,\n",
@ -241,7 +241,7 @@
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"display_name": "graphrag",
"language": "python",
"name": "python3"
},
@ -255,7 +255,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.9"
"version": "3.12.10"
}
},
"nbformat": 4,

View File

@ -19,7 +19,6 @@
"import os\n",
"\n",
"import pandas as pd\n",
"import tiktoken\n",
"\n",
"from graphrag.config.enums import ModelType\n",
"from graphrag.config.models.language_model_config import LanguageModelConfig\n",
@ -57,22 +56,24 @@
"metadata": {},
"outputs": [],
"source": [
"from graphrag.tokenizer.get_tokenizer import get_tokenizer\n",
"\n",
"api_key = os.environ[\"GRAPHRAG_API_KEY\"]\n",
"llm_model = os.environ[\"GRAPHRAG_LLM_MODEL\"]\n",
"\n",
"config = LanguageModelConfig(\n",
" api_key=api_key,\n",
" type=ModelType.OpenAIChat,\n",
" model=llm_model,\n",
" type=ModelType.Chat,\n",
" model_provider=\"openai\",\n",
" model=\"gpt-4.1\",\n",
" max_retries=20,\n",
")\n",
"model = ModelManager().get_or_create_chat_model(\n",
" name=\"global_search\",\n",
" model_type=ModelType.OpenAIChat,\n",
" model_type=ModelType.Chat,\n",
" config=config,\n",
")\n",
"\n",
"token_encoder = tiktoken.encoding_for_model(llm_model)"
"tokenizer = get_tokenizer(config)"
]
},
{
@ -155,11 +156,11 @@
" community_reports=reports,\n",
" communities=communities,\n",
" entities=entities, # default to None if you don't want to use community weights for ranking\n",
" token_encoder=token_encoder,\n",
" tokenizer=tokenizer,\n",
" dynamic_community_selection=True,\n",
" dynamic_community_selection_kwargs={\n",
" \"model\": model,\n",
" \"token_encoder\": token_encoder,\n",
" \"tokenizer\": tokenizer,\n",
" },\n",
")"
]
@ -211,7 +212,7 @@
"search_engine = GlobalSearch(\n",
" model=model,\n",
" context_builder=context_builder,\n",
" token_encoder=token_encoder,\n",
" tokenizer=tokenizer,\n",
" max_data_tokens=12_000, # change this based on the token limit you have on your model (if you are using a model with 8k limit, a good setting could be 5000)\n",
" map_llm_params=map_llm_params,\n",
" reduce_llm_params=reduce_llm_params,\n",
@ -255,7 +256,7 @@
"prompt_tokens = result.prompt_tokens_categories[\"build_context\"]\n",
"output_tokens = result.output_tokens_categories[\"build_context\"]\n",
"print(\n",
" f\"Build context ({llm_model})\\nLLM calls: {llm_calls}. Prompt tokens: {prompt_tokens}. Output tokens: {output_tokens}.\"\n",
" f\"Build context LLM calls: {llm_calls}. Prompt tokens: {prompt_tokens}. Output tokens: {output_tokens}.\"\n",
")\n",
"# inspect number of LLM calls and tokens in map-reduce\n",
"llm_calls = result.llm_calls_categories[\"map\"] + result.llm_calls_categories[\"reduce\"]\n",
@ -266,14 +267,14 @@
" result.output_tokens_categories[\"map\"] + result.output_tokens_categories[\"reduce\"]\n",
")\n",
"print(\n",
" f\"Map-reduce ({llm_model})\\nLLM calls: {llm_calls}. Prompt tokens: {prompt_tokens}. Output tokens: {output_tokens}.\"\n",
" f\"Map-reduce LLM calls: {llm_calls}. Prompt tokens: {prompt_tokens}. Output tokens: {output_tokens}.\"\n",
")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"display_name": "graphrag",
"language": "python",
"name": "python3"
},
@ -287,7 +288,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.9"
"version": "3.12.10"
}
},
"nbformat": 4,

View File

@ -0,0 +1,194 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Copyright (c) 2024 Microsoft Corporation.\n",
"# Licensed under the MIT License."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Example of indexing from an existing in-memory dataframe\n",
"\n",
"Newer versions of GraphRAG let you submit a dataframe directly instead of running through the input processing step. This notebook demonstrates with regular or update runs.\n",
"\n",
"If performing an update, the assumption is that your dataframe contains only the new documents to add to the index."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from pathlib import Path\n",
"from pprint import pprint\n",
"\n",
"import pandas as pd\n",
"\n",
"import graphrag.api as api\n",
"from graphrag.config.load_config import load_config\n",
"from graphrag.index.typing.pipeline_run_result import PipelineRunResult"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"PROJECT_DIRECTORY = \"<your project directory>\"\n",
"UPDATE = False\n",
"FILENAME = \"new_documents.parquet\" if UPDATE else \"<original_documents>.parquet\"\n",
"inputs = pd.read_parquet(f\"{PROJECT_DIRECTORY}/input/{FILENAME}\")\n",
"# Only the bare minimum for input. These are the same fields that would be present after the load_input_documents workflow\n",
"inputs = inputs.loc[:, [\"id\", \"title\", \"text\", \"creation_date\"]]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Generate a `GraphRagConfig` object"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"graphrag_config = load_config(Path(PROJECT_DIRECTORY))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Indexing API\n",
"\n",
"*Indexing* is the process of ingesting raw text data and constructing a knowledge graph. GraphRAG currently supports plaintext (`.txt`) and `.csv` file formats."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Build an index"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"index_result: list[PipelineRunResult] = await api.build_index(\n",
" config=graphrag_config, input_documents=inputs, is_update_run=UPDATE\n",
")\n",
"\n",
"# index_result is a list of workflows that make up the indexing pipeline that was run\n",
"for workflow_result in index_result:\n",
" status = f\"error\\n{workflow_result.errors}\" if workflow_result.errors else \"success\"\n",
" print(f\"Workflow Name: {workflow_result.workflow}\\tStatus: {status}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": []
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Query an index\n",
"\n",
"To query an index, several index files must first be read into memory and passed to the query API. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"entities = pd.read_parquet(f\"{PROJECT_DIRECTORY}/output/entities.parquet\")\n",
"communities = pd.read_parquet(f\"{PROJECT_DIRECTORY}/output/communities.parquet\")\n",
"community_reports = pd.read_parquet(\n",
" f\"{PROJECT_DIRECTORY}/output/community_reports.parquet\"\n",
")\n",
"\n",
"response, context = await api.global_search(\n",
" config=graphrag_config,\n",
" entities=entities,\n",
" communities=communities,\n",
" community_reports=community_reports,\n",
" community_level=2,\n",
" dynamic_community_selection=False,\n",
" response_type=\"Multiple Paragraphs\",\n",
" query=\"What are the top five themes of the dataset?\",\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The response object is the official reponse from graphrag while the context object holds various metadata regarding the querying process used to obtain the final response."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"print(response)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Digging into the context a bit more provides users with extremely granular information such as what sources of data (down to the level of text chunks) were ultimately retrieved and used as part of the context sent to the LLM model)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"pprint(context) # noqa: T203"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "graphrag",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.10"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

@ -19,8 +19,8 @@
"import os\n",
"\n",
"import pandas as pd\n",
"import tiktoken\n",
"\n",
"from graphrag.config.models.vector_store_schema_config import VectorStoreSchemaConfig\n",
"from graphrag.query.context_builder.entity_extraction import EntityVectorStoreKey\n",
"from graphrag.query.indexer_adapters import (\n",
" read_indexer_covariates,\n",
@ -102,7 +102,9 @@
"# load description embeddings to an in-memory lancedb vectorstore\n",
"# to connect to a remote db, specify url and port values.\n",
"description_embedding_store = LanceDBVectorStore(\n",
" collection_name=\"default-entity-description\",\n",
" vector_store_schema_config=VectorStoreSchemaConfig(\n",
" index_name=\"default-entity-description\"\n",
" )\n",
")\n",
"description_embedding_store.connect(db_uri=LANCEDB_URI)\n",
"\n",
@ -195,37 +197,38 @@
"from graphrag.config.enums import ModelType\n",
"from graphrag.config.models.language_model_config import LanguageModelConfig\n",
"from graphrag.language_model.manager import ModelManager\n",
"from graphrag.tokenizer.get_tokenizer import get_tokenizer\n",
"\n",
"api_key = os.environ[\"GRAPHRAG_API_KEY\"]\n",
"llm_model = os.environ[\"GRAPHRAG_LLM_MODEL\"]\n",
"embedding_model = os.environ[\"GRAPHRAG_EMBEDDING_MODEL\"]\n",
"\n",
"chat_config = LanguageModelConfig(\n",
" api_key=api_key,\n",
" type=ModelType.OpenAIChat,\n",
" model=llm_model,\n",
" type=ModelType.Chat,\n",
" model_provider=\"openai\",\n",
" model=\"gpt-4.1\",\n",
" max_retries=20,\n",
")\n",
"chat_model = ModelManager().get_or_create_chat_model(\n",
" name=\"local_search\",\n",
" model_type=ModelType.OpenAIChat,\n",
" model_type=ModelType.Chat,\n",
" config=chat_config,\n",
")\n",
"\n",
"token_encoder = tiktoken.encoding_for_model(llm_model)\n",
"\n",
"embedding_config = LanguageModelConfig(\n",
" api_key=api_key,\n",
" type=ModelType.OpenAIEmbedding,\n",
" model=embedding_model,\n",
" type=ModelType.Embedding,\n",
" model_provider=\"openai\",\n",
" model=\"text-embedding-3-small\",\n",
" max_retries=20,\n",
")\n",
"\n",
"text_embedder = ModelManager().get_or_create_embedding_model(\n",
" name=\"local_search_embedding\",\n",
" model_type=ModelType.OpenAIEmbedding,\n",
" model_type=ModelType.Embedding,\n",
" config=embedding_config,\n",
")"
")\n",
"\n",
"tokenizer = get_tokenizer(chat_config)"
]
},
{
@ -251,7 +254,7 @@
" entity_text_embeddings=description_embedding_store,\n",
" embedding_vectorstore_key=EntityVectorStoreKey.ID, # if the vectorstore uses entity title as ids, set this to EntityVectorStoreKey.TITLE\n",
" text_embedder=text_embedder,\n",
" token_encoder=token_encoder,\n",
" tokenizer=tokenizer,\n",
")"
]
},
@ -314,7 +317,7 @@
"search_engine = LocalSearch(\n",
" model=chat_model,\n",
" context_builder=context_builder,\n",
" token_encoder=token_encoder,\n",
" tokenizer=tokenizer,\n",
" model_params=model_params,\n",
" context_builder_params=local_context_params,\n",
" response_type=\"multiple paragraphs\", # free form text describing the response type and format, can be anything, e.g. prioritized list, single paragraph, multiple paragraphs, multiple-page report\n",
@ -426,7 +429,7 @@
"question_generator = LocalQuestionGen(\n",
" model=chat_model,\n",
" context_builder=context_builder,\n",
" token_encoder=token_encoder,\n",
" tokenizer=tokenizer,\n",
" model_params=model_params,\n",
" context_builder_params=local_context_params,\n",
")"
@ -451,7 +454,7 @@
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"display_name": "graphrag",
"language": "python",
"name": "python3"
},
@ -465,7 +468,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.9"
"version": "3.12.10"
}
},
"nbformat": 4,

View File

@ -1,5 +1,7 @@
# Getting Started
⚠️ GraphRAG can consume a lot of LLM resources! We strongly recommend starting with the tutorial dataset here until you understand how the system works, and consider experimenting with fast/inexpensive models first before committing to a big indexing job.
## Requirements
[Python 3.10-3.12](https://www.python.org/downloads/)
@ -24,25 +26,25 @@ pip install graphrag
We need to set up a data project and some initial configuration. First let's get a sample dataset ready:
```sh
mkdir -p ./ragtest/input
mkdir -p ./christmas/input
```
Get a copy of A Christmas Carol by Charles Dickens from a trusted source:
```sh
curl https://www.gutenberg.org/cache/epub/24022/pg24022.txt -o ./ragtest/input/book.txt
curl https://www.gutenberg.org/cache/epub/24022/pg24022.txt -o ./christmas/input/book.txt
```
## Set Up Your Workspace Variables
To initialize your workspace, first run the `graphrag init` command.
Since we have already configured a directory named `./ragtest` in the previous step, run the following command:
Since we have already configured a directory named `./christmas` in the previous step, run the following command:
```sh
graphrag init --root ./ragtest
graphrag init --root ./christmas
```
This will create two files: `.env` and `settings.yaml` in the `./ragtest` directory.
This will create two files: `.env` and `settings.yaml` in the `./christmas` directory.
- `.env` contains the environment variables required to run the GraphRAG pipeline. If you inspect the file, you'll see a single environment variable defined,
`GRAPHRAG_API_KEY=<API_KEY>`. Replace `<API_KEY>` with your own OpenAI or Azure API key.
@ -65,11 +67,10 @@ deployment_name: <azure_model_deployment_name>
```
#### Using Managed Auth on Azure
To use managed auth, add an additional value to your model config and comment out or remove the api_key line:
To use managed auth, edit the auth_type in your model config and *remove* the api_key line:
```yaml
auth_type: azure_managed_identity # Default auth_type is is api_key
# api_key: ${GRAPHRAG_API_KEY}
```
You will also need to login with [az login](https://learn.microsoft.com/en-us/cli/azure/authenticate-azure-cli) and select the subscription with your endpoint.
@ -79,13 +80,13 @@ You will also need to login with [az login](https://learn.microsoft.com/en-us/cl
Finally we'll run the pipeline!
```sh
graphrag index --root ./ragtest
graphrag index --root ./christmas
```
![pipeline executing from the CLI](img/pipeline-running.png)
This process will take some time to run. This depends on the size of your input data, what model you're using, and the text chunk size being used (these can be configured in your `settings.yaml` file).
Once the pipeline is complete, you should see a new folder called `./ragtest/output` with a series of parquet files.
Once the pipeline is complete, you should see a new folder called `./christmas/output` with a series of parquet files.
# Using the Query Engine
@ -95,7 +96,7 @@ Here is an example using Global search to ask a high-level question:
```sh
graphrag query \
--root ./ragtest \
--root ./christmas \
--method global \
--query "What are the top themes in this story?"
```
@ -104,7 +105,7 @@ Here is an example using Local search to ask a more specific question about a pa
```sh
graphrag query \
--root ./ragtest \
--root ./christmas \
--method local \
--query "Who is Scrooge and what are his main relationships?"
```
@ -116,4 +117,4 @@ Please refer to [Query Engine](query/overview.md) docs for detailed information
- For more details about configuring GraphRAG, see the [configuration documentation](config/overview.md).
- To learn more about Initialization, refer to the [Initialization documentation](config/init.md).
- For more details about using the CLI, refer to the [CLI documentation](cli.md).
- Check out our [visualization guide](visualization_guide.md) for a more interactive experience in debugging and exploring the knowledge graph.
- Check out our [visualization guide](visualization_guide.md) for a more interactive experience in debugging and exploring the knowledge graph.

View File

@ -47,6 +47,7 @@ At query time, these structures are used to provide materials for the LLM contex
- [_Global Search_](query/global_search.md) for reasoning about holistic questions about the corpus by leveraging the community summaries.
- [_Local Search_](query/local_search.md) for reasoning about specific entities by fanning-out to their neighbors and associated concepts.
- [_DRIFT Search_](query/drift_search.md) for reasoning about specific entities by fanning-out to their neighbors and associated concepts, but with the added context of community information.
- _Basic Search_ for those times when your query is best answered by baseline RAG (standard top _k_ vector search).
### Prompt Tuning

View File

@ -32,3 +32,20 @@ The GraphRAG library was designed with LLM interactions in mind, and a common se
Because of these potential error cases, we've added a cache layer around LLM interactions.
When completion requests are made using the same input set (prompt and tuning parameters), we return a cached result if one exists.
This allows our indexer to be more resilient to network issues, to act idempotently, and to provide a more efficient end-user experience.
### Providers & Factories
Several subsystems within GraphRAG use a factory pattern to register and retrieve provider implementations. This allows deep customization to support models, storage, and so on that you may use but isn't built directly into GraphRAG.
The following subsystems use a factory pattern that allows you to register your own implementations:
- [language model](https://github.com/microsoft/graphrag/blob/main/graphrag/language_model/factory.py) - implement your own `chat` and `embed` methods to use a model provider of choice beyond the built-in OpenAI/Azure support
- [cache](https://github.com/microsoft/graphrag/blob/main/graphrag/cache/factory.py) - create your own cache storage location in addition to the file, blob, and CosmosDB ones we provide
- [logger](https://github.com/microsoft/graphrag/blob/main/graphrag/logger/factory.py) - create your own log writing location in addition to the built-in file and blob storage
- [storage](https://github.com/microsoft/graphrag/blob/main/graphrag/storage/factory.py) - create your own storage provider (database, etc.) beyond the file, blob, and CosmosDB ones built in
- [vector store](https://github.com/microsoft/graphrag/blob/main/graphrag/vector_stores/factory.py) - implement your own vector store other than the built-in lancedb, Azure AI Search, and CosmosDB ones built in
- [pipeline + workflows](https://github.com/microsoft/graphrag/blob/main/graphrag/index/workflows/factory.py) - implement your own workflow steps with a custom `run_workflow` function, or register an entire pipeline (list of named workflows)
The links for each of these subsystems point to the source code of the factory, which includes registration of the default built-in implementations. In addition, we have a detailed discussion of [language models](../config/models.md), which includes and example of a custom provider, and a [sample notebook](../examples_notebooks/custom_vector_store.ipynb) that demonstrates a custom vector store.
All of these factories allow you to register an impl using any string name you would like, even overriding built-in ones directly.

View File

@ -16,6 +16,10 @@ All input formats are loaded within GraphRAG and passed to the indexing pipeline
Also see the [outputs](outputs.md) documentation for the final documents table schema saved to parquet after pipeline completion.
## Bring-your-own DataFrame
As of version 2.6.0, GraphRAG's [indexing API method](https://github.com/microsoft/graphrag/blob/main/graphrag/api/index.py) allows you to pass in your own pandas DataFrame and bypass all of the input loading/parsing described in the next section. This is convenient if you have content in a format or storage location we don't support out-of-the-box. __You must ensure that your input DataFrame conforms to the schema described above.__ All of the chunking behavior described later will proceed exactly the same.
## Formats
We support three file formats out-of-the-box. This covers the overwhelming majority of use cases we have encountered. If you have a different format, we recommend writing a script to convert to one of these, which are widely used and supported by many tools and libraries.

View File

@ -58,13 +58,13 @@ graphrag prompt-tune [--root ROOT] [--config CONFIG] [--domain DOMAIN] [--selec
```bash
python -m graphrag prompt-tune --root /path/to/project --config /path/to/settings.yaml --domain "environmental news" \
--selection-method random --limit 10 --language English --max-tokens 2048 --chunk-size 256 --min-examples-required 3 \
--no-entity-types --output /path/to/output
--no-discover-entity-types --output /path/to/output
```
or, with minimal configuration (suggested):
```bash
python -m graphrag prompt-tune --root /path/to/project --config /path/to/settings.yaml --no-entity-types
python -m graphrag prompt-tune --root /path/to/project --config /path/to/settings.yaml --no-discover-entity-types
```
## Document Selection Methods
@ -79,15 +79,7 @@ After that, it uses one of the following selection methods to pick a sample to w
## Modify Env Vars
After running auto tuning, you should modify the following environment variables (or config variables) to pick up the new prompts on your index run. Note: Please make sure to update the correct path to the generated prompts, in this example we are using the default "prompts" path.
- `GRAPHRAG_ENTITY_EXTRACTION_PROMPT_FILE` = "prompts/entity_extraction.txt"
- `GRAPHRAG_COMMUNITY_REPORT_PROMPT_FILE` = "prompts/community_report.txt"
- `GRAPHRAG_SUMMARIZE_DESCRIPTIONS_PROMPT_FILE` = "prompts/summarize_descriptions.txt"
or in your yaml config file:
After running auto tuning, you should modify the following config variables to pick up the new prompts on your index run. Note: Please make sure to update the correct path to the generated prompts, in this example we are using the default "prompts" path.
```yaml
entity_extraction:

View File

@ -24,7 +24,7 @@ Below are the key parameters of the [DRIFTSearch class](https://github.com/micro
- `llm`: OpenAI model object to be used for response generation
- `context_builder`: [context builder](https://github.com/microsoft/graphrag/blob/main/graphrag/query/structured_search/drift_search/drift_context.py) object to be used for preparing context data from community reports and query information
- `config`: model to define the DRIFT Search hyperparameters. [DRIFT Config model](https://github.com/microsoft/graphrag/blob/main/graphrag/config/models/drift_search_config.py)
- `token_encoder`: token encoder for tracking the budget for the algorithm.
- `tokenizer`: token encoder for tracking the budget for the algorithm.
- `query_state`: a state object as defined in [Query State](https://github.com/microsoft/graphrag/blob/main/graphrag/query/structured_search/drift_search/state.py) that allows to track execution of a DRIFT Search instance, alongside follow ups and [DRIFT actions](https://github.com/microsoft/graphrag/blob/main/graphrag/query/structured_search/drift_search/action.py).
## How to Use

View File

@ -1,6 +1,7 @@
# API Notebooks
- [API Overview Notebook](../../examples_notebooks/api_overview.ipynb)
- [Bring-Your-Own Vector Store](../../examples_notebooks/custom_vector_store.ipynb)
# Query Engine Notebooks

View File

@ -2,10 +2,3 @@
# Licensed under the MIT License
"""The GraphRAG package."""
import logging
from graphrag.logger.standard_logging import init_console_logger
logger = logging.getLogger(__name__)
init_console_logger()

View File

@ -11,6 +11,8 @@ Backwards compatibility is not guaranteed at this time.
import logging
from typing import Any
import pandas as pd
from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
from graphrag.config.enums import IndexingMethod
@ -18,7 +20,6 @@ from graphrag.config.models.graph_rag_config import GraphRagConfig
from graphrag.index.run.run_pipeline import run_pipeline
from graphrag.index.run.utils import create_callback_chain
from graphrag.index.typing.pipeline_run_result import PipelineRunResult
from graphrag.index.typing.workflow import WorkflowFunction
from graphrag.index.workflows.factory import PipelineFactory
from graphrag.logger.standard_logging import init_loggers
@ -32,6 +33,8 @@ async def build_index(
memory_profile: bool = False,
callbacks: list[WorkflowCallbacks] | None = None,
additional_context: dict[str, Any] | None = None,
verbose: bool = False,
input_documents: pd.DataFrame | None = None,
) -> list[PipelineRunResult]:
"""Run the pipeline with the given configuration.
@ -47,13 +50,15 @@ async def build_index(
A list of callbacks to register.
additional_context : dict[str, Any] | None default=None
Additional context to pass to the pipeline run. This can be accessed in the pipeline state under the 'additional_context' key.
input_documents : pd.DataFrame | None default=None.
Override document loading and parsing and supply your own dataframe of documents to index.
Returns
-------
list[PipelineRunResult]
The list of pipeline run results
"""
init_loggers(config=config)
init_loggers(config=config, verbose=verbose)
# Create callbacks for pipeline lifecycle events if provided
workflow_callbacks = (
@ -78,6 +83,7 @@ async def build_index(
callbacks=workflow_callbacks,
is_update_run=is_update_run,
additional_context=additional_context,
input_documents=input_documents,
):
outputs.append(output)
if output.errors and len(output.errors) > 0:
@ -90,11 +96,6 @@ async def build_index(
return outputs
def register_workflow_function(name: str, workflow: WorkflowFunction):
"""Register a custom workflow function. You can then include the name in the settings.yaml workflows list."""
PipelineFactory.register(name, workflow)
def _get_method(method: IndexingMethod | str, is_update_run: bool) -> str:
m = method.value if isinstance(method, IndexingMethod) else method
return f"{m}-update" if is_update_run else m

View File

@ -47,6 +47,7 @@ from graphrag.prompt_tune.generator.language import detect_language
from graphrag.prompt_tune.generator.persona import generate_persona
from graphrag.prompt_tune.loader.input import load_docs_in_chunks
from graphrag.prompt_tune.types import DocSelectionType
from graphrag.tokenizer.get_tokenizer import get_tokenizer
logger = logging.getLogger(__name__)
@ -67,6 +68,7 @@ async def generate_indexing_prompts(
min_examples_required: PositiveInt = 2,
n_subset_max: PositiveInt = 300,
k: PositiveInt = 15,
verbose: bool = False,
) -> tuple[str, str, str]:
"""Generate indexing prompts.
@ -89,7 +91,7 @@ async def generate_indexing_prompts(
-------
tuple[str, str, str]: entity extraction prompt, entity summarization prompt, community summarization prompt
"""
init_loggers(config=config)
init_loggers(config=config, verbose=verbose, filename="prompt-tuning.log")
# Retrieve documents
logger.info("Chunking documents...")
@ -165,7 +167,7 @@ async def generate_indexing_prompts(
examples=examples,
language=language,
json_mode=False, # config.llm.model_supports_json should be used, but these prompts are used in non-json mode by the index engine
encoding_model=extract_graph_llm_settings.encoding_model,
tokenizer=get_tokenizer(model_config=extract_graph_llm_settings),
max_token_count=max_tokens,
min_examples_required=min_examples_required,
)

View File

@ -93,7 +93,7 @@ async def global_search(
-------
TODO: Document the search response type and format.
"""
init_loggers(config=config, verbose=verbose)
init_loggers(config=config, verbose=verbose, filename="query.log")
callbacks = callbacks or []
full_response = ""
@ -156,7 +156,7 @@ def global_search_streaming(
-------
TODO: Document the search response type and format.
"""
init_loggers(config=config, verbose=verbose)
init_loggers(config=config, verbose=verbose, filename="query.log")
communities_ = read_indexer_communities(communities, community_reports)
reports = read_indexer_reports(
@ -229,7 +229,11 @@ async def multi_index_global_search(
-------
TODO: Document the search response type and format.
"""
init_loggers(config=config, verbose=verbose)
init_loggers(config=config, verbose=verbose, filename="query.log")
logger.warning(
"Multi-index search is deprecated and will be removed in GraphRAG v3."
)
# Streaming not supported yet
if streaming:
@ -369,7 +373,7 @@ async def local_search(
-------
TODO: Document the search response type and format.
"""
init_loggers(config=config, verbose=verbose)
init_loggers(config=config, verbose=verbose, filename="query.log")
callbacks = callbacks or []
full_response = ""
@ -435,7 +439,7 @@ def local_search_streaming(
-------
TODO: Document the search response type and format.
"""
init_loggers(config=config, verbose=verbose)
init_loggers(config=config, verbose=verbose, filename="query.log")
vector_store_args = {}
for index, store in config.vector_store.items():
@ -508,8 +512,11 @@ async def multi_index_local_search(
-------
TODO: Document the search response type and format.
"""
init_loggers(config=config, verbose=verbose)
init_loggers(config=config, verbose=verbose, filename="query.log")
logger.warning(
"Multi-index search is deprecated and will be removed in GraphRAG v3."
)
# Streaming not supported yet
if streaming:
message = "Streaming not yet implemented for multi_index_local_search"
@ -730,7 +737,7 @@ async def drift_search(
-------
TODO: Document the search response type and format.
"""
init_loggers(config=config, verbose=verbose)
init_loggers(config=config, verbose=verbose, filename="query.log")
callbacks = callbacks or []
full_response = ""
@ -792,7 +799,7 @@ def drift_search_streaming(
-------
TODO: Document the search response type and format.
"""
init_loggers(config=config, verbose=verbose)
init_loggers(config=config, verbose=verbose, filename="query.log")
vector_store_args = {}
for index, store in config.vector_store.items():
@ -872,7 +879,11 @@ async def multi_index_drift_search(
-------
TODO: Document the search response type and format.
"""
init_loggers(config=config, verbose=verbose)
init_loggers(config=config, verbose=verbose, filename="query.log")
logger.warning(
"Multi-index search is deprecated and will be removed in GraphRAG v3."
)
# Streaming not supported yet
if streaming:
@ -1065,7 +1076,7 @@ async def basic_search(
-------
TODO: Document the search response type and format.
"""
init_loggers(config=config, verbose=verbose)
init_loggers(config=config, verbose=verbose, filename="query.log")
callbacks = callbacks or []
full_response = ""
@ -1111,7 +1122,7 @@ def basic_search_streaming(
-------
TODO: Document the search response type and format.
"""
init_loggers(config=config, verbose=verbose)
init_loggers(config=config, verbose=verbose, filename="query.log")
vector_store_args = {}
for index, store in config.vector_store.items():
@ -1119,7 +1130,7 @@ def basic_search_streaming(
msg = f"Vector Store Args: {redact(vector_store_args)}"
logger.debug(msg)
description_embedding_store = get_embedding_store(
embedding_store = get_embedding_store(
config_args=vector_store_args,
embedding_name=text_unit_text_embedding,
)
@ -1130,7 +1141,7 @@ def basic_search_streaming(
search_engine = get_basic_search_engine(
config=config,
text_units=read_indexer_text_units(text_units),
text_unit_embeddings=description_embedding_store,
text_unit_embeddings=embedding_store,
system_prompt=prompt,
callbacks=callbacks,
)
@ -1164,7 +1175,11 @@ async def multi_index_basic_search(
-------
TODO: Document the search response type and format.
"""
init_loggers(config=config, verbose=verbose)
init_loggers(config=config, verbose=verbose, filename="query.log")
logger.warning(
"Multi-index search is deprecated and will be removed in GraphRAG v3."
)
# Streaming not supported yet
if streaming:

View File

@ -1,23 +1,24 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""A module containing create_cache method definition."""
"""Factory functions for creating a cache."""
from __future__ import annotations
from typing import TYPE_CHECKING, ClassVar
from graphrag.config.enums import CacheType
from graphrag.storage.blob_pipeline_storage import create_blob_storage
from graphrag.storage.cosmosdb_pipeline_storage import create_cosmosdb_storage
from graphrag.storage.file_pipeline_storage import FilePipelineStorage
if TYPE_CHECKING:
from graphrag.cache.pipeline_cache import PipelineCache
from graphrag.cache.json_pipeline_cache import JsonPipelineCache
from graphrag.cache.memory_pipeline_cache import InMemoryCache
from graphrag.cache.noop_pipeline_cache import NoopPipelineCache
from graphrag.config.enums import CacheType
from graphrag.storage.blob_pipeline_storage import BlobPipelineStorage
from graphrag.storage.cosmosdb_pipeline_storage import CosmosDBPipelineStorage
from graphrag.storage.file_pipeline_storage import FilePipelineStorage
if TYPE_CHECKING:
from collections.abc import Callable
from graphrag.cache.pipeline_cache import PipelineCache
class CacheFactory:
@ -25,39 +26,90 @@ class CacheFactory:
Includes a method for users to register a custom cache implementation.
Configuration arguments are passed to each cache implementation as kwargs (where possible)
Configuration arguments are passed to each cache implementation as kwargs
for individual enforcement of required/optional arguments.
"""
cache_types: ClassVar[dict[str, type]] = {}
_registry: ClassVar[dict[str, Callable[..., PipelineCache]]] = {}
@classmethod
def register(cls, cache_type: str, cache: type):
"""Register a custom cache implementation."""
cls.cache_types[cache_type] = cache
def register(cls, cache_type: str, creator: Callable[..., PipelineCache]) -> None:
"""Register a custom cache implementation.
Args:
cache_type: The type identifier for the cache.
creator: A class or callable that creates an instance of PipelineCache.
"""
cls._registry[cache_type] = creator
@classmethod
def create_cache(
cls, cache_type: CacheType | str | None, root_dir: str, kwargs: dict
) -> PipelineCache:
"""Create or get a cache from the provided type."""
if not cache_type:
return NoopPipelineCache()
match cache_type:
case CacheType.none:
return NoopPipelineCache()
case CacheType.memory:
return InMemoryCache()
case CacheType.file:
return JsonPipelineCache(
FilePipelineStorage(root_dir=root_dir).child(kwargs["base_dir"])
)
case CacheType.blob:
return JsonPipelineCache(create_blob_storage(**kwargs))
case CacheType.cosmosdb:
return JsonPipelineCache(create_cosmosdb_storage(**kwargs))
case _:
if cache_type in cls.cache_types:
return cls.cache_types[cache_type](**kwargs)
msg = f"Unknown cache type: {cache_type}"
raise ValueError(msg)
def create_cache(cls, cache_type: str, kwargs: dict) -> PipelineCache:
"""Create a cache object from the provided type.
Args:
cache_type: The type of cache to create.
root_dir: The root directory for file-based caches.
kwargs: Additional keyword arguments for the cache constructor.
Returns
-------
A PipelineCache instance.
Raises
------
ValueError: If the cache type is not registered.
"""
if cache_type not in cls._registry:
msg = f"Unknown cache type: {cache_type}"
raise ValueError(msg)
return cls._registry[cache_type](**kwargs)
@classmethod
def get_cache_types(cls) -> list[str]:
"""Get the registered cache implementations."""
return list(cls._registry.keys())
@classmethod
def is_supported_type(cls, cache_type: str) -> bool:
"""Check if the given cache type is supported."""
return cache_type in cls._registry
# --- register built-in cache implementations ---
def create_file_cache(root_dir: str, base_dir: str, **kwargs) -> PipelineCache:
"""Create a file-based cache implementation."""
# Create storage with base_dir in kwargs since FilePipelineStorage expects it there
storage_kwargs = {"base_dir": root_dir, **kwargs}
storage = FilePipelineStorage(**storage_kwargs).child(base_dir)
return JsonPipelineCache(storage)
def create_blob_cache(**kwargs) -> PipelineCache:
"""Create a blob storage-based cache implementation."""
storage = BlobPipelineStorage(**kwargs)
return JsonPipelineCache(storage)
def create_cosmosdb_cache(**kwargs) -> PipelineCache:
"""Create a CosmosDB-based cache implementation."""
storage = CosmosDBPipelineStorage(**kwargs)
return JsonPipelineCache(storage)
def create_noop_cache(**_kwargs) -> PipelineCache:
"""Create a no-op cache implementation."""
return NoopPipelineCache()
def create_memory_cache(**kwargs) -> PipelineCache:
"""Create a memory cache implementation."""
return InMemoryCache(**kwargs)
# --- register built-in cache implementations ---
CacheFactory.register(CacheType.none.value, create_noop_cache)
CacheFactory.register(CacheType.memory.value, create_memory_cache)
CacheFactory.register(CacheType.file.value, create_file_cache)
CacheFactory.register(CacheType.blob.value, create_blob_cache)
CacheFactory.register(CacheType.cosmosdb.value, create_cosmosdb_cache)

View File

@ -0,0 +1,46 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""A logger that emits updates from the indexing engine to the console."""
from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks
from graphrag.index.typing.pipeline_run_result import PipelineRunResult
from graphrag.logger.progress import Progress
# ruff: noqa: T201
class ConsoleWorkflowCallbacks(NoopWorkflowCallbacks):
"""A logger that writes to a console."""
_verbose = False
def __init__(self, verbose=False):
self._verbose = verbose
def pipeline_start(self, names: list[str]) -> None:
"""Execute this callback to signal when the entire pipeline starts."""
print("Starting pipeline with workflows:", ", ".join(names))
def pipeline_end(self, results: list[PipelineRunResult]) -> None:
"""Execute this callback to signal when the entire pipeline ends."""
print("Pipeline complete")
def workflow_start(self, name: str, instance: object) -> None:
"""Execute this callback when a workflow starts."""
print(f"Starting workflow: {name}")
def workflow_end(self, name: str, instance: object) -> None:
"""Execute this callback when a workflow ends."""
print("") # account for potential return on prior progress
print(f"Workflow complete: {name}")
if self._verbose:
print(instance)
def progress(self, progress: Progress) -> None:
"""Handle when progress occurs."""
complete = progress.completed_items or 0
total = progress.total_items or 1
percent = round((complete / total) * 100)
start = f" {complete} / {total} "
print(f"{start:{'.'}<{percent}}", flush=True, end="\r")

View File

@ -10,7 +10,8 @@ import warnings
from pathlib import Path
import graphrag.api as api
from graphrag.config.enums import CacheType, IndexingMethod, ReportingType
from graphrag.callbacks.console_workflow_callbacks import ConsoleWorkflowCallbacks
from graphrag.config.enums import CacheType, IndexingMethod
from graphrag.config.load_config import load_config
from graphrag.index.validate_config import validate_config_names
from graphrag.utils.cli import redact
@ -115,24 +116,12 @@ def _run_index(
# Initialize loggers and reporting config
init_loggers(
config=config,
root_dir=str(config.root_dir) if config.root_dir else None,
verbose=verbose,
)
if not cache:
config.cache.type = CacheType.none
# Log the configuration details
if config.reporting.type == ReportingType.file:
log_dir = Path(config.root_dir or "") / (config.reporting.base_dir or "")
log_path = log_dir / "logs.txt"
logger.info("Logging enabled at %s", log_path)
else:
logger.info(
"Logging not enabled for config %s",
redact(config.model_dump()),
)
if not skip_validation:
validate_config_names(config)
@ -154,6 +143,8 @@ def _run_index(
method=method,
is_update_run=is_update_run,
memory_profile=memprofile,
callbacks=[ConsoleWorkflowCallbacks(verbose=verbose)],
verbose=verbose,
)
)
encountered_errors = any(

View File

@ -7,7 +7,6 @@ import logging
from pathlib import Path
import graphrag.api as api
from graphrag.config.enums import ReportingType
from graphrag.config.load_config import load_config
from graphrag.prompt_tune.generator.community_report_summarization import (
COMMUNITY_SUMMARIZATION_FILENAME,
@ -73,22 +72,13 @@ async def prompt_tune(
from graphrag.logger.standard_logging import init_loggers
# initialize loggers with config
init_loggers(
config=graph_config,
root_dir=str(root_path),
verbose=verbose,
)
init_loggers(config=graph_config, verbose=verbose, filename="prompt-tuning.log")
# log the configuration details
if graph_config.reporting.type == ReportingType.file:
log_dir = Path(root_path) / (graph_config.reporting.base_dir or "")
log_path = log_dir / "logs.txt"
logger.info("Logging enabled at %s", log_path)
else:
logger.info(
"Logging not enabled for config %s",
redact(graph_config.model_dump()),
)
logger.info("Starting prompt tune.")
logger.info(
"Using default configuration: %s",
redact(graph_config.model_dump()),
)
prompts = await api.generate_indexing_prompts(
config=graph_config,
@ -103,6 +93,7 @@ async def prompt_tune(
min_examples_required=min_examples_required,
n_subset_max=n_subset_max,
k=k,
verbose=verbose,
)
output_path = output.resolve()

View File

@ -4,7 +4,6 @@
"""CLI implementation of the query subcommand."""
import asyncio
import logging
import sys
from pathlib import Path
from typing import TYPE_CHECKING, Any
@ -19,8 +18,7 @@ from graphrag.utils.storage import load_table_from_storage, storage_has_table
if TYPE_CHECKING:
import pandas as pd
# Initialize standard logger
logger = logging.getLogger(__name__)
# ruff: noqa: T201
def run_global_search(
@ -61,10 +59,6 @@ def run_global_search(
final_community_reports_list = dataframe_dict["community_reports"]
index_names = dataframe_dict["index_names"]
logger.info(
"Running multi-index global search on indexes: %s",
dataframe_dict["index_names"],
)
response, context_data = asyncio.run(
api.multi_index_global_search(
config=config,
@ -80,11 +74,7 @@ def run_global_search(
verbose=verbose,
)
)
# log the full response at INFO level for user visibility but at DEBUG level in the API layer
logger.info("Query Response:\n%s", response)
# NOTE: we return the response and context data here purely as a complete demonstration of the API.
# External users should use the API directly to get the response and context data.
print(response)
return response, context_data
# Otherwise, call the Single-Index Global Search API
@ -118,9 +108,9 @@ def run_global_search(
verbose=verbose,
):
full_response += stream_chunk
print(stream_chunk, end="") # noqa: T201
sys.stdout.flush() # flush output buffer to display text immediately
print() # noqa: T201
print(stream_chunk, end="")
sys.stdout.flush()
print()
return full_response, context_data
return asyncio.run(run_streaming_search())
@ -138,11 +128,8 @@ def run_global_search(
verbose=verbose,
)
)
# log the full response at INFO level for user visibility but at DEBUG level in the API layer
logger.info("Global Search Response:\n%s", response)
print(response)
# NOTE: we return the response and context data here purely as a complete demonstration of the API.
# External users should use the API directly to get the response and context data.
return response, context_data
@ -188,11 +175,6 @@ def run_local_search(
final_relationships_list = dataframe_dict["relationships"]
index_names = dataframe_dict["index_names"]
logger.info(
"Running multi-index local search on indexes: %s",
dataframe_dict["index_names"],
)
# If any covariates tables are missing from any index, set the covariates list to None
if len(dataframe_dict["covariates"]) != dataframe_dict["num_indexes"]:
final_covariates_list = None
@ -216,11 +198,8 @@ def run_local_search(
verbose=verbose,
)
)
# log the full response at INFO level for user visibility but at DEBUG level in the API layer
logger.info("Local Search Response:\n%s", response)
print(response)
# NOTE: we return the response and context data here purely as a complete demonstration of the API.
# External users should use the API directly to get the response and context data.
return response, context_data
# Otherwise, call the Single-Index Local Search API
@ -259,9 +238,9 @@ def run_local_search(
verbose=verbose,
):
full_response += stream_chunk
print(stream_chunk, end="") # noqa: T201
sys.stdout.flush() # flush output buffer to display text immediately
print() # noqa: T201
print(stream_chunk, end="")
sys.stdout.flush()
print()
return full_response, context_data
return asyncio.run(run_streaming_search())
@ -281,11 +260,8 @@ def run_local_search(
verbose=verbose,
)
)
# log the full response at INFO level for user visibility but at DEBUG level in the API layer
logger.info("Local Search Response:\n%s", response)
print(response)
# NOTE: we return the response and context data here purely as a complete demonstration of the API.
# External users should use the API directly to get the response and context data.
return response, context_data
@ -329,11 +305,6 @@ def run_drift_search(
final_relationships_list = dataframe_dict["relationships"]
index_names = dataframe_dict["index_names"]
logger.info(
"Running multi-index drift search on indexes: %s",
dataframe_dict["index_names"],
)
response, context_data = asyncio.run(
api.multi_index_drift_search(
config=config,
@ -350,11 +321,8 @@ def run_drift_search(
verbose=verbose,
)
)
# log the full response at INFO level for user visibility but at DEBUG level in the API layer
logger.info("DRIFT Search Response:\n%s", response)
print(response)
# NOTE: we return the response and context data here purely as a complete demonstration of the API.
# External users should use the API directly to get the response and context data.
return response, context_data
# Otherwise, call the Single-Index Drift Search API
@ -391,9 +359,9 @@ def run_drift_search(
verbose=verbose,
):
full_response += stream_chunk
print(stream_chunk, end="") # noqa: T201
sys.stdout.flush() # flush output buffer to display text immediately
print() # noqa: T201
print(stream_chunk, end="")
sys.stdout.flush()
print()
return full_response, context_data
return asyncio.run(run_streaming_search())
@ -413,11 +381,8 @@ def run_drift_search(
verbose=verbose,
)
)
# log the full response at INFO level for user visibility but at DEBUG level in the API layer
logger.info("DRIFT Search Response:\n%s", response)
print(response)
# NOTE: we return the response and context data here purely as a complete demonstration of the API.
# External users should use the API directly to get the response and context data.
return response, context_data
@ -451,11 +416,6 @@ def run_basic_search(
final_text_units_list = dataframe_dict["text_units"]
index_names = dataframe_dict["index_names"]
logger.info(
"Running multi-index basic search on indexes: %s",
dataframe_dict["index_names"],
)
response, context_data = asyncio.run(
api.multi_index_basic_search(
config=config,
@ -466,11 +426,8 @@ def run_basic_search(
verbose=verbose,
)
)
# log the full response at INFO level for user visibility but at DEBUG level in the API layer
logger.info("Basic Search Response:\n%s", response)
print(response)
# NOTE: we return the response and context data here purely as a complete demonstration of the API.
# External users should use the API directly to get the response and context data.
return response, context_data
# Otherwise, call the Single-Index Basic Search API
@ -497,9 +454,9 @@ def run_basic_search(
verbose=verbose,
):
full_response += stream_chunk
print(stream_chunk, end="") # noqa: T201
sys.stdout.flush() # flush output buffer to display text immediately
print() # noqa: T201
print(stream_chunk, end="")
sys.stdout.flush()
print()
return full_response, context_data
return asyncio.run(run_streaming_search())
@ -512,11 +469,8 @@ def run_basic_search(
verbose=verbose,
)
)
# log the full response at INFO level for user visibility but at DEBUG level in the API layer
logger.info("Basic Search Response:\n%s", response)
print(response)
# NOTE: we return the response and context data here purely as a complete demonstration of the API.
# External users should use the API directly to get the response and context data.
return response, context_data

View File

@ -3,9 +3,10 @@
"""Common default configuration values."""
from collections.abc import Callable
from dataclasses import dataclass, field
from pathlib import Path
from typing import ClassVar, Literal
from typing import ClassVar
from graphrag.config.embeddings import default_embeddings
from graphrag.config.enums import (
@ -18,27 +19,59 @@ from graphrag.config.enums import (
NounPhraseExtractorType,
ReportingType,
StorageType,
VectorStoreType,
)
from graphrag.index.operations.build_noun_graph.np_extractors.stop_words import (
EN_STOP_WORDS,
)
from graphrag.vector_stores.factory import VectorStoreType
from graphrag.language_model.providers.litellm.services.rate_limiter.rate_limiter import (
RateLimiter,
)
from graphrag.language_model.providers.litellm.services.rate_limiter.static_rate_limiter import (
StaticRateLimiter,
)
from graphrag.language_model.providers.litellm.services.retry.exponential_retry import (
ExponentialRetry,
)
from graphrag.language_model.providers.litellm.services.retry.incremental_wait_retry import (
IncrementalWaitRetry,
)
from graphrag.language_model.providers.litellm.services.retry.native_wait_retry import (
NativeRetry,
)
from graphrag.language_model.providers.litellm.services.retry.random_wait_retry import (
RandomWaitRetry,
)
from graphrag.language_model.providers.litellm.services.retry.retry import Retry
DEFAULT_OUTPUT_BASE_DIR = "output"
DEFAULT_CHAT_MODEL_ID = "default_chat_model"
DEFAULT_CHAT_MODEL_TYPE = ModelType.OpenAIChat
DEFAULT_CHAT_MODEL_TYPE = ModelType.Chat
DEFAULT_CHAT_MODEL = "gpt-4-turbo-preview"
DEFAULT_CHAT_MODEL_AUTH_TYPE = AuthType.APIKey
DEFAULT_EMBEDDING_MODEL_ID = "default_embedding_model"
DEFAULT_EMBEDDING_MODEL_TYPE = ModelType.OpenAIEmbedding
DEFAULT_EMBEDDING_MODEL_TYPE = ModelType.Embedding
DEFAULT_EMBEDDING_MODEL = "text-embedding-3-small"
DEFAULT_EMBEDDING_MODEL_AUTH_TYPE = AuthType.APIKey
DEFAULT_MODEL_PROVIDER = "openai"
DEFAULT_VECTOR_STORE_ID = "default_vector_store"
ENCODING_MODEL = "cl100k_base"
COGNITIVE_SERVICES_AUDIENCE = "https://cognitiveservices.azure.com/.default"
DEFAULT_RETRY_SERVICES: dict[str, Callable[..., Retry]] = {
"native": NativeRetry,
"exponential_backoff": ExponentialRetry,
"random_wait": RandomWaitRetry,
"incremental_wait": IncrementalWaitRetry,
}
DEFAULT_RATE_LIMITER_SERVICES: dict[str, Callable[..., RateLimiter]] = {
"static": StaticRateLimiter,
}
@dataclass
class BasicSearchDefaults:
"""Default values for basic search."""
@ -214,6 +247,7 @@ class ExtractGraphNLPDefaults:
normalize_edge_weights: bool = True
text_analyzer: TextAnalyzerDefaults = field(default_factory=TextAnalyzerDefaults)
concurrent_requests: int = 25
async_mode: AsyncType = AsyncType.Threaded
@dataclass
@ -274,6 +308,7 @@ class LanguageModelDefaults:
api_key: None = None
auth_type: ClassVar[AuthType] = AuthType.APIKey
model_provider: str | None = None
encoding_model: str = ""
max_tokens: int | None = None
temperature: float = 0
@ -291,9 +326,10 @@ class LanguageModelDefaults:
proxy: None = None
audience: None = None
model_supports_json: None = None
tokens_per_minute: Literal["auto"] = "auto"
requests_per_minute: Literal["auto"] = "auto"
retry_strategy: str = "native"
tokens_per_minute: None = None
requests_per_minute: None = None
rate_limit_strategy: str | None = "static"
retry_strategy: str = "exponential_backoff"
max_retries: int = 10
max_retry_wait: float = 10.0
concurrent_requests: int = 25
@ -393,6 +429,7 @@ class VectorStoreDefaults:
api_key: None = None
audience: None = None
database_name: None = None
schema: None = None
@dataclass

View File

@ -29,14 +29,14 @@ default_embeddings: list[str] = [
]
def create_collection_name(
def create_index_name(
container_name: str, embedding_name: str, validate: bool = True
) -> str:
"""
Create a collection name for the embedding store.
Create a index name for the embedding store.
Within any given vector store, we can have multiple sets of embeddings organized into projects.
The `container` param is used for this partitioning, and is added as a prefix to the collection name for differentiation.
The `container` param is used for this partitioning, and is added as a prefix to the index name for differentiation.
The embedding name is fixed, with the available list defined in graphrag.index.config.embeddings

View File

@ -59,6 +59,14 @@ class StorageType(str, Enum):
return f'"{self.value}"'
class VectorStoreType(str, Enum):
"""The supported vector store types."""
LanceDB = "lancedb"
AzureAISearch = "azure_ai_search"
CosmosDB = "cosmosdb"
class ReportingType(str, Enum):
"""The reporting configuration type for the pipeline."""
@ -78,10 +86,12 @@ class ModelType(str, Enum):
# Embeddings
OpenAIEmbedding = "openai_embedding"
AzureOpenAIEmbedding = "azure_openai_embedding"
Embedding = "embedding"
# Chat Completion
OpenAIChat = "openai_chat"
AzureOpenAIChat = "azure_openai_chat"
Chat = "chat"
# Debug
MockChat = "mock_chat"

View File

@ -33,15 +33,6 @@ class AzureApiVersionMissingError(ValueError):
super().__init__(msg)
class AzureDeploymentNameMissingError(ValueError):
"""Azure Deployment Name missing error."""
def __init__(self, llm_type: str) -> None:
"""Init method definition."""
msg = f"Deployment name is required for {llm_type}. Please rerun `graphrag init` set the deployment_name."
super().__init__(msg)
class LanguageModelConfigMissingError(ValueError):
"""Missing model configuration error."""

View File

@ -11,7 +11,6 @@ def get_embedding_settings(
vector_store_params: dict | None = None,
) -> dict:
"""Transform GraphRAG config into settings for workflows."""
# TEMP
embeddings_llm_settings = settings.get_language_model_config(
settings.embed_text.model_id
)

View File

@ -19,41 +19,34 @@ INIT_YAML = f"""\
models:
{defs.DEFAULT_CHAT_MODEL_ID}:
type: {defs.DEFAULT_CHAT_MODEL_TYPE.value} # or azure_openai_chat
# api_base: https://<instance>.openai.azure.com
# api_version: 2024-05-01-preview
type: {defs.DEFAULT_CHAT_MODEL_TYPE.value}
model_provider: {defs.DEFAULT_MODEL_PROVIDER}
auth_type: {defs.DEFAULT_CHAT_MODEL_AUTH_TYPE.value} # or azure_managed_identity
api_key: ${{GRAPHRAG_API_KEY}} # set this in the generated .env file
# audience: "https://cognitiveservices.azure.com/.default"
# organization: <organization_id>
api_key: ${{GRAPHRAG_API_KEY}} # set this in the generated .env file, or remove if managed identity
model: {defs.DEFAULT_CHAT_MODEL}
# deployment_name: <azure_model_deployment_name>
# encoding_model: {defs.ENCODING_MODEL} # automatically set by tiktoken if left undefined
model_supports_json: true # recommended if this is available for your model.
concurrent_requests: {language_model_defaults.concurrent_requests} # max number of simultaneous LLM requests allowed
async_mode: {language_model_defaults.async_mode.value} # or asyncio
retry_strategy: native
max_retries: {language_model_defaults.max_retries}
tokens_per_minute: {language_model_defaults.tokens_per_minute} # set to null to disable rate limiting
requests_per_minute: {language_model_defaults.requests_per_minute} # set to null to disable rate limiting
{defs.DEFAULT_EMBEDDING_MODEL_ID}:
type: {defs.DEFAULT_EMBEDDING_MODEL_TYPE.value} # or azure_openai_embedding
# api_base: https://<instance>.openai.azure.com
# api_version: 2024-05-01-preview
auth_type: {defs.DEFAULT_EMBEDDING_MODEL_AUTH_TYPE.value} # or azure_managed_identity
api_key: ${{GRAPHRAG_API_KEY}}
# audience: "https://cognitiveservices.azure.com/.default"
# organization: <organization_id>
model: {defs.DEFAULT_EMBEDDING_MODEL}
# deployment_name: <azure_model_deployment_name>
# encoding_model: {defs.ENCODING_MODEL} # automatically set by tiktoken if left undefined
model_supports_json: true # recommended if this is available for your model.
concurrent_requests: {language_model_defaults.concurrent_requests} # max number of simultaneous LLM requests allowed
concurrent_requests: {language_model_defaults.concurrent_requests}
async_mode: {language_model_defaults.async_mode.value} # or asyncio
retry_strategy: native
retry_strategy: {language_model_defaults.retry_strategy}
max_retries: {language_model_defaults.max_retries}
tokens_per_minute: null # set to null to disable rate limiting or auto for dynamic
requests_per_minute: null # set to null to disable rate limiting or auto for dynamic
tokens_per_minute: null
requests_per_minute: null
{defs.DEFAULT_EMBEDDING_MODEL_ID}:
type: {defs.DEFAULT_EMBEDDING_MODEL_TYPE.value}
model_provider: {defs.DEFAULT_MODEL_PROVIDER}
auth_type: {defs.DEFAULT_EMBEDDING_MODEL_AUTH_TYPE.value}
api_key: ${{GRAPHRAG_API_KEY}}
model: {defs.DEFAULT_EMBEDDING_MODEL}
# api_base: https://<instance>.openai.azure.com
# api_version: 2024-05-01-preview
concurrent_requests: {language_model_defaults.concurrent_requests}
async_mode: {language_model_defaults.async_mode.value} # or asyncio
retry_strategy: {language_model_defaults.retry_strategy}
max_retries: {language_model_defaults.max_retries}
tokens_per_minute: null
requests_per_minute: null
### Input settings ###
@ -62,7 +55,6 @@ input:
type: {graphrag_config_defaults.input.storage.type.value} # or blob
base_dir: "{graphrag_config_defaults.input.storage.base_dir}"
file_type: {graphrag_config_defaults.input.file_type.value} # [csv, text, json]
chunks:
size: {graphrag_config_defaults.chunks.size}
@ -82,7 +74,7 @@ cache:
base_dir: "{graphrag_config_defaults.cache.base_dir}"
reporting:
type: {graphrag_config_defaults.reporting.type.value} # [file, blob, cosmosdb]
type: {graphrag_config_defaults.reporting.type.value} # [file, blob]
base_dir: "{graphrag_config_defaults.reporting.base_dir}"
vector_store:
@ -90,7 +82,6 @@ vector_store:
type: {vector_store_defaults.type}
db_uri: {vector_store_defaults.db_uri}
container_name: {vector_store_defaults.container_name}
overwrite: {vector_store_defaults.overwrite}
### Workflow settings ###
@ -112,6 +103,7 @@ summarize_descriptions:
extract_graph_nlp:
text_analyzer:
extractor_type: {graphrag_config_defaults.extract_graph_nlp.text_analyzer.extractor_type.value} # [regex_english, syntactic_parser, cfg]
async_mode: {graphrag_config_defaults.extract_graph_nlp.async_mode.value} # or asyncio
cluster_graph:
max_cluster_size: {graphrag_config_defaults.cluster_graph.max_cluster_size}

View File

@ -12,7 +12,7 @@ from graphrag.config.enums import CacheType
class CacheConfig(BaseModel):
"""The default configuration section for Cache."""
type: CacheType = Field(
type: CacheType | str = Field(
description="The cache type to use.",
default=graphrag_config_defaults.cache.type,
)

View File

@ -6,7 +6,7 @@
from pydantic import BaseModel, Field
from graphrag.config.defaults import graphrag_config_defaults
from graphrag.config.enums import NounPhraseExtractorType
from graphrag.config.enums import AsyncType, NounPhraseExtractorType
class TextAnalyzerConfig(BaseModel):
@ -68,3 +68,7 @@ class ExtractGraphNLPConfig(BaseModel):
description="The number of threads to use for the extraction process.",
default=graphrag_config_defaults.extract_graph_nlp.concurrent_requests,
)
async_mode: AsyncType = Field(
description="The async mode to use.",
default=graphrag_config_defaults.extract_graph_nlp.async_mode,
)

View File

@ -11,6 +11,7 @@ from pydantic import BaseModel, Field, model_validator
import graphrag.config.defaults as defs
from graphrag.config.defaults import graphrag_config_defaults
from graphrag.config.enums import VectorStoreType
from graphrag.config.errors import LanguageModelConfigMissingError
from graphrag.config.models.basic_search_config import BasicSearchConfig
from graphrag.config.models.cache_config import CacheConfig
@ -36,7 +37,12 @@ from graphrag.config.models.summarize_descriptions_config import (
from graphrag.config.models.text_embedding_config import TextEmbeddingConfig
from graphrag.config.models.umap_config import UmapConfig
from graphrag.config.models.vector_store_config import VectorStoreConfig
from graphrag.vector_stores.factory import VectorStoreType
from graphrag.language_model.providers.litellm.services.rate_limiter.rate_limiter_factory import (
RateLimiterFactory,
)
from graphrag.language_model.providers.litellm.services.retry.retry_factory import (
RetryFactory,
)
class GraphRagConfig(BaseModel):
@ -89,6 +95,47 @@ class GraphRagConfig(BaseModel):
if defs.DEFAULT_EMBEDDING_MODEL_ID not in self.models:
raise LanguageModelConfigMissingError(defs.DEFAULT_EMBEDDING_MODEL_ID)
def _validate_retry_services(self) -> None:
"""Validate the retry services configuration."""
retry_factory = RetryFactory()
for model_id, model in self.models.items():
if model.retry_strategy != "none":
if model.retry_strategy not in retry_factory:
msg = f"Retry strategy '{model.retry_strategy}' for model '{model_id}' is not registered. Available strategies: {', '.join(retry_factory.keys())}"
raise ValueError(msg)
_ = retry_factory.create(
strategy=model.retry_strategy,
max_retries=model.max_retries,
max_retry_wait=model.max_retry_wait,
)
def _validate_rate_limiter_services(self) -> None:
"""Validate the rate limiter services configuration."""
rate_limiter_factory = RateLimiterFactory()
for model_id, model in self.models.items():
if model.rate_limit_strategy is not None:
if model.rate_limit_strategy not in rate_limiter_factory:
msg = f"Rate Limiter strategy '{model.rate_limit_strategy}' for model '{model_id}' is not registered. Available strategies: {', '.join(rate_limiter_factory.keys())}"
raise ValueError(msg)
rpm = (
model.requests_per_minute
if type(model.requests_per_minute) is int
else None
)
tpm = (
model.tokens_per_minute
if type(model.tokens_per_minute) is int
else None
)
if rpm is not None or tpm is not None:
_ = rate_limiter_factory.create(
strategy=model.rate_limit_strategy, rpm=rpm, tpm=tpm
)
input: InputConfig = Field(
description="The input configuration.", default=InputConfig()
)
@ -300,6 +347,11 @@ class GraphRagConfig(BaseModel):
raise ValueError(msg)
store.db_uri = str((Path(self.root_dir) / store.db_uri).resolve())
def _validate_factories(self) -> None:
"""Validate the factories used in the configuration."""
self._validate_retry_services()
self._validate_rate_limiter_services()
def get_language_model_config(self, model_id: str) -> LanguageModelConfig:
"""Get a model configuration by ID.
@ -360,4 +412,5 @@ class GraphRagConfig(BaseModel):
self._validate_multi_output_base_dirs()
self._validate_update_index_output_base_dir()
self._validate_vector_store_db_uri()
self._validate_factories()
return self

View File

@ -3,6 +3,7 @@
"""Language model configuration."""
import logging
from typing import Literal
import tiktoken
@ -14,11 +15,12 @@ from graphrag.config.errors import (
ApiKeyMissingError,
AzureApiBaseMissingError,
AzureApiVersionMissingError,
AzureDeploymentNameMissingError,
ConflictingSettingsError,
)
from graphrag.language_model.factory import ModelFactory
logger = logging.getLogger(__name__)
class LanguageModelConfig(BaseModel):
"""Language model configuration."""
@ -73,8 +75,11 @@ class LanguageModelConfig(BaseModel):
ConflictingSettingsError
If the Azure authentication type conflicts with the model being used.
"""
if self.auth_type == AuthType.AzureManagedIdentity and (
self.type == ModelType.OpenAIChat or self.type == ModelType.OpenAIEmbedding
if (
self.auth_type == AuthType.AzureManagedIdentity
and self.type != ModelType.AzureOpenAIChat
and self.type != ModelType.AzureOpenAIEmbedding
and self.model_provider != "azure" # indicates Litellm + AOI
):
msg = f"auth_type of azure_managed_identity is not supported for model type {self.type}. Please rerun `graphrag init` and set the auth_type to api_key."
raise ConflictingSettingsError(msg)
@ -93,6 +98,35 @@ class LanguageModelConfig(BaseModel):
if not ModelFactory.is_supported_model(self.type):
msg = f"Model type {self.type} is not recognized, must be one of {ModelFactory.get_chat_models() + ModelFactory.get_embedding_models()}."
raise KeyError(msg)
if self.type in [
"openai_chat",
"openai_embedding",
"azure_openai_chat",
"azure_openai_embedding",
]:
msg = f"Model config based on fnllm is deprecated and will be removed in GraphRAG v3, please use {ModelType.Chat} or {ModelType.Embedding} instead to switch to LiteLLM config."
logger.warning(msg)
model_provider: str | None = Field(
description="The model provider to use.",
default=language_model_defaults.model_provider,
)
def _validate_model_provider(self) -> None:
"""Validate the model provider.
Required when using Litellm.
Raises
------
KeyError
If the model provider is not recognized.
"""
if (self.type == ModelType.Chat or self.type == ModelType.Embedding) and (
self.model_provider is None or self.model_provider.strip() == ""
):
msg = f"Model provider must be specified when using type == {self.type}."
raise KeyError(msg)
model: str = Field(description="The LLM model to use.")
encoding_model: str = Field(
@ -103,12 +137,27 @@ class LanguageModelConfig(BaseModel):
def _validate_encoding_model(self) -> None:
"""Validate the encoding model.
The default behavior is to use an encoding model that matches the LLM model.
LiteLLM supports 100+ models and their tokenization. There is no need to
set the encoding model when using the new LiteLLM provider as was done with fnllm provider.
Users can still manually specify a tiktoken based encoding model to use even with the LiteLLM provider
in which case the specified encoding model will be used regardless of the LLM model being used, even if
it is not an openai based model.
If not using LiteLLM provider, set the encoding model based on the LLM model name.
This is for backward compatibility with existing fnllm provider until fnllm is removed.
Raises
------
KeyError
If the model name is not recognized.
"""
if self.encoding_model.strip() == "":
if (
self.type != ModelType.Chat
and self.type != ModelType.Embedding
and self.encoding_model.strip() == ""
):
self.encoding_model = tiktoken.encoding_name_for_model(self.model)
api_base: str | None = Field(
@ -129,6 +178,7 @@ class LanguageModelConfig(BaseModel):
if (
self.type == ModelType.AzureOpenAIChat
or self.type == ModelType.AzureOpenAIEmbedding
or self.model_provider == "azure" # indicates Litellm + AOI
) and (self.api_base is None or self.api_base.strip() == ""):
raise AzureApiBaseMissingError(self.type)
@ -150,6 +200,7 @@ class LanguageModelConfig(BaseModel):
if (
self.type == ModelType.AzureOpenAIChat
or self.type == ModelType.AzureOpenAIEmbedding
or self.model_provider == "azure" # indicates Litellm + AOI
) and (self.api_version is None or self.api_version.strip() == ""):
raise AzureApiVersionMissingError(self.type)
@ -171,8 +222,10 @@ class LanguageModelConfig(BaseModel):
if (
self.type == ModelType.AzureOpenAIChat
or self.type == ModelType.AzureOpenAIEmbedding
or self.model_provider == "azure" # indicates Litellm + AOI
) and (self.deployment_name is None or self.deployment_name.strip() == ""):
raise AzureDeploymentNameMissingError(self.type)
msg = f"deployment_name is not set for Azure-hosted model. This will default to your model name ({self.model}). If different, this should be set."
logger.debug(msg)
organization: str | None = Field(
description="The organization to use for the LLM service.",
@ -212,6 +265,14 @@ class LanguageModelConfig(BaseModel):
msg = f"Tokens per minute must be a non zero positive number, 'auto' or null. Suggested value: {language_model_defaults.tokens_per_minute}."
raise ValueError(msg)
if (
(self.type == ModelType.Chat or self.type == ModelType.Embedding)
and self.rate_limit_strategy is not None
and self.tokens_per_minute == "auto"
):
msg = f"tokens_per_minute cannot be set to 'auto' when using type '{self.type}'. Please set it to a positive integer or null to disable."
raise ValueError(msg)
requests_per_minute: int | Literal["auto"] | None = Field(
description="The number of requests per minute to use for the LLM service.",
default=language_model_defaults.requests_per_minute,
@ -230,6 +291,19 @@ class LanguageModelConfig(BaseModel):
msg = f"Requests per minute must be a non zero positive number, 'auto' or null. Suggested value: {language_model_defaults.requests_per_minute}."
raise ValueError(msg)
if (
(self.type == ModelType.Chat or self.type == ModelType.Embedding)
and self.rate_limit_strategy is not None
and self.requests_per_minute == "auto"
):
msg = f"requests_per_minute cannot be set to 'auto' when using type '{self.type}'. Please set it to a positive integer or null to disable."
raise ValueError(msg)
rate_limit_strategy: str | None = Field(
description="The rate limit strategy to use for the LLM service.",
default=language_model_defaults.rate_limit_strategy,
)
retry_strategy: str = Field(
description="The retry strategy to use for the LLM service.",
default=language_model_defaults.retry_strategy,
@ -318,6 +392,7 @@ class LanguageModelConfig(BaseModel):
@model_validator(mode="after")
def _validate_model(self):
self._validate_type()
self._validate_model_provider()
self._validate_auth_type()
self._validate_api_key()
self._validate_tokens_per_minute()

View File

@ -12,7 +12,7 @@ from graphrag.config.enums import ReportingType
class ReportingConfig(BaseModel):
"""The default configuration section for Reporting."""
type: ReportingType = Field(
type: ReportingType | str = Field(
description="The reporting type to use.",
default=graphrag_config_defaults.reporting.type,
)

View File

@ -14,7 +14,7 @@ from graphrag.config.enums import StorageType
class StorageConfig(BaseModel):
"""The default configuration section for storage."""
type: StorageType = Field(
type: StorageType | str = Field(
description="The storage type to use.",
default=graphrag_config_defaults.storage.type,
)

View File

@ -39,7 +39,7 @@ class SummarizeDescriptionsConfig(BaseModel):
self, root_dir: str, model_config: LanguageModelConfig
) -> dict:
"""Get the resolved description summarization strategy."""
from graphrag.index.operations.summarize_descriptions import (
from graphrag.index.operations.summarize_descriptions.summarize_descriptions import (
SummarizeStrategyType,
)

View File

@ -39,7 +39,7 @@ class TextEmbeddingConfig(BaseModel):
def resolved_strategy(self, model_config: LanguageModelConfig) -> dict:
"""Get the resolved text embedding strategy."""
from graphrag.index.operations.embed_text import (
from graphrag.index.operations.embed_text.embed_text import (
TextEmbedStrategyType,
)

View File

@ -6,7 +6,9 @@
from pydantic import BaseModel, Field, model_validator
from graphrag.config.defaults import vector_store_defaults
from graphrag.vector_stores.factory import VectorStoreType
from graphrag.config.embeddings import all_embeddings
from graphrag.config.enums import VectorStoreType
from graphrag.config.models.vector_store_schema_config import VectorStoreSchemaConfig
class VectorStoreConfig(BaseModel):
@ -85,9 +87,25 @@ class VectorStoreConfig(BaseModel):
default=vector_store_defaults.overwrite,
)
embeddings_schema: dict[str, VectorStoreSchemaConfig] = {}
def _validate_embeddings_schema(self) -> None:
"""Validate the embeddings schema."""
for name in self.embeddings_schema:
if name not in all_embeddings:
msg = f"vector_store.embeddings_schema contains an invalid embedding schema name: {name}. Please update your settings.yaml and select the correct embedding schema names."
raise ValueError(msg)
if self.type == VectorStoreType.CosmosDB:
for id_field in self.embeddings_schema:
if id_field != "id":
msg = "When using CosmosDB, the id_field in embeddings_schema must be 'id'. Please update your settings.yaml and set the id_field to 'id'."
raise ValueError(msg)
@model_validator(mode="after")
def _validate_model(self):
"""Validate the model."""
self._validate_db_uri()
self._validate_url()
self._validate_embeddings_schema()
return self

View File

@ -0,0 +1,66 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""Parameterization settings for the default configuration."""
import re
from pydantic import BaseModel, Field, model_validator
DEFAULT_VECTOR_SIZE: int = 1536
VALID_IDENTIFIER_REGEX = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$")
def is_valid_field_name(field: str) -> bool:
"""Check if a field name is valid for CosmosDB."""
return bool(VALID_IDENTIFIER_REGEX.match(field))
class VectorStoreSchemaConfig(BaseModel):
"""The default configuration section for Vector Store Schema."""
id_field: str = Field(
description="The ID field to use.",
default="id",
)
vector_field: str = Field(
description="The vector field to use.",
default="vector",
)
text_field: str = Field(
description="The text field to use.",
default="text",
)
attributes_field: str = Field(
description="The attributes field to use.",
default="attributes",
)
vector_size: int = Field(
description="The vector size to use.",
default=DEFAULT_VECTOR_SIZE,
)
index_name: str | None = Field(description="The index name to use.", default=None)
def _validate_schema(self) -> None:
"""Validate the schema."""
for field in [
self.id_field,
self.vector_field,
self.text_field,
self.attributes_field,
]:
if not is_valid_field_name(field):
msg = f"Unsafe or invalid field name: {field}"
raise ValueError(msg)
@model_validator(mode="after")
def _validate_model(self):
"""Validate the model."""
self._validate_schema()
return self

View File

@ -0,0 +1,4 @@
# Copyright (c) 2025 Microsoft Corporation.
# Licensed under the MIT License
"""Factory module."""

View File

@ -0,0 +1,68 @@
# Copyright (c) 2025 Microsoft Corporation.
# Licensed under the MIT License
"""Factory ABC."""
from abc import ABC
from collections.abc import Callable
from typing import Any, ClassVar, Generic, TypeVar
T = TypeVar("T", covariant=True)
class Factory(ABC, Generic[T]):
"""Abstract base class for factories."""
_instance: ClassVar["Factory | None"] = None
def __new__(cls, *args: Any, **kwargs: Any) -> "Factory":
"""Create a new instance of Factory if it does not exist."""
if cls._instance is None:
cls._instance = super().__new__(cls, *args, **kwargs)
return cls._instance
def __init__(self):
if not hasattr(self, "_initialized"):
self._services: dict[str, Callable[..., T]] = {}
self._initialized = True
def __contains__(self, strategy: str) -> bool:
"""Check if a strategy is registered."""
return strategy in self._services
def keys(self) -> list[str]:
"""Get a list of registered strategy names."""
return list(self._services.keys())
def register(self, *, strategy: str, service_initializer: Callable[..., T]) -> None:
"""
Register a new service.
Args
----
strategy: The name of the strategy.
service_initializer: A callable that creates an instance of T.
"""
self._services[strategy] = service_initializer
def create(self, *, strategy: str, **kwargs: Any) -> T:
"""
Create a service instance based on the strategy.
Args
----
strategy: The name of the strategy.
**kwargs: Additional arguments to pass to the service initializer.
Returns
-------
An instance of T.
Raises
------
ValueError: If the strategy is not registered.
"""
if strategy not in self._services:
msg = f"Strategy '{strategy}' is not registered."
raise ValueError(msg)
return self._services[strategy](**kwargs)

View File

@ -24,12 +24,17 @@ async def build_noun_graph(
text_analyzer: BaseNounPhraseExtractor,
normalize_edge_weights: bool,
num_threads: int = 4,
async_mode: AsyncType = AsyncType.Threaded,
cache: PipelineCache | None = None,
) -> tuple[pd.DataFrame, pd.DataFrame]:
"""Build a noun graph from text units."""
text_units = text_unit_df.loc[:, ["id", "text"]]
nodes_df = await _extract_nodes(
text_units, text_analyzer, num_threads=num_threads, cache=cache
text_units,
text_analyzer,
num_threads=num_threads,
async_mode=async_mode,
cache=cache,
)
edges_df = _extract_edges(nodes_df, normalize_edge_weights=normalize_edge_weights)
return (nodes_df, edges_df)
@ -39,6 +44,7 @@ async def _extract_nodes(
text_unit_df: pd.DataFrame,
text_analyzer: BaseNounPhraseExtractor,
num_threads: int = 4,
async_mode: AsyncType = AsyncType.Threaded,
cache: PipelineCache | None = None,
) -> pd.DataFrame:
"""
@ -64,7 +70,7 @@ async def _extract_nodes(
text_unit_df,
extract,
num_threads=num_threads,
async_type=AsyncType.Threaded,
async_type=async_mode,
progress_msg="extract noun phrases progress: ",
)

View File

@ -11,7 +11,7 @@ import tiktoken
from graphrag.config.models.chunking_config import ChunkingConfig
from graphrag.index.operations.chunk_text.typing import TextChunk
from graphrag.index.text_splitting.text_splitting import (
Tokenizer,
TokenChunkerOptions,
split_multiple_texts_on_tokens,
)
from graphrag.logger.progress import ProgressTicker
@ -45,7 +45,7 @@ def run_tokens(
encode, decode = get_encoding_fn(encoding_name)
return split_multiple_texts_on_tokens(
input,
Tokenizer(
TokenChunkerOptions(
chunk_overlap=chunk_overlap,
tokens_per_chunk=tokens_per_chunk,
encode=encode,

View File

@ -2,10 +2,3 @@
# Licensed under the MIT License
"""The Indexing Engine text embed package root."""
from graphrag.index.operations.embed_text.embed_text import (
TextEmbedStrategyType,
embed_text,
)
__all__ = ["TextEmbedStrategyType", "embed_text"]

View File

@ -12,7 +12,8 @@ import pandas as pd
from graphrag.cache.pipeline_cache import PipelineCache
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
from graphrag.config.embeddings import create_collection_name
from graphrag.config.embeddings import create_index_name
from graphrag.config.models.vector_store_schema_config import VectorStoreSchemaConfig
from graphrag.index.operations.embed_text.strategies.typing import TextEmbeddingStrategy
from graphrag.vector_stores.base import BaseVectorStore, VectorStoreDocument
from graphrag.vector_stores.factory import VectorStoreFactory
@ -49,9 +50,9 @@ async def embed_text(
vector_store_config = strategy.get("vector_store")
if vector_store_config:
collection_name = _get_collection_name(vector_store_config, embedding_name)
index_name = _get_index_name(vector_store_config, embedding_name)
vector_store: BaseVectorStore = _create_vector_store(
vector_store_config, collection_name
vector_store_config, index_name, embedding_name
)
vector_store_workflow_config = vector_store_config.get(
embedding_name, vector_store_config
@ -183,27 +184,46 @@ async def _text_embed_with_vector_store(
def _create_vector_store(
vector_store_config: dict, collection_name: str
vector_store_config: dict, index_name: str, embedding_name: str | None = None
) -> BaseVectorStore:
vector_store_type: str = str(vector_store_config.get("type"))
if collection_name:
vector_store_config.update({"collection_name": collection_name})
embeddings_schema: dict[str, VectorStoreSchemaConfig] = vector_store_config.get(
"embeddings_schema", {}
)
single_embedding_config: VectorStoreSchemaConfig = VectorStoreSchemaConfig()
if (
embeddings_schema is not None
and embedding_name is not None
and embedding_name in embeddings_schema
):
raw_config = embeddings_schema[embedding_name]
if isinstance(raw_config, dict):
single_embedding_config = VectorStoreSchemaConfig(**raw_config)
else:
single_embedding_config = raw_config
if single_embedding_config.index_name is None:
single_embedding_config.index_name = index_name
vector_store = VectorStoreFactory().create_vector_store(
vector_store_type, kwargs=vector_store_config
vector_store_schema_config=single_embedding_config,
vector_store_type=vector_store_type,
**vector_store_config,
)
vector_store.connect(**vector_store_config)
return vector_store
def _get_collection_name(vector_store_config: dict, embedding_name: str) -> str:
def _get_index_name(vector_store_config: dict, embedding_name: str) -> str:
container_name = vector_store_config.get("container_name", "default")
collection_name = create_collection_name(container_name, embedding_name)
index_name = create_index_name(container_name, embedding_name)
msg = f"using vector store {vector_store_config.get('type')} with container_name {container_name} for embedding {embedding_name}: {collection_name}"
msg = f"using vector store {vector_store_config.get('type')} with container_name {container_name} for embedding {embedding_name}: {index_name}"
logger.info(msg)
return collection_name
return index_name
def load_strategy(strategy: TextEmbedStrategyType) -> TextEmbeddingStrategy:

View File

@ -18,6 +18,7 @@ from graphrag.index.utils.is_null import is_null
from graphrag.language_model.manager import ModelManager
from graphrag.language_model.protocol.base import EmbeddingModel
from graphrag.logger.progress import ProgressTicker, progress_ticker
from graphrag.tokenizer.get_tokenizer import get_tokenizer
logger = logging.getLogger(__name__)
@ -79,7 +80,7 @@ def _get_splitter(
config: LanguageModelConfig, batch_max_tokens: int
) -> TokenTextSplitter:
return TokenTextSplitter(
encoding_name=config.encoding_model,
tokenizer=get_tokenizer(model_config=config),
chunk_size=batch_max_tokens,
)

View File

@ -8,10 +8,12 @@ import graphrag.data_model.schemas as schemas
from graphrag.index.operations.summarize_communities.graph_context.sort_context import (
sort_context,
)
from graphrag.query.llm.text_utils import num_tokens
from graphrag.tokenizer.tokenizer import Tokenizer
def build_mixed_context(context: list[dict], max_context_tokens: int) -> str:
def build_mixed_context(
context: list[dict], tokenizer: Tokenizer, max_context_tokens: int
) -> str:
"""
Build parent context by concatenating all sub-communities' contexts.
@ -45,9 +47,10 @@ def build_mixed_context(context: list[dict], max_context_tokens: int) -> str:
remaining_local_context.extend(sorted_context[rid][schemas.ALL_CONTEXT])
new_context_string = sort_context(
local_context=remaining_local_context + final_local_contexts,
tokenizer=tokenizer,
sub_community_reports=substitute_reports,
)
if num_tokens(new_context_string) <= max_context_tokens:
if tokenizer.num_tokens(new_context_string) <= max_context_tokens:
exceeded_limit = False
context_string = new_context_string
break
@ -63,7 +66,7 @@ def build_mixed_context(context: list[dict], max_context_tokens: int) -> str:
new_context_string = pd.DataFrame(substitute_reports).to_csv(
index=False, sep=","
)
if num_tokens(new_context_string) > max_context_tokens:
if tokenizer.num_tokens(new_context_string) > max_context_tokens:
break
context_string = new_context_string

View File

@ -30,7 +30,7 @@ from graphrag.index.utils.dataframes import (
where_column_equals,
)
from graphrag.logger.progress import progress_iterable
from graphrag.query.llm.text_utils import num_tokens
from graphrag.tokenizer.tokenizer import Tokenizer
logger = logging.getLogger(__name__)
@ -39,6 +39,7 @@ def build_local_context(
nodes,
edges,
claims,
tokenizer: Tokenizer,
callbacks: WorkflowCallbacks,
max_context_tokens: int = 16_000,
):
@ -49,7 +50,7 @@ def build_local_context(
for level in progress_iterable(levels, callbacks.progress, len(levels)):
communities_at_level_df = _prepare_reports_at_level(
nodes, edges, claims, level, max_context_tokens
nodes, edges, claims, tokenizer, level, max_context_tokens
)
communities_at_level_df.loc[:, schemas.COMMUNITY_LEVEL] = level
@ -63,6 +64,7 @@ def _prepare_reports_at_level(
node_df: pd.DataFrame,
edge_df: pd.DataFrame,
claim_df: pd.DataFrame | None,
tokenizer: Tokenizer,
level: int,
max_context_tokens: int = 16_000,
) -> pd.DataFrame:
@ -181,6 +183,7 @@ def _prepare_reports_at_level(
# Generate community-level context strings using vectorized batch processing
return parallel_sort_context_batch(
community_df,
tokenizer=tokenizer,
max_context_tokens=max_context_tokens,
)
@ -189,6 +192,7 @@ def build_level_context(
report_df: pd.DataFrame | None,
community_hierarchy_df: pd.DataFrame,
local_context_df: pd.DataFrame,
tokenizer: Tokenizer,
level: int,
max_context_tokens: int,
) -> pd.DataFrame:
@ -219,11 +223,11 @@ def build_level_context(
if report_df is None or report_df.empty:
invalid_context_df.loc[:, schemas.CONTEXT_STRING] = _sort_and_trim_context(
invalid_context_df, max_context_tokens
invalid_context_df, tokenizer, max_context_tokens
)
invalid_context_df[schemas.CONTEXT_SIZE] = invalid_context_df.loc[
:, schemas.CONTEXT_STRING
].map(num_tokens)
].map(tokenizer.num_tokens)
invalid_context_df[schemas.CONTEXT_EXCEED_FLAG] = False
return union(valid_context_df, invalid_context_df)
@ -237,6 +241,7 @@ def build_level_context(
invalid_context_df,
sub_context_df,
community_hierarchy_df,
tokenizer,
max_context_tokens,
)
@ -244,11 +249,13 @@ def build_level_context(
# this should be rare, but if it happens, we will just trim the local context to fit the limit
remaining_df = _antijoin_reports(invalid_context_df, community_df)
remaining_df.loc[:, schemas.CONTEXT_STRING] = _sort_and_trim_context(
remaining_df, max_context_tokens
remaining_df, tokenizer, max_context_tokens
)
result = union(valid_context_df, community_df, remaining_df)
result[schemas.CONTEXT_SIZE] = result.loc[:, schemas.CONTEXT_STRING].map(num_tokens)
result[schemas.CONTEXT_SIZE] = result.loc[:, schemas.CONTEXT_STRING].map(
tokenizer.num_tokens
)
result[schemas.CONTEXT_EXCEED_FLAG] = False
return result
@ -269,19 +276,29 @@ def _antijoin_reports(df: pd.DataFrame, reports: pd.DataFrame) -> pd.DataFrame:
return antijoin(df, reports, schemas.COMMUNITY_ID)
def _sort_and_trim_context(df: pd.DataFrame, max_context_tokens: int) -> pd.Series:
def _sort_and_trim_context(
df: pd.DataFrame, tokenizer: Tokenizer, max_context_tokens: int
) -> pd.Series:
"""Sort and trim context to fit the limit."""
series = cast("pd.Series", df[schemas.ALL_CONTEXT])
return transform_series(
series, lambda x: sort_context(x, max_context_tokens=max_context_tokens)
series,
lambda x: sort_context(
x, tokenizer=tokenizer, max_context_tokens=max_context_tokens
),
)
def _build_mixed_context(df: pd.DataFrame, max_context_tokens: int) -> pd.Series:
def _build_mixed_context(
df: pd.DataFrame, tokenizer: Tokenizer, max_context_tokens: int
) -> pd.Series:
"""Sort and trim context to fit the limit."""
series = cast("pd.Series", df[schemas.ALL_CONTEXT])
return transform_series(
series, lambda x: build_mixed_context(x, max_context_tokens=max_context_tokens)
series,
lambda x: build_mixed_context(
x, tokenizer, max_context_tokens=max_context_tokens
),
)
@ -303,6 +320,7 @@ def _get_community_df(
invalid_context_df: pd.DataFrame,
sub_context_df: pd.DataFrame,
community_hierarchy_df: pd.DataFrame,
tokenizer: Tokenizer,
max_context_tokens: int,
) -> pd.DataFrame:
"""Get community context for each community."""
@ -338,7 +356,7 @@ def _get_community_df(
.reset_index()
)
community_df[schemas.CONTEXT_STRING] = _build_mixed_context(
community_df, max_context_tokens
community_df, tokenizer, max_context_tokens
)
community_df[schemas.COMMUNITY_LEVEL] = level
return community_df

View File

@ -5,11 +5,12 @@
import pandas as pd
import graphrag.data_model.schemas as schemas
from graphrag.query.llm.text_utils import num_tokens
from graphrag.tokenizer.tokenizer import Tokenizer
def sort_context(
local_context: list[dict],
tokenizer: Tokenizer,
sub_community_reports: list[dict] | None = None,
max_context_tokens: int | None = None,
node_name_column: str = schemas.TITLE,
@ -112,7 +113,10 @@ def sort_context(
new_context_string = _get_context_string(
sorted_nodes, sorted_edges, sorted_claims, sub_community_reports
)
if max_context_tokens and num_tokens(new_context_string) > max_context_tokens:
if (
max_context_tokens
and tokenizer.num_tokens(new_context_string) > max_context_tokens
):
break
context_string = new_context_string
@ -122,7 +126,9 @@ def sort_context(
)
def parallel_sort_context_batch(community_df, max_context_tokens, parallel=False):
def parallel_sort_context_batch(
community_df, tokenizer: Tokenizer, max_context_tokens, parallel=False
):
"""Calculate context using parallelization if enabled."""
if parallel:
# Use ThreadPoolExecutor for parallel execution
@ -131,7 +137,9 @@ def parallel_sort_context_batch(community_df, max_context_tokens, parallel=False
with ThreadPoolExecutor(max_workers=None) as executor:
context_strings = list(
executor.map(
lambda x: sort_context(x, max_context_tokens=max_context_tokens),
lambda x: sort_context(
x, tokenizer, max_context_tokens=max_context_tokens
),
community_df[schemas.ALL_CONTEXT],
)
)
@ -141,13 +149,13 @@ def parallel_sort_context_batch(community_df, max_context_tokens, parallel=False
# Assign context strings directly to the DataFrame
community_df[schemas.CONTEXT_STRING] = community_df[schemas.ALL_CONTEXT].apply(
lambda context_list: sort_context(
context_list, max_context_tokens=max_context_tokens
context_list, tokenizer, max_context_tokens=max_context_tokens
)
)
# Calculate other columns
community_df[schemas.CONTEXT_SIZE] = community_df[schemas.CONTEXT_STRING].apply(
num_tokens
tokenizer.num_tokens
)
community_df[schemas.CONTEXT_EXCEED_FLAG] = (
community_df[schemas.CONTEXT_SIZE] > max_context_tokens

View File

@ -16,7 +16,6 @@ from graphrag.index.operations.summarize_communities.typing import (
Finding,
StrategyConfig,
)
from graphrag.index.utils.rate_limiter import RateLimiter
from graphrag.language_model.manager import ModelManager
from graphrag.language_model.protocol.base import ChatModel
@ -51,8 +50,6 @@ async def _run_extractor(
level: int,
args: StrategyConfig,
) -> CommunityReport | None:
# RateLimiter
rate_limiter = RateLimiter(rate=1, per=60)
extractor = CommunityReportsExtractor(
model,
extraction_prompt=args.get("extraction_prompt", None),
@ -63,7 +60,6 @@ async def _run_extractor(
)
try:
await rate_limiter.acquire()
results = await extractor(input)
report = results.structured_output
if report is None:

View File

@ -23,6 +23,7 @@ from graphrag.index.operations.summarize_communities.utils import (
)
from graphrag.index.utils.derive_from_rows import derive_from_rows
from graphrag.logger.progress import progress_ticker
from graphrag.tokenizer.tokenizer import Tokenizer
logger = logging.getLogger(__name__)
@ -35,6 +36,7 @@ async def summarize_communities(
callbacks: WorkflowCallbacks,
cache: PipelineCache,
strategy: dict,
tokenizer: Tokenizer,
max_input_length: int,
async_mode: AsyncType = AsyncType.AsyncIO,
num_threads: int = 4,
@ -44,7 +46,6 @@ async def summarize_communities(
tick = progress_ticker(callbacks.progress, len(local_contexts))
strategy_exec = load_strategy(strategy["type"])
strategy_config = {**strategy}
community_hierarchy = (
communities.explode("children")
.rename({"children": "sub_community"}, axis=1)
@ -60,6 +61,7 @@ async def summarize_communities(
community_hierarchy_df=community_hierarchy,
local_context_df=local_contexts,
level=level,
tokenizer=tokenizer,
max_context_tokens=max_input_length,
)
level_contexts.append(level_context)

View File

@ -18,7 +18,7 @@ from graphrag.index.operations.summarize_communities.text_unit_context.prep_text
from graphrag.index.operations.summarize_communities.text_unit_context.sort_context import (
sort_context,
)
from graphrag.query.llm.text_utils import num_tokens
from graphrag.tokenizer.tokenizer import Tokenizer
logger = logging.getLogger(__name__)
@ -27,6 +27,7 @@ def build_local_context(
community_membership_df: pd.DataFrame,
text_units_df: pd.DataFrame,
node_df: pd.DataFrame,
tokenizer: Tokenizer,
max_context_tokens: int = 16000,
) -> pd.DataFrame:
"""
@ -69,10 +70,10 @@ def build_local_context(
.reset_index()
)
context_df[schemas.CONTEXT_STRING] = context_df[schemas.ALL_CONTEXT].apply(
lambda x: sort_context(x)
lambda x: sort_context(x, tokenizer)
)
context_df[schemas.CONTEXT_SIZE] = context_df[schemas.CONTEXT_STRING].apply(
lambda x: num_tokens(x)
lambda x: tokenizer.num_tokens(x)
)
context_df[schemas.CONTEXT_EXCEED_FLAG] = context_df[schemas.CONTEXT_SIZE].apply(
lambda x: x > max_context_tokens
@ -86,6 +87,7 @@ def build_level_context(
community_hierarchy_df: pd.DataFrame,
local_context_df: pd.DataFrame,
level: int,
tokenizer: Tokenizer,
max_context_tokens: int = 16000,
) -> pd.DataFrame:
"""
@ -116,10 +118,12 @@ def build_level_context(
invalid_context_df.loc[:, [schemas.CONTEXT_STRING]] = invalid_context_df[
schemas.ALL_CONTEXT
].apply(lambda x: sort_context(x, max_context_tokens=max_context_tokens))
].apply(
lambda x: sort_context(x, tokenizer, max_context_tokens=max_context_tokens)
)
invalid_context_df.loc[:, [schemas.CONTEXT_SIZE]] = invalid_context_df[
schemas.CONTEXT_STRING
].apply(lambda x: num_tokens(x))
].apply(lambda x: tokenizer.num_tokens(x))
invalid_context_df.loc[:, [schemas.CONTEXT_EXCEED_FLAG]] = False
return pd.concat([valid_context_df, invalid_context_df])
@ -199,10 +203,10 @@ def build_level_context(
.reset_index()
)
community_df[schemas.CONTEXT_STRING] = community_df[schemas.ALL_CONTEXT].apply(
lambda x: build_mixed_context(x, max_context_tokens)
lambda x: build_mixed_context(x, tokenizer, max_context_tokens)
)
community_df[schemas.CONTEXT_SIZE] = community_df[schemas.CONTEXT_STRING].apply(
lambda x: num_tokens(x)
lambda x: tokenizer.num_tokens(x)
)
community_df[schemas.CONTEXT_EXCEED_FLAG] = False
community_df[schemas.COMMUNITY_LEVEL] = level
@ -220,10 +224,10 @@ def build_level_context(
)
remaining_df[schemas.CONTEXT_STRING] = cast(
"pd.DataFrame", remaining_df[schemas.ALL_CONTEXT]
).apply(lambda x: sort_context(x, max_context_tokens=max_context_tokens))
).apply(lambda x: sort_context(x, tokenizer, max_context_tokens=max_context_tokens))
remaining_df[schemas.CONTEXT_SIZE] = cast(
"pd.DataFrame", remaining_df[schemas.CONTEXT_STRING]
).apply(lambda x: num_tokens(x))
).apply(lambda x: tokenizer.num_tokens(x))
remaining_df[schemas.CONTEXT_EXCEED_FLAG] = False
return cast(

View File

@ -8,7 +8,7 @@ import logging
import pandas as pd
import graphrag.data_model.schemas as schemas
from graphrag.query.llm.text_utils import num_tokens
from graphrag.tokenizer.tokenizer import Tokenizer
logger = logging.getLogger(__name__)
@ -57,6 +57,7 @@ def get_context_string(
def sort_context(
local_context: list[dict],
tokenizer: Tokenizer,
sub_community_reports: list[dict] | None = None,
max_context_tokens: int | None = None,
) -> str:
@ -73,7 +74,7 @@ def sort_context(
new_context_string = get_context_string(
current_text_units, sub_community_reports
)
if num_tokens(new_context_string) > max_context_tokens:
if tokenizer.num_tokens(new_context_string) > max_context_tokens:
break
context_string = new_context_string

View File

@ -2,17 +2,3 @@
# Licensed under the MIT License
"""Root package for description summarization."""
from graphrag.index.operations.summarize_descriptions.summarize_descriptions import (
summarize_descriptions,
)
from graphrag.index.operations.summarize_descriptions.typing import (
SummarizationStrategy,
SummarizeStrategyType,
)
__all__ = [
"SummarizationStrategy",
"SummarizeStrategyType",
"summarize_descriptions",
]

View File

@ -7,9 +7,9 @@ import json
from dataclasses import dataclass
from graphrag.index.typing.error_handler import ErrorHandlerFn
from graphrag.index.utils.tokens import num_tokens_from_string
from graphrag.language_model.protocol.base import ChatModel
from graphrag.prompts.index.summarize_descriptions import SUMMARIZE_PROMPT
from graphrag.tokenizer.get_tokenizer import get_tokenizer
# these tokens are used in the prompt
ENTITY_NAME_KEY = "entity_name"
@ -45,7 +45,7 @@ class SummarizeExtractor:
"""Init method definition."""
# TODO: streamline construction
self._model = model_invoker
self._tokenizer = get_tokenizer(model_invoker.config)
self._summarization_prompt = summarization_prompt or SUMMARIZE_PROMPT
self._on_error = on_error or (lambda _e, _s, _d: None)
self._max_summary_length = max_summary_length
@ -85,14 +85,14 @@ class SummarizeExtractor:
descriptions = sorted(descriptions)
# Iterate over descriptions, adding all until the max input tokens is reached
usable_tokens = self._max_input_tokens - num_tokens_from_string(
usable_tokens = self._max_input_tokens - self._tokenizer.num_tokens(
self._summarization_prompt
)
descriptions_collected = []
result = ""
for i, description in enumerate(descriptions):
usable_tokens -= num_tokens_from_string(description)
usable_tokens -= self._tokenizer.num_tokens(description)
descriptions_collected.append(description)
# If buffer is full, or all descriptions have been added, summarize
@ -109,8 +109,8 @@ class SummarizeExtractor:
descriptions_collected = [result]
usable_tokens = (
self._max_input_tokens
- num_tokens_from_string(self._summarization_prompt)
- num_tokens_from_string(result)
- self._tokenizer.num_tokens(self._summarization_prompt)
- self._tokenizer.num_tokens(result)
)
return result

View File

@ -11,6 +11,8 @@ from collections.abc import AsyncIterable
from dataclasses import asdict
from typing import Any
import pandas as pd
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
from graphrag.config.models.graph_rag_config import GraphRagConfig
from graphrag.index.run.utils import create_run_context
@ -30,6 +32,7 @@ async def run_pipeline(
callbacks: WorkflowCallbacks,
is_update_run: bool = False,
additional_context: dict[str, Any] | None = None,
input_documents: pd.DataFrame | None = None,
) -> AsyncIterable[PipelineRunResult]:
"""Run all workflows using a simplified pipeline."""
root_dir = config.root_dir
@ -60,6 +63,11 @@ async def run_pipeline(
state["update_timestamp"] = update_timestamp
# if the user passes in a df directly, write directly to storage so we can skip finding/parsing later
if input_documents is not None:
await write_table_to_storage(input_documents, "documents", delta_storage)
pipeline.remove("load_update_documents")
context = create_run_context(
input_storage=input_storage,
output_storage=delta_storage,
@ -72,6 +80,11 @@ async def run_pipeline(
else:
logger.info("Running standard indexing.")
# if the user passes in a df directly, write directly to storage so we can skip finding/parsing later
if input_documents is not None:
await write_table_to_storage(input_documents, "documents", output_storage)
pipeline.remove("load_input_documents")
context = create_run_context(
input_storage=input_storage,
output_storage=output_storage,

View File

@ -5,16 +5,16 @@
import logging
from abc import ABC, abstractmethod
from collections.abc import Callable, Collection, Iterable
from collections.abc import Callable, Iterable
from dataclasses import dataclass
from typing import Any, Literal, cast
from typing import Any, cast
import pandas as pd
import tiktoken
import graphrag.config.defaults as defs
from graphrag.index.operations.chunk_text.typing import TextChunk
from graphrag.logger.progress import ProgressTicker
from graphrag.tokenizer.get_tokenizer import get_tokenizer
from graphrag.tokenizer.tokenizer import Tokenizer
EncodedText = list[int]
DecodeFn = Callable[[EncodedText], str]
@ -25,8 +25,8 @@ logger = logging.getLogger(__name__)
@dataclass(frozen=True)
class Tokenizer:
"""Tokenizer data class."""
class TokenChunkerOptions:
"""TokenChunkerOptions data class."""
chunk_overlap: int
"""Overlap in tokens between chunks"""
@ -83,44 +83,18 @@ class NoopTextSplitter(TextSplitter):
class TokenTextSplitter(TextSplitter):
"""Token text splitter class definition."""
_allowed_special: Literal["all"] | set[str]
_disallowed_special: Literal["all"] | Collection[str]
def __init__(
self,
encoding_name: str = defs.ENCODING_MODEL,
model_name: str | None = None,
allowed_special: Literal["all"] | set[str] | None = None,
disallowed_special: Literal["all"] | Collection[str] = "all",
tokenizer: Tokenizer | None = None,
**kwargs: Any,
):
"""Init method definition."""
super().__init__(**kwargs)
if model_name is not None:
try:
enc = tiktoken.encoding_for_model(model_name)
except KeyError:
logger.exception(
"Model %s not found, using %s", model_name, encoding_name
)
enc = tiktoken.get_encoding(encoding_name)
else:
enc = tiktoken.get_encoding(encoding_name)
self._tokenizer = enc
self._allowed_special = allowed_special or set()
self._disallowed_special = disallowed_special
def encode(self, text: str) -> list[int]:
"""Encode the given text into an int-vector."""
return self._tokenizer.encode(
text,
allowed_special=self._allowed_special,
disallowed_special=self._disallowed_special,
)
self._tokenizer = tokenizer or get_tokenizer()
def num_tokens(self, text: str) -> int:
"""Return the number of tokens in a string."""
return len(self.encode(text))
return self._tokenizer.num_tokens(text)
def split_text(self, text: str | list[str]) -> list[str]:
"""Split text method."""
@ -132,17 +106,17 @@ class TokenTextSplitter(TextSplitter):
msg = f"Attempting to split a non-string value, actual is {type(text)}"
raise TypeError(msg)
tokenizer = Tokenizer(
token_chunker_options = TokenChunkerOptions(
chunk_overlap=self._chunk_overlap,
tokens_per_chunk=self._chunk_size,
decode=self._tokenizer.decode,
encode=lambda text: self.encode(text),
encode=self._tokenizer.encode,
)
return split_single_text_on_tokens(text=text, tokenizer=tokenizer)
return split_single_text_on_tokens(text=text, tokenizer=token_chunker_options)
def split_single_text_on_tokens(text: str, tokenizer: Tokenizer) -> list[str]:
def split_single_text_on_tokens(text: str, tokenizer: TokenChunkerOptions) -> list[str]:
"""Split a single text and return chunks using the tokenizer."""
result = []
input_ids = tokenizer.encode(text)
@ -166,7 +140,7 @@ def split_single_text_on_tokens(text: str, tokenizer: Tokenizer) -> list[str]:
# Adapted from - https://github.com/langchain-ai/langchain/blob/77b359edf5df0d37ef0d539f678cf64f5557cb54/libs/langchain/langchain/text_splitter.py#L471
# So we could have better control over the chunking process
def split_multiple_texts_on_tokens(
texts: list[str], tokenizer: Tokenizer, tick: ProgressTicker
texts: list[str], tokenizer: TokenChunkerOptions, tick: ProgressTicker
) -> list[TextChunk]:
"""Split multiple texts and return chunks with metadata using the tokenizer."""
result = []

View File

@ -21,3 +21,7 @@ class Pipeline:
def names(self) -> list[str]:
"""Return the names of the workflows in the pipeline."""
return [name for name, _ in self.workflows]
def remove(self, name: str) -> None:
"""Remove a workflow from the pipeline by name."""
self.workflows = [w for w in self.workflows if w[0] != name]

View File

@ -1,40 +0,0 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""Rate limiter utility."""
import asyncio
import time
class RateLimiter:
"""
The original TpmRpmLLMLimiter strategy did not account for minute-based rate limiting when scheduled.
The RateLimiter was introduced to ensure that the CommunityReportsExtractor could be scheduled to adhere to rate configurations on a per-minute basis.
"""
# TODO: RateLimiter scheduled: using asyncio for async_mode
def __init__(self, rate: int, per: int):
self.rate = rate
self.per = per
self.allowance = rate
self.last_check = time.monotonic()
async def acquire(self):
"""Acquire a token from the rate limiter."""
current = time.monotonic()
elapsed = current - self.last_check
self.last_check = current
self.allowance += elapsed * (self.rate / self.per)
if self.allowance > self.rate:
self.allowance = self.rate
if self.allowance < 1.0:
sleep_time = (1.0 - self.allowance) * (self.per / self.rate)
await asyncio.sleep(sleep_time)
self.allowance = 0.0
else:
self.allowance -= 1.0

View File

@ -1,44 +0,0 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""Utilities for working with tokens."""
import logging
import tiktoken
import graphrag.config.defaults as defs
DEFAULT_ENCODING_NAME = defs.ENCODING_MODEL
logger = logging.getLogger(__name__)
def num_tokens_from_string(
string: str, model: str | None = None, encoding_name: str | None = None
) -> int:
"""Return the number of tokens in a text string."""
if model is not None:
try:
encoding = tiktoken.encoding_for_model(model)
except KeyError:
msg = f"Failed to get encoding for {model} when getting num_tokens_from_string. Fall back to default encoding {DEFAULT_ENCODING_NAME}"
logger.warning(msg)
encoding = tiktoken.get_encoding(DEFAULT_ENCODING_NAME)
else:
encoding = tiktoken.get_encoding(encoding_name or DEFAULT_ENCODING_NAME)
return len(encoding.encode(string))
def string_from_tokens(
tokens: list[int], model: str | None = None, encoding_name: str | None = None
) -> str:
"""Return a text string from a list of tokens."""
if model is not None:
encoding = tiktoken.encoding_for_model(model)
elif encoding_name is not None:
encoding = tiktoken.get_encoding(encoding_name)
else:
msg = "Either model or encoding_name must be specified."
raise ValueError(msg)
return encoding.decode(tokens)

View File

@ -15,42 +15,39 @@ logger = logging.getLogger(__name__)
def validate_config_names(parameters: GraphRagConfig) -> None:
"""Validate config file for LLM deployment name typos."""
# Validate Chat LLM configs
# TODO: Replace default_chat_model with a way to select the model
default_llm_settings = parameters.get_language_model_config("default_chat_model")
llm = ModelManager().register_chat(
name="test-llm",
model_type=default_llm_settings.type,
config=default_llm_settings,
callbacks=NoopWorkflowCallbacks(),
cache=None,
)
try:
asyncio.run(llm.achat("This is an LLM connectivity test. Say Hello World"))
logger.info("LLM Config Params Validated")
except Exception as e: # noqa: BLE001
logger.error(f"LLM configuration error detected. Exiting...\n{e}") # noqa
sys.exit(1)
# Validate Embeddings LLM configs
embedding_llm_settings = parameters.get_language_model_config(
parameters.embed_text.model_id
)
embed_llm = ModelManager().register_embedding(
name="test-embed-llm",
model_type=embedding_llm_settings.type,
config=embedding_llm_settings,
callbacks=NoopWorkflowCallbacks(),
cache=None,
)
try:
asyncio.run(embed_llm.aembed_batch(["This is an LLM Embedding Test String"]))
logger.info("Embedding LLM Config Params Validated")
except Exception as e: # noqa: BLE001
logger.error(f"Embedding LLM configuration error detected. Exiting...\n{e}") # noqa
sys.exit(1)
"""Validate config file for model deployment name typos, by running a quick test message for each."""
for id, config in parameters.models.items():
if config.type in ["chat", "azure_openai", "openai"]:
llm = ModelManager().register_chat(
name="test-llm",
model_type=config.type,
config=config,
callbacks=NoopWorkflowCallbacks(),
cache=None,
)
try:
asyncio.run(
llm.achat("This is an LLM connectivity test. Say Hello World")
)
logger.info("LLM Config Params Validated")
except Exception as e: # noqa: BLE001
logger.error(f"LLM configuration error detected.\n{e}") # noqa
print(f"Failed to validate language model ({id}) params", e) # noqa: T201
sys.exit(1)
elif config.type in ["embedding", "azure_openai_embedding", "openai_embedding"]:
embed_llm = ModelManager().register_embedding(
name="test-embed-llm",
model_type=config.type,
config=config,
callbacks=NoopWorkflowCallbacks(),
cache=None,
)
try:
asyncio.run(
embed_llm.aembed_batch(["This is an LLM Embedding Test String"])
)
logger.info("Embedding LLM Config Params Validated")
except Exception as e: # noqa: BLE001
logger.error(f"Embedding configuration error detected.\n{e}") # noqa
print(f"Failed to validate embedding model ({id}) params", e) # noqa: T201
sys.exit(1)

View File

@ -17,7 +17,6 @@ from graphrag.index.operations.chunk_text.strategies import get_encoding_fn
from graphrag.index.typing.context import PipelineRunContext
from graphrag.index.typing.workflow import WorkflowFunctionOutput
from graphrag.index.utils.hashing import gen_sha512_hash
from graphrag.logger.progress import Progress
from graphrag.utils.storage import load_table_from_storage, write_table_to_storage
logger = logging.getLogger(__name__)
@ -69,8 +68,6 @@ def create_base_text_units(
zip(*[sort[col] for col in ["id", "text"]], strict=True)
)
callbacks.progress(Progress(percent=0))
agg_dict = {"text_with_ids": list}
if "metadata" in documents:
agg_dict["metadata"] = "first" # type: ignore

View File

@ -13,6 +13,7 @@ from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
from graphrag.config.defaults import graphrag_config_defaults
from graphrag.config.enums import AsyncType
from graphrag.config.models.graph_rag_config import GraphRagConfig
from graphrag.config.models.language_model_config import LanguageModelConfig
from graphrag.index.operations.finalize_community_reports import (
finalize_community_reports,
)
@ -28,6 +29,7 @@ from graphrag.index.operations.summarize_communities.summarize_communities impor
)
from graphrag.index.typing.context import PipelineRunContext
from graphrag.index.typing.workflow import WorkflowFunctionOutput
from graphrag.tokenizer.get_tokenizer import get_tokenizer
from graphrag.utils.storage import (
load_table_from_storage,
storage_has_table,
@ -102,6 +104,9 @@ async def create_community_reports(
summarization_strategy["extraction_prompt"] = summarization_strategy["graph_prompt"]
model_config = LanguageModelConfig(**summarization_strategy["llm"])
tokenizer = get_tokenizer(model_config)
max_input_length = summarization_strategy.get(
"max_input_length", graphrag_config_defaults.community_reports.max_input_length
)
@ -110,6 +115,7 @@ async def create_community_reports(
nodes,
edges,
claims,
tokenizer,
callbacks,
max_input_length,
)
@ -122,6 +128,7 @@ async def create_community_reports(
callbacks,
cache,
summarization_strategy,
tokenizer=tokenizer,
max_input_length=max_input_length,
async_mode=async_mode,
num_threads=num_threads,

View File

@ -12,6 +12,7 @@ from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
from graphrag.config.defaults import graphrag_config_defaults
from graphrag.config.enums import AsyncType
from graphrag.config.models.graph_rag_config import GraphRagConfig
from graphrag.config.models.language_model_config import LanguageModelConfig
from graphrag.index.operations.finalize_community_reports import (
finalize_community_reports,
)
@ -27,6 +28,7 @@ from graphrag.index.operations.summarize_communities.text_unit_context.context_b
)
from graphrag.index.typing.context import PipelineRunContext
from graphrag.index.typing.workflow import WorkflowFunctionOutput
from graphrag.tokenizer.get_tokenizer import get_tokenizer
from graphrag.utils.storage import load_table_from_storage, write_table_to_storage
logger = logging.getLogger(__name__)
@ -88,8 +90,11 @@ async def create_community_reports_text(
"max_input_length", graphrag_config_defaults.community_reports.max_input_length
)
model_config = LanguageModelConfig(**summarization_strategy["llm"])
tokenizer = get_tokenizer(model_config)
local_contexts = build_local_context(
communities, text_units, nodes, max_input_length
communities, text_units, nodes, tokenizer, max_input_length
)
community_reports = await summarize_communities(
@ -100,6 +105,7 @@ async def create_community_reports_text(
callbacks,
cache,
summarization_strategy,
tokenizer=tokenizer,
max_input_length=max_input_length,
async_mode=async_mode,
num_threads=num_threads,

View File

@ -69,7 +69,7 @@ def create_final_documents(
).reset_index(drop=True)
rejoined["id"] = rejoined["id"].astype(str)
rejoined["human_readable_id"] = rejoined.index + 1
rejoined["human_readable_id"] = rejoined.index
if "metadata" not in rejoined.columns:
rejoined["metadata"] = pd.Series(dtype="object")

View File

@ -60,7 +60,7 @@ def create_final_text_units(
) -> pd.DataFrame:
"""All the steps to transform the text units."""
selected = text_units.loc[:, ["id", "text", "document_ids", "n_tokens"]]
selected["human_readable_id"] = selected.index + 1
selected["human_readable_id"] = selected.index
entity_join = _entities(final_entities)
relationship_join = _relationships(final_relationships)

View File

@ -88,6 +88,6 @@ async def extract_covariates(
)
text_units.drop(columns=["text_unit_id"], inplace=True) # don't pollute the global
covariates["id"] = covariates["covariate_type"].apply(lambda _x: str(uuid4()))
covariates["human_readable_id"] = covariates.index + 1
covariates["human_readable_id"] = covariates.index
return covariates.loc[:, COVARIATES_FINAL_COLUMNS]

View File

@ -15,7 +15,7 @@ from graphrag.config.models.graph_rag_config import GraphRagConfig
from graphrag.index.operations.extract_graph.extract_graph import (
extract_graph as extractor,
)
from graphrag.index.operations.summarize_descriptions import (
from graphrag.index.operations.summarize_descriptions.summarize_descriptions import (
summarize_descriptions,
)
from graphrag.index.typing.context import PipelineRunContext

View File

@ -61,6 +61,7 @@ async def extract_graph_nlp(
text_analyzer=text_analyzer,
normalize_edge_weights=extraction_config.normalize_edge_weights,
num_threads=extraction_config.concurrent_requests,
async_mode=extraction_config.async_mode,
cache=cache,
)

View File

@ -21,12 +21,11 @@ from graphrag.config.embeddings import (
)
from graphrag.config.get_embedding_settings import get_embedding_settings
from graphrag.config.models.graph_rag_config import GraphRagConfig
from graphrag.index.operations.embed_text import embed_text
from graphrag.index.operations.embed_text.embed_text import embed_text
from graphrag.index.typing.context import PipelineRunContext
from graphrag.index.typing.workflow import WorkflowFunctionOutput
from graphrag.utils.storage import (
load_table_from_storage,
storage_has_table,
write_table_to_storage,
)
@ -39,27 +38,35 @@ async def run_workflow(
) -> WorkflowFunctionOutput:
"""All the steps to transform community reports."""
logger.info("Workflow started: generate_text_embeddings")
embedded_fields = config.embed_text.names
logger.info("Embedding the following fields: %s", embedded_fields)
documents = None
relationships = None
text_units = None
entities = None
community_reports = None
if await storage_has_table("documents", context.output_storage):
if document_text_embedding in embedded_fields:
documents = await load_table_from_storage("documents", context.output_storage)
if await storage_has_table("relationships", context.output_storage):
if relationship_description_embedding in embedded_fields:
relationships = await load_table_from_storage(
"relationships", context.output_storage
)
if await storage_has_table("text_units", context.output_storage):
if text_unit_text_embedding in embedded_fields:
text_units = await load_table_from_storage("text_units", context.output_storage)
if await storage_has_table("entities", context.output_storage):
if (
entity_title_embedding in embedded_fields
or entity_description_embedding in embedded_fields
):
entities = await load_table_from_storage("entities", context.output_storage)
if await storage_has_table("community_reports", context.output_storage):
if (
community_title_embedding in embedded_fields
or community_summary_embedding in embedded_fields
or community_full_content_embedding in embedded_fields
):
community_reports = await load_table_from_storage(
"community_reports", context.output_storage
)
embedded_fields = config.embed_text.names
text_embed = get_embedding_settings(config)
output = await generate_text_embeddings(

View File

@ -7,13 +7,17 @@ from collections.abc import Callable
from typing import Any, ClassVar
from graphrag.config.enums import ModelType
from graphrag.language_model.protocol import ChatModel, EmbeddingModel
from graphrag.language_model.protocol.base import ChatModel, EmbeddingModel
from graphrag.language_model.providers.fnllm.models import (
AzureOpenAIChatFNLLM,
AzureOpenAIEmbeddingFNLLM,
OpenAIChatFNLLM,
OpenAIEmbeddingFNLLM,
)
from graphrag.language_model.providers.litellm.chat_model import LitellmChatModel
from graphrag.language_model.providers.litellm.embedding_model import (
LitellmEmbeddingModel,
)
class ModelFactory:
@ -100,15 +104,20 @@ class ModelFactory:
# --- Register default implementations ---
ModelFactory.register_chat(
ModelType.AzureOpenAIChat, lambda **kwargs: AzureOpenAIChatFNLLM(**kwargs)
ModelType.AzureOpenAIChat.value, lambda **kwargs: AzureOpenAIChatFNLLM(**kwargs)
)
ModelFactory.register_chat(
ModelType.OpenAIChat, lambda **kwargs: OpenAIChatFNLLM(**kwargs)
ModelType.OpenAIChat.value, lambda **kwargs: OpenAIChatFNLLM(**kwargs)
)
ModelFactory.register_chat(ModelType.Chat, lambda **kwargs: LitellmChatModel(**kwargs))
ModelFactory.register_embedding(
ModelType.AzureOpenAIEmbedding, lambda **kwargs: AzureOpenAIEmbeddingFNLLM(**kwargs)
ModelType.AzureOpenAIEmbedding.value,
lambda **kwargs: AzureOpenAIEmbeddingFNLLM(**kwargs),
)
ModelFactory.register_embedding(
ModelType.OpenAIEmbedding, lambda **kwargs: OpenAIEmbeddingFNLLM(**kwargs)
ModelType.OpenAIEmbedding.value, lambda **kwargs: OpenAIEmbeddingFNLLM(**kwargs)
)
ModelFactory.register_embedding(
ModelType.Embedding, lambda **kwargs: LitellmEmbeddingModel(**kwargs)
)

View File

@ -16,7 +16,7 @@ from typing_extensions import Self
from graphrag.language_model.factory import ModelFactory
if TYPE_CHECKING:
from graphrag.language_model.protocol import ChatModel, EmbeddingModel
from graphrag.language_model.protocol.base import ChatModel, EmbeddingModel
class ModelManager:

View File

@ -2,7 +2,3 @@
# Licensed under the MIT License
"""Base protocol definitions for LLMs."""
from .base import ChatModel, EmbeddingModel
__all__ = ["ChatModel", "EmbeddingModel"]

View File

@ -0,0 +1,4 @@
# Copyright (c) 2025 Microsoft Corporation.
# Licensed under the MIT License
"""GraphRAG LiteLLM module. Provides LiteLLM-based implementations of chat and embedding models."""

View File

@ -0,0 +1,414 @@
# Copyright (c) 2025 Microsoft Corporation.
# Licensed under the MIT License
"""Chat model implementation using Litellm."""
import inspect
import json
from collections.abc import AsyncGenerator, Generator
from typing import TYPE_CHECKING, Any, cast
import litellm
from azure.identity import DefaultAzureCredential, get_bearer_token_provider
from litellm import (
CustomStreamWrapper,
ModelResponse, # type: ignore
acompletion,
completion,
)
from pydantic import BaseModel, Field
from graphrag.config.defaults import COGNITIVE_SERVICES_AUDIENCE
from graphrag.config.enums import AuthType
from graphrag.language_model.providers.litellm.request_wrappers.with_cache import (
with_cache,
)
from graphrag.language_model.providers.litellm.request_wrappers.with_logging import (
with_logging,
)
from graphrag.language_model.providers.litellm.request_wrappers.with_rate_limiter import (
with_rate_limiter,
)
from graphrag.language_model.providers.litellm.request_wrappers.with_retries import (
with_retries,
)
from graphrag.language_model.providers.litellm.types import (
AFixedModelCompletion,
FixedModelCompletion,
)
if TYPE_CHECKING:
from graphrag.cache.pipeline_cache import PipelineCache
from graphrag.config.models.language_model_config import LanguageModelConfig
from graphrag.language_model.response.base import ModelResponse as MR # noqa: N817
litellm.suppress_debug_info = True
def _create_base_completions(
model_config: "LanguageModelConfig",
) -> tuple[FixedModelCompletion, AFixedModelCompletion]:
"""Wrap the base litellm completion function with the model configuration.
Args
----
model_config: The configuration for the language model.
Returns
-------
A tuple containing the synchronous and asynchronous completion functions.
"""
model_provider = model_config.model_provider
model = model_config.deployment_name or model_config.model
base_args: dict[str, Any] = {
"drop_params": True, # LiteLLM drop unsupported params for selected model.
"model": f"{model_provider}/{model}",
"timeout": model_config.request_timeout,
"top_p": model_config.top_p,
"n": model_config.n,
"temperature": model_config.temperature,
"frequency_penalty": model_config.frequency_penalty,
"presence_penalty": model_config.presence_penalty,
"api_base": model_config.api_base,
"api_version": model_config.api_version,
"api_key": model_config.api_key,
"organization": model_config.organization,
"proxy": model_config.proxy,
"audience": model_config.audience,
"max_tokens": model_config.max_tokens,
"max_completion_tokens": model_config.max_completion_tokens,
"reasoning_effort": model_config.reasoning_effort,
}
if model_config.auth_type == AuthType.AzureManagedIdentity:
if model_config.model_provider != "azure":
msg = "Azure Managed Identity authentication is only supported for Azure models."
raise ValueError(msg)
base_args["azure_scope"] = base_args.pop("audience")
base_args["azure_ad_token_provider"] = get_bearer_token_provider(
DefaultAzureCredential(),
model_config.audience or COGNITIVE_SERVICES_AUDIENCE,
)
def _base_completion(**kwargs: Any) -> ModelResponse | CustomStreamWrapper:
new_args = {**base_args, **kwargs}
if "name" in new_args:
new_args.pop("name")
return completion(**new_args)
async def _base_acompletion(**kwargs: Any) -> ModelResponse | CustomStreamWrapper:
new_args = {**base_args, **kwargs}
if "name" in new_args:
new_args.pop("name")
return await acompletion(**new_args)
return (_base_completion, _base_acompletion)
def _create_completions(
model_config: "LanguageModelConfig",
cache: "PipelineCache | None",
cache_key_prefix: str,
) -> tuple[FixedModelCompletion, AFixedModelCompletion]:
"""Wrap the base litellm completion function with the model configuration and additional features.
Wrap the base litellm completion function with instance variables based on the model configuration.
Then wrap additional features such as rate limiting, retries, and caching, if enabled.
Final function composition order:
- Logging(Cache(Retries(RateLimiter(ModelCompletion()))))
Args
----
model_config: The configuration for the language model.
cache: Optional cache for storing responses.
cache_key_prefix: Prefix for cache keys.
Returns
-------
A tuple containing the synchronous and asynchronous completion functions.
"""
completion, acompletion = _create_base_completions(model_config)
# TODO: For v2.x release, rpm/tpm can be int or str (auto) for backwards compatibility with fnllm.
# LiteLLM does not support "auto", so we have to check those values here.
# For v3 release, force rpm/tpm to be int and remove the type checks below
# and just check if rate_limit_strategy is enabled.
if model_config.rate_limit_strategy is not None:
rpm = (
model_config.requests_per_minute
if type(model_config.requests_per_minute) is int
else None
)
tpm = (
model_config.tokens_per_minute
if type(model_config.tokens_per_minute) is int
else None
)
if rpm is not None or tpm is not None:
completion, acompletion = with_rate_limiter(
sync_fn=completion,
async_fn=acompletion,
model_config=model_config,
rpm=rpm,
tpm=tpm,
)
if model_config.retry_strategy != "none":
completion, acompletion = with_retries(
sync_fn=completion,
async_fn=acompletion,
model_config=model_config,
)
if cache is not None:
completion, acompletion = with_cache(
sync_fn=completion,
async_fn=acompletion,
model_config=model_config,
cache=cache,
request_type="chat",
cache_key_prefix=cache_key_prefix,
)
completion, acompletion = with_logging(
sync_fn=completion,
async_fn=acompletion,
)
return (completion, acompletion)
class LitellmModelOutput(BaseModel):
"""A model representing the output from a language model."""
content: str = Field(description="The generated text content")
full_response: None = Field(
default=None, description="The full response from the model, if available"
)
class LitellmModelResponse(BaseModel):
"""A model representing the response from a language model."""
output: LitellmModelOutput = Field(description="The output from the model")
parsed_response: BaseModel | None = Field(
default=None, description="Parsed response from the model"
)
history: list = Field(
default_factory=list,
description="Conversation history including the prompt and response",
)
class LitellmChatModel:
"""LiteLLM-based Chat Model."""
def __init__(
self,
name: str,
config: "LanguageModelConfig",
cache: "PipelineCache | None" = None,
**kwargs: Any,
):
self.name = name
self.config = config
self.cache = cache.child(self.name) if cache else None
self.completion, self.acompletion = _create_completions(
config, self.cache, "chat"
)
def _get_kwargs(self, **kwargs: Any) -> dict[str, Any]:
"""Get model arguments supported by litellm."""
args_to_include = [
"name",
"modalities",
"prediction",
"audio",
"logit_bias",
"metadata",
"user",
"response_format",
"seed",
"tools",
"tool_choice",
"logprobs",
"top_logprobs",
"parallel_tool_calls",
"web_search_options",
"extra_headers",
"functions",
"function_call",
"thinking",
]
new_args = {k: v for k, v in kwargs.items() if k in args_to_include}
# If using JSON, check if response_format should be a Pydantic model or just a general JSON object
if kwargs.get("json"):
new_args["response_format"] = {"type": "json_object"}
if (
"json_model" in kwargs
and inspect.isclass(kwargs["json_model"])
and issubclass(kwargs["json_model"], BaseModel)
):
new_args["response_format"] = kwargs["json_model"]
return new_args
async def achat(
self, prompt: str, history: list | None = None, **kwargs: Any
) -> "MR":
"""
Generate a response for the given prompt and history.
Args
----
prompt: The prompt to generate a response for.
history: Optional chat history.
**kwargs: Additional keyword arguments (e.g., model parameters).
Returns
-------
LitellmModelResponse: The generated model response.
"""
new_kwargs = self._get_kwargs(**kwargs)
messages: list[dict[str, str]] = history or []
messages.append({"role": "user", "content": prompt})
response = await self.acompletion(messages=messages, stream=False, **new_kwargs) # type: ignore
messages.append({
"role": "assistant",
"content": response.choices[0].message.content or "", # type: ignore
})
parsed_response: BaseModel | None = None
if "response_format" in new_kwargs:
parsed_dict: dict[str, Any] = json.loads(
response.choices[0].message.content or "{}" # type: ignore
)
parsed_response = parsed_dict # type: ignore
if inspect.isclass(new_kwargs["response_format"]) and issubclass(
new_kwargs["response_format"], BaseModel
):
# If response_format is a pydantic model, instantiate it
model_initializer = cast(
"type[BaseModel]", new_kwargs["response_format"]
)
parsed_response = model_initializer(**parsed_dict)
return LitellmModelResponse(
output=LitellmModelOutput(
content=response.choices[0].message.content or "" # type: ignore
),
parsed_response=parsed_response,
history=messages,
)
async def achat_stream(
self, prompt: str, history: list | None = None, **kwargs: Any
) -> AsyncGenerator[str, None]:
"""
Generate a response for the given prompt and history.
Args
----
prompt: The prompt to generate a response for.
history: Optional chat history.
**kwargs: Additional keyword arguments (e.g., model parameters).
Returns
-------
AsyncGenerator[str, None]: The generated response as a stream of strings.
"""
new_kwargs = self._get_kwargs(**kwargs)
messages: list[dict[str, str]] = history or []
messages.append({"role": "user", "content": prompt})
response = await self.acompletion(messages=messages, stream=True, **new_kwargs) # type: ignore
async for chunk in response: # type: ignore
if chunk.choices and chunk.choices[0].delta.content:
yield chunk.choices[0].delta.content
def chat(self, prompt: str, history: list | None = None, **kwargs: Any) -> "MR":
"""
Generate a response for the given prompt and history.
Args
----
prompt: The prompt to generate a response for.
history: Optional chat history.
**kwargs: Additional keyword arguments (e.g., model parameters).
Returns
-------
LitellmModelResponse: The generated model response.
"""
new_kwargs = self._get_kwargs(**kwargs)
messages: list[dict[str, str]] = history or []
messages.append({"role": "user", "content": prompt})
response = self.completion(messages=messages, stream=False, **new_kwargs) # type: ignore
messages.append({
"role": "assistant",
"content": response.choices[0].message.content or "", # type: ignore
})
parsed_response: BaseModel | None = None
if "response_format" in new_kwargs:
parsed_dict: dict[str, Any] = json.loads(
response.choices[0].message.content or "{}" # type: ignore
)
parsed_response = parsed_dict # type: ignore
if inspect.isclass(new_kwargs["response_format"]) and issubclass(
new_kwargs["response_format"], BaseModel
):
# If response_format is a pydantic model, instantiate it
model_initializer = cast(
"type[BaseModel]", new_kwargs["response_format"]
)
parsed_response = model_initializer(**parsed_dict)
return LitellmModelResponse(
output=LitellmModelOutput(
content=response.choices[0].message.content or "" # type: ignore
),
parsed_response=parsed_response,
history=messages,
)
def chat_stream(
self, prompt: str, history: list | None = None, **kwargs: Any
) -> Generator[str, None]:
"""
Generate a response for the given prompt and history.
Args
----
prompt: The prompt to generate a response for.
history: Optional chat history.
**kwargs: Additional keyword arguments (e.g., model parameters).
Returns
-------
Generator[str, None]: The generated response as a stream of strings.
"""
new_kwargs = self._get_kwargs(**kwargs)
messages: list[dict[str, str]] = history or []
messages.append({"role": "user", "content": prompt})
response = self.completion(messages=messages, stream=True, **new_kwargs) # type: ignore
for chunk in response:
if chunk.choices and chunk.choices[0].delta.content: # type: ignore
yield chunk.choices[0].delta.content # type: ignore

View File

@ -0,0 +1,280 @@
# Copyright (c) 2025 Microsoft Corporation.
# Licensed under the MIT License
"""Embedding model implementation using Litellm."""
from typing import TYPE_CHECKING, Any
import litellm
from azure.identity import DefaultAzureCredential, get_bearer_token_provider
from litellm import (
EmbeddingResponse, # type: ignore
aembedding,
embedding,
)
from graphrag.config.defaults import COGNITIVE_SERVICES_AUDIENCE
from graphrag.config.enums import AuthType
from graphrag.language_model.providers.litellm.request_wrappers.with_cache import (
with_cache,
)
from graphrag.language_model.providers.litellm.request_wrappers.with_logging import (
with_logging,
)
from graphrag.language_model.providers.litellm.request_wrappers.with_rate_limiter import (
with_rate_limiter,
)
from graphrag.language_model.providers.litellm.request_wrappers.with_retries import (
with_retries,
)
from graphrag.language_model.providers.litellm.types import (
AFixedModelEmbedding,
FixedModelEmbedding,
)
if TYPE_CHECKING:
from graphrag.cache.pipeline_cache import PipelineCache
from graphrag.config.models.language_model_config import LanguageModelConfig
litellm.suppress_debug_info = True
def _create_base_embeddings(
model_config: "LanguageModelConfig",
) -> tuple[FixedModelEmbedding, AFixedModelEmbedding]:
"""Wrap the base litellm embedding function with the model configuration.
Args
----
model_config: The configuration for the language model.
Returns
-------
A tuple containing the synchronous and asynchronous embedding functions.
"""
model_provider = model_config.model_provider
model = model_config.deployment_name or model_config.model
base_args: dict[str, Any] = {
"drop_params": True, # LiteLLM drop unsupported params for selected model.
"model": f"{model_provider}/{model}",
"timeout": model_config.request_timeout,
"api_base": model_config.api_base,
"api_version": model_config.api_version,
"api_key": model_config.api_key,
"organization": model_config.organization,
"proxy": model_config.proxy,
"audience": model_config.audience,
}
if model_config.auth_type == AuthType.AzureManagedIdentity:
if model_config.model_provider != "azure":
msg = "Azure Managed Identity authentication is only supported for Azure models."
raise ValueError(msg)
base_args["azure_scope"] = base_args.pop("audience")
base_args["azure_ad_token_provider"] = get_bearer_token_provider(
DefaultAzureCredential(),
model_config.audience or COGNITIVE_SERVICES_AUDIENCE,
)
def _base_embedding(**kwargs: Any) -> EmbeddingResponse:
new_args = {**base_args, **kwargs}
if "name" in new_args:
new_args.pop("name")
return embedding(**new_args)
async def _base_aembedding(**kwargs: Any) -> EmbeddingResponse:
new_args = {**base_args, **kwargs}
if "name" in new_args:
new_args.pop("name")
return await aembedding(**new_args)
return (_base_embedding, _base_aembedding)
def _create_embeddings(
model_config: "LanguageModelConfig",
cache: "PipelineCache | None",
cache_key_prefix: str,
) -> tuple[FixedModelEmbedding, AFixedModelEmbedding]:
"""Wrap the base litellm embedding function with the model configuration and additional features.
Wrap the base litellm embedding function with instance variables based on the model configuration.
Then wrap additional features such as rate limiting, retries, and caching, if enabled.
Final function composition order:
- Logging(Cache(Retries(RateLimiter(ModelEmbedding()))))
Args
----
model_config: The configuration for the language model.
cache: Optional cache for storing responses.
cache_key_prefix: Prefix for cache keys.
Returns
-------
A tuple containing the synchronous and asynchronous embedding functions.
"""
embedding, aembedding = _create_base_embeddings(model_config)
# TODO: For v2.x release, rpm/tpm can be int or str (auto) for backwards compatibility with fnllm.
# LiteLLM does not support "auto", so we have to check those values here.
# For v3 release, force rpm/tpm to be int and remove the type checks below
# and just check if rate_limit_strategy is enabled.
if model_config.rate_limit_strategy is not None:
rpm = (
model_config.requests_per_minute
if type(model_config.requests_per_minute) is int
else None
)
tpm = (
model_config.tokens_per_minute
if type(model_config.tokens_per_minute) is int
else None
)
if rpm is not None or tpm is not None:
embedding, aembedding = with_rate_limiter(
sync_fn=embedding,
async_fn=aembedding,
model_config=model_config,
rpm=rpm,
tpm=tpm,
)
if model_config.retry_strategy != "none":
embedding, aembedding = with_retries(
sync_fn=embedding,
async_fn=aembedding,
model_config=model_config,
)
if cache is not None:
embedding, aembedding = with_cache(
sync_fn=embedding,
async_fn=aembedding,
model_config=model_config,
cache=cache,
request_type="embedding",
cache_key_prefix=cache_key_prefix,
)
embedding, aembedding = with_logging(
sync_fn=embedding,
async_fn=aembedding,
)
return (embedding, aembedding)
class LitellmEmbeddingModel:
"""LiteLLM-based Embedding Model."""
def __init__(
self,
name: str,
config: "LanguageModelConfig",
cache: "PipelineCache | None" = None,
**kwargs: Any,
):
self.name = name
self.config = config
self.cache = cache.child(self.name) if cache else None
self.embedding, self.aembedding = _create_embeddings(
config, self.cache, "embeddings"
)
def _get_kwargs(self, **kwargs: Any) -> dict[str, Any]:
"""Get model arguments supported by litellm."""
args_to_include = [
"name",
"dimensions",
"encoding_format",
"timeout",
"user",
]
return {k: v for k, v in kwargs.items() if k in args_to_include}
async def aembed_batch(
self, text_list: list[str], **kwargs: Any
) -> list[list[float]]:
"""
Batch generate embeddings.
Args
----
text_list: A batch of text inputs to generate embeddings for.
**kwargs: Additional keyword arguments (e.g., model parameters).
Returns
-------
A Batch of embeddings.
"""
new_kwargs = self._get_kwargs(**kwargs)
response = await self.aembedding(input=text_list, **new_kwargs)
return [emb.get("embedding", []) for emb in response.data]
async def aembed(self, text: str, **kwargs: Any) -> list[float]:
"""
Async embed.
Args:
text: The text to generate an embedding for.
**kwargs: Additional keyword arguments (e.g., model parameters).
Returns
-------
An embedding.
"""
new_kwargs = self._get_kwargs(**kwargs)
response = await self.aembedding(input=[text], **new_kwargs)
return (
response.data[0].get("embedding", [])
if response.data and response.data[0]
else []
)
def embed_batch(self, text_list: list[str], **kwargs: Any) -> list[list[float]]:
"""
Batch generate embeddings.
Args:
text_list: A batch of text inputs to generate embeddings for.
**kwargs: Additional keyword arguments (e.g., model parameters).
Returns
-------
A Batch of embeddings.
"""
new_kwargs = self._get_kwargs(**kwargs)
response = self.embedding(input=text_list, **new_kwargs)
return [emb.get("embedding", []) for emb in response.data]
def embed(self, text: str, **kwargs: Any) -> list[float]:
"""
Embed a single text input.
Args:
text: The text to generate an embedding for.
**kwargs: Additional keyword arguments (e.g., model parameters).
Returns
-------
An embedding.
"""
new_kwargs = self._get_kwargs(**kwargs)
response = self.embedding(input=[text], **new_kwargs)
return (
response.data[0].get("embedding", [])
if response.data and response.data[0]
else []
)

View File

@ -0,0 +1,140 @@
# Copyright (c) 2025 Microsoft Corporation.
# Licensed under the MIT License
"""
LiteLLM cache key generation.
Modeled after the fnllm cache key generation.
https://github.com/microsoft/essex-toolkit/blob/23d3077b65c0e8f1d89c397a2968fe570a25f790/python/fnllm/fnllm/caching/base.py#L50
"""
import hashlib
import inspect
import json
from typing import TYPE_CHECKING, Any
from pydantic import BaseModel
if TYPE_CHECKING:
from graphrag.config.models.language_model_config import LanguageModelConfig
_CACHE_VERSION = 3
"""
If there's a breaking change in what we cache, we should increment this version number to invalidate existing caches.
fnllm was on cache version 2 and though we generate
similar cache keys, the objects stored in cache by fnllm and litellm are different.
Using litellm model providers will not be able to reuse caches generated by fnllm
thus we start with version 3 for litellm.
"""
def get_cache_key(
model_config: "LanguageModelConfig",
prefix: str,
messages: str | None = None,
input: str | None = None,
**kwargs: Any,
) -> str:
"""Generate a cache key based on the model configuration and input arguments.
Modeled after the fnllm cache key generation.
https://github.com/microsoft/essex-toolkit/blob/23d3077b65c0e8f1d89c397a2968fe570a25f790/python/fnllm/fnllm/caching/base.py#L50
Args
____
model_config: The configuration of the language model.
prefix: A prefix for the cache key.
**kwargs: Additional model input parameters.
Returns
-------
`{prefix}_{data_hash}_v{version}` if prefix is provided.
"""
cache_key: dict[str, Any] = {
"parameters": _get_parameters(model_config, **kwargs),
}
if messages is not None and input is not None:
msg = "Only one of 'messages' or 'input' should be provided."
raise ValueError(msg)
if messages is not None:
cache_key["messages"] = messages
elif input is not None:
cache_key["input"] = input
else:
msg = "Either 'messages' or 'input' must be provided."
raise ValueError(msg)
data_hash = _hash(json.dumps(cache_key, sort_keys=True))
name = kwargs.get("name")
if name:
prefix += f"_{name}"
return f"{prefix}_{data_hash}_v{_CACHE_VERSION}"
def _get_parameters(
model_config: "LanguageModelConfig",
**kwargs: Any,
) -> dict[str, Any]:
"""Pluck out the parameters that define a cache key.
Use the same parameters as fnllm except request timeout.
- embeddings: https://github.com/microsoft/essex-toolkit/blob/main/python/fnllm/fnllm/openai/types/embeddings/parameters.py#L12
- chat: https://github.com/microsoft/essex-toolkit/blob/main/python/fnllm/fnllm/openai/types/chat/parameters.py#L25
Args
____
model_config: The configuration of the language model.
**kwargs: Additional model input parameters.
Returns
-------
dict[str, Any]: A dictionary of parameters that define the cache key.
"""
parameters = {
"model": model_config.deployment_name or model_config.model,
"frequency_penalty": model_config.frequency_penalty,
"max_tokens": model_config.max_tokens,
"max_completion_tokens": model_config.max_completion_tokens,
"n": model_config.n,
"presence_penalty": model_config.presence_penalty,
"temperature": model_config.temperature,
"top_p": model_config.top_p,
"reasoning_effort": model_config.reasoning_effort,
}
keys_to_cache = [
"function_call",
"functions",
"logit_bias",
"logprobs",
"parallel_tool_calls",
"seed",
"service_tier",
"stop",
"tool_choice",
"tools",
"top_logprobs",
"user",
"dimensions",
"encoding_format",
]
parameters.update({key: kwargs.get(key) for key in keys_to_cache if key in kwargs})
response_format = kwargs.get("response_format")
if inspect.isclass(response_format) and issubclass(response_format, BaseModel):
parameters["response_format"] = str(response_format)
elif response_format is not None:
parameters["response_format"] = response_format
return parameters
def _hash(input: str) -> str:
"""Generate a hash for the input string."""
return hashlib.sha256(input.encode()).hexdigest()

View File

@ -0,0 +1,4 @@
# Copyright (c) 2025 Microsoft Corporation.
# Licensed under the MIT License
"""LiteLLM completion/embedding function wrappers."""

View File

@ -0,0 +1,107 @@
# Copyright (c) 2025 Microsoft Corporation.
# Licensed under the MIT License
"""LiteLLM completion/embedding cache wrapper."""
import asyncio
from typing import TYPE_CHECKING, Any, Literal
from litellm import EmbeddingResponse, ModelResponse # type: ignore
from graphrag.language_model.providers.litellm.get_cache_key import get_cache_key
from graphrag.language_model.providers.litellm.types import (
AsyncLitellmRequestFunc,
LitellmRequestFunc,
)
if TYPE_CHECKING:
from graphrag.cache.pipeline_cache import PipelineCache
from graphrag.config.models.language_model_config import LanguageModelConfig
def with_cache(
*,
sync_fn: LitellmRequestFunc,
async_fn: AsyncLitellmRequestFunc,
model_config: "LanguageModelConfig",
cache: "PipelineCache",
request_type: Literal["chat", "embedding"],
cache_key_prefix: str,
) -> tuple[LitellmRequestFunc, AsyncLitellmRequestFunc]:
"""
Wrap the synchronous and asynchronous request functions with caching.
Args
----
sync_fn: The synchronous chat/embedding request function to wrap.
async_fn: The asynchronous chat/embedding request function to wrap.
model_config: The configuration for the language model.
cache: The cache to use for storing responses.
request_type: The type of request being made, either "chat" or "embedding".
cache_key_prefix: The prefix to use for cache keys.
Returns
-------
A tuple containing the wrapped synchronous and asynchronous chat/embedding request functions.
"""
def _wrapped_with_cache(**kwargs: Any) -> Any:
is_streaming = kwargs.get("stream", False)
if is_streaming:
return sync_fn(**kwargs)
cache_key = get_cache_key(
model_config=model_config, prefix=cache_key_prefix, **kwargs
)
event_loop = asyncio.get_event_loop()
cached_response = event_loop.run_until_complete(cache.get(cache_key))
if (
cached_response is not None
and isinstance(cached_response, dict)
and "response" in cached_response
and cached_response["response"] is not None
and isinstance(cached_response["response"], dict)
):
try:
if request_type == "chat":
return ModelResponse(**cached_response["response"])
return EmbeddingResponse(**cached_response["response"])
except Exception: # noqa: BLE001
# Try to retrieve value from cache but if it fails, continue
# to make the request.
...
response = sync_fn(**kwargs)
event_loop.run_until_complete(
cache.set(cache_key, {"response": response.model_dump()})
)
return response
async def _wrapped_with_cache_async(
**kwargs: Any,
) -> Any:
is_streaming = kwargs.get("stream", False)
if is_streaming:
return await async_fn(**kwargs)
cache_key = get_cache_key(
model_config=model_config, prefix=cache_key_prefix, **kwargs
)
cached_response = await cache.get(cache_key)
if (
cached_response is not None
and isinstance(cached_response, dict)
and "response" in cached_response
and cached_response["response"] is not None
and isinstance(cached_response["response"], dict)
):
try:
if request_type == "chat":
return ModelResponse(**cached_response["response"])
return EmbeddingResponse(**cached_response["response"])
except Exception: # noqa: BLE001
# Try to retrieve value from cache but if it fails, continue
# to make the request.
...
response = await async_fn(**kwargs)
await cache.set(cache_key, {"response": response.model_dump()})
return response
return (_wrapped_with_cache, _wrapped_with_cache_async)

View File

@ -0,0 +1,56 @@
# Copyright (c) 2025 Microsoft Corporation.
# Licensed under the MIT License
"""LiteLLM completion/embedding logging wrapper."""
import logging
from typing import Any
from graphrag.language_model.providers.litellm.types import (
AsyncLitellmRequestFunc,
LitellmRequestFunc,
)
logger = logging.getLogger(__name__)
def with_logging(
*,
sync_fn: LitellmRequestFunc,
async_fn: AsyncLitellmRequestFunc,
) -> tuple[LitellmRequestFunc, AsyncLitellmRequestFunc]:
"""
Wrap the synchronous and asynchronous request functions with retries.
Args
----
sync_fn: The synchronous chat/embedding request function to wrap.
async_fn: The asynchronous chat/embedding request function to wrap.
model_config: The configuration for the language model.
Returns
-------
A tuple containing the wrapped synchronous and asynchronous chat/embedding request functions.
"""
def _wrapped_with_logging(**kwargs: Any) -> Any:
try:
return sync_fn(**kwargs)
except Exception as e:
logger.exception(
f"with_logging: Request failed with exception={e}", # noqa: G004, TRY401
)
raise
async def _wrapped_with_logging_async(
**kwargs: Any,
) -> Any:
try:
return await async_fn(**kwargs)
except Exception as e:
logger.exception(
f"with_logging: Async request failed with exception={e}", # noqa: G004, TRY401
)
raise
return (_wrapped_with_logging, _wrapped_with_logging_async)

View File

@ -0,0 +1,97 @@
# Copyright (c) 2025 Microsoft Corporation.
# Licensed under the MIT License
"""LiteLLM completion/embedding rate limiter wrapper."""
from typing import TYPE_CHECKING, Any
from litellm import token_counter # type: ignore
from graphrag.language_model.providers.litellm.services.rate_limiter.rate_limiter_factory import (
RateLimiterFactory,
)
from graphrag.language_model.providers.litellm.types import (
AsyncLitellmRequestFunc,
LitellmRequestFunc,
)
if TYPE_CHECKING:
from graphrag.config.models.language_model_config import LanguageModelConfig
def with_rate_limiter(
*,
sync_fn: LitellmRequestFunc,
async_fn: AsyncLitellmRequestFunc,
model_config: "LanguageModelConfig",
rpm: int | None = None,
tpm: int | None = None,
) -> tuple[LitellmRequestFunc, AsyncLitellmRequestFunc]:
"""
Wrap the synchronous and asynchronous request functions with rate limiting.
Args
----
sync_fn: The synchronous chat/embedding request function to wrap.
async_fn: The asynchronous chat/embedding request function to wrap.
model_config: The configuration for the language model.
processing_event: A threading event that can be used to pause the rate limiter.
rpm: An optional requests per minute limit.
tpm: An optional tokens per minute limit.
If `rpm` and `tpm` is set to 0 or None, rate limiting is disabled.
Returns
-------
A tuple containing the wrapped synchronous and asynchronous chat/embedding request functions.
"""
rate_limiter_factory = RateLimiterFactory()
if (
model_config.rate_limit_strategy is None
or model_config.rate_limit_strategy not in rate_limiter_factory
):
msg = f"Rate Limiter strategy '{model_config.rate_limit_strategy}' is none or not registered. Available strategies: {', '.join(rate_limiter_factory.keys())}"
raise ValueError(msg)
rate_limiter_service = rate_limiter_factory.create(
strategy=model_config.rate_limit_strategy, rpm=rpm, tpm=tpm
)
max_tokens = model_config.max_completion_tokens or model_config.max_tokens or 0
def _wrapped_with_rate_limiter(**kwargs: Any) -> Any:
token_count = max_tokens
if "messages" in kwargs:
token_count += token_counter(
model=model_config.model,
messages=kwargs["messages"],
)
elif "input" in kwargs:
token_count += token_counter(
model=model_config.model,
text=kwargs["input"],
)
with rate_limiter_service.acquire(token_count=token_count):
return sync_fn(**kwargs)
async def _wrapped_with_rate_limiter_async(
**kwargs: Any,
) -> Any:
token_count = max_tokens
if "messages" in kwargs:
token_count += token_counter(
model=model_config.model,
messages=kwargs["messages"],
)
elif "input" in kwargs:
token_count += token_counter(
model=model_config.model,
text=kwargs["input"],
)
with rate_limiter_service.acquire(token_count=token_count):
return await async_fn(**kwargs)
return (_wrapped_with_rate_limiter, _wrapped_with_rate_limiter_async)

View File

@ -0,0 +1,54 @@
# Copyright (c) 2025 Microsoft Corporation.
# Licensed under the MIT License
"""LiteLLM completion/embedding retries wrapper."""
from typing import TYPE_CHECKING, Any
from graphrag.language_model.providers.litellm.services.retry.retry_factory import (
RetryFactory,
)
from graphrag.language_model.providers.litellm.types import (
AsyncLitellmRequestFunc,
LitellmRequestFunc,
)
if TYPE_CHECKING:
from graphrag.config.models.language_model_config import LanguageModelConfig
def with_retries(
*,
sync_fn: LitellmRequestFunc,
async_fn: AsyncLitellmRequestFunc,
model_config: "LanguageModelConfig",
) -> tuple[LitellmRequestFunc, AsyncLitellmRequestFunc]:
"""
Wrap the synchronous and asynchronous request functions with retries.
Args
----
sync_fn: The synchronous chat/embedding request function to wrap.
async_fn: The asynchronous chat/embedding request function to wrap.
model_config: The configuration for the language model.
Returns
-------
A tuple containing the wrapped synchronous and asynchronous chat/embedding request functions.
"""
retry_factory = RetryFactory()
retry_service = retry_factory.create(
strategy=model_config.retry_strategy,
max_retries=model_config.max_retries,
max_retry_wait=model_config.max_retry_wait,
)
def _wrapped_with_retries(**kwargs: Any) -> Any:
return retry_service.retry(func=sync_fn, **kwargs)
async def _wrapped_with_retries_async(
**kwargs: Any,
) -> Any:
return await retry_service.aretry(func=async_fn, **kwargs)
return (_wrapped_with_retries, _wrapped_with_retries_async)

View File

@ -0,0 +1,4 @@
# Copyright (c) 2025 Microsoft Corporation.
# Licensed under the MIT License
"""LiteLLM Services."""

View File

@ -0,0 +1,4 @@
# Copyright (c) 2025 Microsoft Corporation.
# Licensed under the MIT License
"""LiteLLM Rate Limiter."""

View File

@ -0,0 +1,37 @@
# Copyright (c) 2025 Microsoft Corporation.
# Licensed under the MIT License
"""LiteLLM Rate Limiter."""
from abc import ABC, abstractmethod
from collections.abc import Iterator
from contextlib import contextmanager
from typing import Any
class RateLimiter(ABC):
"""Abstract base class for rate limiters."""
@abstractmethod
def __init__(
self,
/,
**kwargs: Any,
) -> None: ...
@abstractmethod
@contextmanager
def acquire(self, *, token_count: int) -> Iterator[None]:
"""
Acquire Rate Limiter.
Args
----
token_count: The estimated number of tokens for the current request.
Yields
------
None: This context manager does not return any value.
"""
msg = "RateLimiter subclasses must implement the acquire method."
raise NotImplementedError(msg)

View File

@ -0,0 +1,22 @@
# Copyright (c) 2025 Microsoft Corporation.
# Licensed under the MIT License
"""LiteLLM Rate Limiter Factory."""
from graphrag.config.defaults import DEFAULT_RATE_LIMITER_SERVICES
from graphrag.factory.factory import Factory
from graphrag.language_model.providers.litellm.services.rate_limiter.rate_limiter import (
RateLimiter,
)
class RateLimiterFactory(Factory[RateLimiter]):
"""Singleton factory for creating rate limiter services."""
rate_limiter_factory = RateLimiterFactory()
for service_name, service_cls in DEFAULT_RATE_LIMITER_SERVICES.items():
rate_limiter_factory.register(
strategy=service_name, service_initializer=service_cls
)

Some files were not shown because too many files have changed in this diff Show More