mirror of
https://github.com/microsoft/graphrag.git
synced 2026-01-14 00:57:23 +08:00
Compare commits
22 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
fdb7e3835b | ||
|
|
ac8a7f5eef | ||
|
|
6c86b0a7bb | ||
|
|
2bd3922d8d | ||
|
|
7f996cf584 | ||
|
|
9bc899fe95 | ||
|
|
2b70e4a4f3 | ||
|
|
82cd3b7df2 | ||
|
|
075cadd59a | ||
|
|
6d7a50b7f0 | ||
|
|
2bf7e7c018 | ||
|
|
6c66b7c30f | ||
|
|
a398cc38bb | ||
|
|
ac95c917d3 | ||
|
|
1cb20b66f5 | ||
|
|
2030f94eb4 | ||
|
|
69ad36e735 | ||
|
|
30bdb35cc8 | ||
|
|
77fb7d9d7d | ||
|
|
469ee8568f | ||
|
|
7c28c70d5c | ||
|
|
5713205210 |
17
.github/workflows/gh-pages.yml
vendored
17
.github/workflows/gh-pages.yml
vendored
@ -6,8 +6,7 @@ permissions:
|
||||
contents: write
|
||||
|
||||
env:
|
||||
POETRY_VERSION: '1.8.3'
|
||||
PYTHON_VERSION: '3.11'
|
||||
PYTHON_VERSION: "3.11"
|
||||
|
||||
jobs:
|
||||
build:
|
||||
@ -16,8 +15,6 @@ jobs:
|
||||
GH_PAGES: 1
|
||||
DEBUG: 1
|
||||
GRAPHRAG_API_KEY: ${{ secrets.GRAPHRAG_API_KEY }}
|
||||
GRAPHRAG_LLM_MODEL: ${{ secrets.GRAPHRAG_LLM_MODEL }}
|
||||
GRAPHRAG_EMBEDDING_MODEL: ${{ secrets.GRAPHRAG_EMBEDDING_MODEL }}
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
@ -29,18 +26,16 @@ jobs:
|
||||
with:
|
||||
python-version: ${{ env.PYTHON_VERSION }}
|
||||
|
||||
- name: Install Poetry ${{ env.POETRY_VERSION }}
|
||||
uses: abatilo/actions-poetry@v3.0.0
|
||||
with:
|
||||
poetry-version: ${{ env.POETRY_VERSION }}
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v6
|
||||
|
||||
- name: poetry intsall
|
||||
- name: Install dependencies
|
||||
shell: bash
|
||||
run: poetry install
|
||||
run: uv sync
|
||||
|
||||
- name: mkdocs build
|
||||
shell: bash
|
||||
run: poetry run poe build_docs
|
||||
run: uv run poe build_docs
|
||||
|
||||
- name: List Docsite Contents
|
||||
run: find site
|
||||
|
||||
24
.github/workflows/python-ci.yml
vendored
24
.github/workflows/python-ci.yml
vendored
@ -26,9 +26,6 @@ concurrency:
|
||||
# Only run the for the latest commit
|
||||
cancel-in-progress: true
|
||||
|
||||
env:
|
||||
POETRY_VERSION: 1.8.3
|
||||
|
||||
jobs:
|
||||
python-ci:
|
||||
# skip draft PRs
|
||||
@ -51,7 +48,7 @@ jobs:
|
||||
filters: |
|
||||
python:
|
||||
- 'graphrag/**/*'
|
||||
- 'poetry.lock'
|
||||
- 'uv.lock'
|
||||
- 'pyproject.toml'
|
||||
- '**/*.py'
|
||||
- '**/*.toml'
|
||||
@ -64,30 +61,27 @@ jobs:
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
- name: Install Poetry
|
||||
uses: abatilo/actions-poetry@v3.0.0
|
||||
with:
|
||||
poetry-version: $POETRY_VERSION
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v6
|
||||
|
||||
- name: Install dependencies
|
||||
shell: bash
|
||||
run: |
|
||||
poetry self add setuptools wheel
|
||||
poetry run python -m pip install gensim
|
||||
poetry install
|
||||
uv sync
|
||||
uv pip install gensim
|
||||
|
||||
- name: Check
|
||||
run: |
|
||||
poetry run poe check
|
||||
uv run poe check
|
||||
|
||||
- name: Build
|
||||
run: |
|
||||
poetry build
|
||||
uv build
|
||||
|
||||
- name: Unit Test
|
||||
run: |
|
||||
poetry run poe test_unit
|
||||
uv run poe test_unit
|
||||
|
||||
- name: Verb Test
|
||||
run: |
|
||||
poetry run poe test_verbs
|
||||
uv run poe test_verbs
|
||||
|
||||
28
.github/workflows/python-integration-tests.yml
vendored
28
.github/workflows/python-integration-tests.yml
vendored
@ -26,9 +26,6 @@ concurrency:
|
||||
# only run the for the latest commit
|
||||
cancel-in-progress: true
|
||||
|
||||
env:
|
||||
POETRY_VERSION: 1.8.3
|
||||
|
||||
jobs:
|
||||
python-ci:
|
||||
# skip draft PRs
|
||||
@ -51,7 +48,7 @@ jobs:
|
||||
filters: |
|
||||
python:
|
||||
- 'graphrag/**/*'
|
||||
- 'poetry.lock'
|
||||
- 'uv.lock'
|
||||
- 'pyproject.toml'
|
||||
- '**/*.py'
|
||||
- '**/*.toml'
|
||||
@ -64,25 +61,24 @@ jobs:
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
- name: Install Poetry
|
||||
uses: abatilo/actions-poetry@v3.0.0
|
||||
with:
|
||||
poetry-version: $POETRY_VERSION
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v6
|
||||
|
||||
- name: Install dependencies
|
||||
shell: bash
|
||||
run: |
|
||||
poetry self add setuptools wheel
|
||||
poetry run python -m pip install gensim
|
||||
poetry install
|
||||
uv sync
|
||||
uv pip install gensim
|
||||
|
||||
- name: Build
|
||||
run: |
|
||||
poetry build
|
||||
uv build
|
||||
|
||||
- name: Install Azurite
|
||||
id: azuright
|
||||
uses: potatoqualitee/azuright@v1.1
|
||||
- name: Install and start Azurite
|
||||
shell: bash
|
||||
run: |
|
||||
npm install -g azurite
|
||||
azurite --silent --skipApiVersionCheck --location /tmp/azurite --debug /tmp/azurite-debug.log &
|
||||
|
||||
# For more information on installation/setup of Azure Cosmos DB Emulator
|
||||
# https://learn.microsoft.com/en-us/azure/cosmos-db/how-to-develop-emulator?tabs=docker-linux%2Cpython&pivots=api-nosql
|
||||
@ -97,4 +93,4 @@ jobs:
|
||||
|
||||
- name: Integration Test
|
||||
run: |
|
||||
poetry run poe test_integration
|
||||
uv run poe test_integration
|
||||
|
||||
20
.github/workflows/python-notebook-tests.yml
vendored
20
.github/workflows/python-notebook-tests.yml
vendored
@ -26,9 +26,6 @@ concurrency:
|
||||
# Only run the for the latest commit
|
||||
cancel-in-progress: true
|
||||
|
||||
env:
|
||||
POETRY_VERSION: 1.8.3
|
||||
|
||||
jobs:
|
||||
python-ci:
|
||||
# skip draft PRs
|
||||
@ -41,8 +38,6 @@ jobs:
|
||||
env:
|
||||
DEBUG: 1
|
||||
GRAPHRAG_API_KEY: ${{ secrets.OPENAI_NOTEBOOK_KEY }}
|
||||
GRAPHRAG_LLM_MODEL: ${{ secrets.GRAPHRAG_LLM_MODEL }}
|
||||
GRAPHRAG_EMBEDDING_MODEL: ${{ secrets.GRAPHRAG_EMBEDDING_MODEL }}
|
||||
|
||||
runs-on: ${{ matrix.os }}
|
||||
steps:
|
||||
@ -54,7 +49,7 @@ jobs:
|
||||
filters: |
|
||||
python:
|
||||
- 'graphrag/**/*'
|
||||
- 'poetry.lock'
|
||||
- 'uv.lock'
|
||||
- 'pyproject.toml'
|
||||
- '**/*.py'
|
||||
- '**/*.toml'
|
||||
@ -66,18 +61,15 @@ jobs:
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
- name: Install Poetry
|
||||
uses: abatilo/actions-poetry@v3.0.0
|
||||
with:
|
||||
poetry-version: $POETRY_VERSION
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v6
|
||||
|
||||
- name: Install dependencies
|
||||
shell: bash
|
||||
run: |
|
||||
poetry self add setuptools wheel
|
||||
poetry run python -m pip install gensim
|
||||
poetry install
|
||||
uv sync
|
||||
uv pip install gensim
|
||||
|
||||
- name: Notebook Test
|
||||
run: |
|
||||
poetry run poe test_notebook
|
||||
uv run poe test_notebook
|
||||
|
||||
13
.github/workflows/python-publish.yml
vendored
13
.github/workflows/python-publish.yml
vendored
@ -6,7 +6,6 @@ on:
|
||||
branches: [main]
|
||||
|
||||
env:
|
||||
POETRY_VERSION: "1.8.3"
|
||||
PYTHON_VERSION: "3.10"
|
||||
|
||||
jobs:
|
||||
@ -31,21 +30,19 @@ jobs:
|
||||
with:
|
||||
python-version: ${{ env.PYTHON_VERSION }}
|
||||
|
||||
- name: Install Poetry
|
||||
uses: abatilo/actions-poetry@v3.0.0
|
||||
with:
|
||||
poetry-version: ${{ env.POETRY_VERSION }}
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v6
|
||||
|
||||
- name: Install dependencies
|
||||
shell: bash
|
||||
run: poetry install
|
||||
run: uv sync
|
||||
|
||||
- name: Export Publication Version
|
||||
run: echo "version=`poetry version --short`" >> $GITHUB_OUTPUT
|
||||
run: echo "version=$(uv version --short)" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Build Distributable
|
||||
shell: bash
|
||||
run: poetry build
|
||||
run: uv build
|
||||
|
||||
- name: Publish package distributions to PyPI
|
||||
uses: pypa/gh-action-pypi-publish@release/v1
|
||||
|
||||
42
.github/workflows/python-smoke-tests.yml
vendored
42
.github/workflows/python-smoke-tests.yml
vendored
@ -26,9 +26,6 @@ concurrency:
|
||||
# Only run the for the latest commit
|
||||
cancel-in-progress: true
|
||||
|
||||
env:
|
||||
POETRY_VERSION: 1.8.3
|
||||
|
||||
jobs:
|
||||
python-ci:
|
||||
# skip draft PRs
|
||||
@ -40,20 +37,8 @@ jobs:
|
||||
fail-fast: false # Continue running all jobs even if one fails
|
||||
env:
|
||||
DEBUG: 1
|
||||
GRAPHRAG_LLM_TYPE: "azure_openai_chat"
|
||||
GRAPHRAG_EMBEDDING_TYPE: "azure_openai_embedding"
|
||||
GRAPHRAG_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
GRAPHRAG_API_BASE: ${{ secrets.GRAPHRAG_API_BASE }}
|
||||
GRAPHRAG_API_VERSION: ${{ secrets.GRAPHRAG_API_VERSION }}
|
||||
GRAPHRAG_LLM_DEPLOYMENT_NAME: ${{ secrets.GRAPHRAG_LLM_DEPLOYMENT_NAME }}
|
||||
GRAPHRAG_EMBEDDING_DEPLOYMENT_NAME: ${{ secrets.GRAPHRAG_EMBEDDING_DEPLOYMENT_NAME }}
|
||||
GRAPHRAG_LLM_MODEL: ${{ secrets.GRAPHRAG_LLM_MODEL }}
|
||||
GRAPHRAG_EMBEDDING_MODEL: ${{ secrets.GRAPHRAG_EMBEDDING_MODEL }}
|
||||
# We have Windows + Linux runners in 3.10, so we need to divide the rate limits by 2
|
||||
GRAPHRAG_LLM_TPM: 200_000 # 400_000 / 2
|
||||
GRAPHRAG_LLM_RPM: 1_000 # 2_000 / 2
|
||||
GRAPHRAG_EMBEDDING_TPM: 225_000 # 450_000 / 2
|
||||
GRAPHRAG_EMBEDDING_RPM: 1_000 # 2_000 / 2
|
||||
# Azure AI Search config
|
||||
AZURE_AI_SEARCH_URL_ENDPOINT: ${{ secrets.AZURE_AI_SEARCH_URL_ENDPOINT }}
|
||||
AZURE_AI_SEARCH_API_KEY: ${{ secrets.AZURE_AI_SEARCH_API_KEY }}
|
||||
@ -68,7 +53,7 @@ jobs:
|
||||
filters: |
|
||||
python:
|
||||
- 'graphrag/**/*'
|
||||
- 'poetry.lock'
|
||||
- 'uv.lock'
|
||||
- 'pyproject.toml'
|
||||
- '**/*.py'
|
||||
- '**/*.toml'
|
||||
@ -81,33 +66,32 @@ jobs:
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
- name: Install Poetry
|
||||
uses: abatilo/actions-poetry@v3.0.0
|
||||
with:
|
||||
poetry-version: $POETRY_VERSION
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v6
|
||||
|
||||
- name: Install dependencies
|
||||
shell: bash
|
||||
run: |
|
||||
poetry self add setuptools wheel
|
||||
poetry run python -m pip install gensim
|
||||
poetry install
|
||||
uv sync
|
||||
uv pip install gensim
|
||||
|
||||
- name: Build
|
||||
run: |
|
||||
poetry build
|
||||
uv build
|
||||
|
||||
- name: Install Azurite
|
||||
id: azuright
|
||||
uses: potatoqualitee/azuright@v1.1
|
||||
- name: Install and start Azurite
|
||||
shell: bash
|
||||
run: |
|
||||
npm install -g azurite
|
||||
azurite --silent --skipApiVersionCheck --location /tmp/azurite --debug /tmp/azurite-debug.log &
|
||||
|
||||
- name: Smoke Test
|
||||
if: steps.changes.outputs.python == 'true'
|
||||
run: |
|
||||
poetry run poe test_smoke
|
||||
uv run poe test_smoke
|
||||
|
||||
- uses: actions/upload-artifact@v4
|
||||
if: always()
|
||||
with:
|
||||
name: smoke-test-artifacts-${{ matrix.python-version }}-${{ matrix.poetry-version }}-${{ runner.os }}
|
||||
name: smoke-test-artifacts-${{ matrix.python-version }}-${{ runner.os }}
|
||||
path: tests/fixtures/*
|
||||
|
||||
2
.gitignore
vendored
2
.gitignore
vendored
@ -1,6 +1,8 @@
|
||||
# Python Artifacts
|
||||
python/*/lib/
|
||||
dist/
|
||||
build/
|
||||
*.egg-info/
|
||||
|
||||
# Test Output
|
||||
.coverage
|
||||
|
||||
14
.semversioner/2.5.0.json
Normal file
14
.semversioner/2.5.0.json
Normal file
@ -0,0 +1,14 @@
|
||||
{
|
||||
"changes": [
|
||||
{
|
||||
"description": "Add additional context variable to build index signature for custom parameter bag",
|
||||
"type": "minor"
|
||||
},
|
||||
{
|
||||
"description": "swap package management from Poetry -> UV",
|
||||
"type": "minor"
|
||||
}
|
||||
],
|
||||
"created_at": "2025-08-14T00:59:46+00:00",
|
||||
"version": "2.5.0"
|
||||
}
|
||||
54
.semversioner/2.6.0.json
Normal file
54
.semversioner/2.6.0.json
Normal file
@ -0,0 +1,54 @@
|
||||
{
|
||||
"changes": [
|
||||
{
|
||||
"description": "Add LiteLLM chat and embedding model providers.",
|
||||
"type": "minor"
|
||||
},
|
||||
{
|
||||
"description": "Add LoggerFactory and clean up related API.",
|
||||
"type": "minor"
|
||||
},
|
||||
{
|
||||
"description": "Add config for NLP async mode.",
|
||||
"type": "minor"
|
||||
},
|
||||
{
|
||||
"description": "Add optional input documents to indexing API.",
|
||||
"type": "minor"
|
||||
},
|
||||
{
|
||||
"description": "add customization to vector store",
|
||||
"type": "minor"
|
||||
},
|
||||
{
|
||||
"description": "Add gpt-5 support by updating fnllm dependency.",
|
||||
"type": "patch"
|
||||
},
|
||||
{
|
||||
"description": "Fix all human_readable_id fields to be 0-based.",
|
||||
"type": "patch"
|
||||
},
|
||||
{
|
||||
"description": "Fix multi-index search.",
|
||||
"type": "patch"
|
||||
},
|
||||
{
|
||||
"description": "Improve upon recent logging refactor",
|
||||
"type": "patch"
|
||||
},
|
||||
{
|
||||
"description": "Make cache, storage, and vector_store factories consistent with similar registration support",
|
||||
"type": "patch"
|
||||
},
|
||||
{
|
||||
"description": "Remove hard-coded community rate limiter.",
|
||||
"type": "patch"
|
||||
},
|
||||
{
|
||||
"description": "generate_text_embeddings only loads tables if embedding field is specified.",
|
||||
"type": "patch"
|
||||
}
|
||||
],
|
||||
"created_at": "2025-09-22T21:44:51+00:00",
|
||||
"version": "2.6.0"
|
||||
}
|
||||
18
.semversioner/2.7.0.json
Normal file
18
.semversioner/2.7.0.json
Normal file
@ -0,0 +1,18 @@
|
||||
{
|
||||
"changes": [
|
||||
{
|
||||
"description": "Set LiteLLM as default in init_content.",
|
||||
"type": "minor"
|
||||
},
|
||||
{
|
||||
"description": "Fix Azure auth scope issue with LiteLLM.",
|
||||
"type": "patch"
|
||||
},
|
||||
{
|
||||
"description": "Housekeeping toward 2.7.",
|
||||
"type": "patch"
|
||||
}
|
||||
],
|
||||
"created_at": "2025-10-08T22:39:42+00:00",
|
||||
"version": "2.7.0"
|
||||
}
|
||||
57
.vscode/launch.json
vendored
57
.vscode/launch.json
vendored
@ -6,21 +6,24 @@
|
||||
"name": "Indexer",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "poetry",
|
||||
"module": "graphrag",
|
||||
"args": [
|
||||
"poe", "index",
|
||||
"--root", "<path_to_ragtest_root_demo>"
|
||||
"index",
|
||||
"--root",
|
||||
"<path_to_index_folder>"
|
||||
],
|
||||
"console": "integratedTerminal"
|
||||
},
|
||||
{
|
||||
"name": "Query",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "poetry",
|
||||
"module": "graphrag",
|
||||
"args": [
|
||||
"poe", "query",
|
||||
"--root", "<path_to_ragtest_root_demo>",
|
||||
"--method", "global",
|
||||
"query",
|
||||
"--root",
|
||||
"<path_to_index_folder>",
|
||||
"--method", "basic",
|
||||
"--query", "What are the top themes in this story",
|
||||
]
|
||||
},
|
||||
@ -28,12 +31,48 @@
|
||||
"name": "Prompt Tuning",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "poetry",
|
||||
"module": "uv",
|
||||
"args": [
|
||||
"poe", "prompt-tune",
|
||||
"--config",
|
||||
"<path_to_ragtest_root_demo>/settings.yaml",
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Debug Integration Pytest",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "pytest",
|
||||
"args": [
|
||||
"./tests/integration/vector_stores",
|
||||
"-k", "test_azure_ai_search"
|
||||
],
|
||||
"console": "integratedTerminal",
|
||||
"justMyCode": false
|
||||
},
|
||||
{
|
||||
"name": "Debug Verbs Pytest",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "pytest",
|
||||
"args": [
|
||||
"./tests/verbs",
|
||||
"-k", "test_generate_text_embeddings"
|
||||
],
|
||||
"console": "integratedTerminal",
|
||||
"justMyCode": false
|
||||
},
|
||||
{
|
||||
"name": "Debug Smoke Pytest",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "pytest",
|
||||
"args": [
|
||||
"./tests/smoke",
|
||||
"-k", "test_fixtures"
|
||||
],
|
||||
"console": "integratedTerminal",
|
||||
"justMyCode": false
|
||||
},
|
||||
]
|
||||
}
|
||||
37
.vscode/settings.json
vendored
37
.vscode/settings.json
vendored
@ -1,43 +1,8 @@
|
||||
{
|
||||
"search.exclude": {
|
||||
"**/.yarn": true,
|
||||
"**/.pnp.*": true
|
||||
},
|
||||
"editor.formatOnSave": false,
|
||||
"eslint.nodePath": ".yarn/sdks",
|
||||
"typescript.tsdk": ".yarn/sdks/typescript/lib",
|
||||
"typescript.enablePromptUseWorkspaceTsdk": true,
|
||||
"javascript.preferences.importModuleSpecifier": "relative",
|
||||
"javascript.preferences.importModuleSpecifierEnding": "js",
|
||||
"typescript.preferences.importModuleSpecifier": "relative",
|
||||
"typescript.preferences.importModuleSpecifierEnding": "js",
|
||||
"explorer.fileNesting.enabled": true,
|
||||
"explorer.fileNesting.patterns": {
|
||||
"*.ts": "${capture}.ts, ${capture}.hooks.ts, ${capture}.hooks.tsx, ${capture}.contexts.ts, ${capture}.stories.tsx, ${capture}.story.tsx, ${capture}.spec.tsx, ${capture}.base.ts, ${capture}.base.tsx, ${capture}.types.ts, ${capture}.styles.ts, ${capture}.styles.tsx, ${capture}.utils.ts, ${capture}.utils.tsx, ${capture}.constants.ts, ${capture}.module.scss, ${capture}.module.css, ${capture}.md",
|
||||
"*.js": "${capture}.js.map, ${capture}.min.js, ${capture}.d.ts",
|
||||
"*.jsx": "${capture}.js",
|
||||
"*.tsx": "${capture}.ts, ${capture}.hooks.ts, ${capture}.hooks.tsx, ${capture}.contexts.ts, ${capture}.stories.tsx, ${capture}.story.tsx, ${capture}.spec.tsx, ${capture}.base.ts, ${capture}.base.tsx, ${capture}.types.ts, ${capture}.styles.ts, ${capture}.styles.tsx, ${capture}.utils.ts, ${capture}.utils.tsx, ${capture}.constants.ts, ${capture}.module.scss, ${capture}.module.css, ${capture}.md, ${capture}.css",
|
||||
"tsconfig.json": "tsconfig.*.json",
|
||||
"package.json": "package-lock.json, turbo.json, tsconfig.json, rome.json, biome.json, .npmignore, dictionary.txt, cspell.config.yaml",
|
||||
"README.md": "*.md, LICENSE, CODEOWNERS",
|
||||
".eslintrc": ".eslintignore",
|
||||
".prettierrc": ".prettierignore",
|
||||
".gitattributes": ".gitignore",
|
||||
".yarnrc.yml": "yarn.lock, .pnp.*",
|
||||
"jest.config.js": "jest.setup.mjs",
|
||||
"pyproject.toml": "poetry.lock, poetry.toml, mkdocs.yaml",
|
||||
"cspell.config.yaml": "dictionary.txt"
|
||||
},
|
||||
"azureFunctions.postDeployTask": "npm install (functions)",
|
||||
"azureFunctions.projectLanguage": "TypeScript",
|
||||
"azureFunctions.projectRuntime": "~4",
|
||||
"debug.internalConsoleOptions": "neverOpen",
|
||||
"azureFunctions.preDeployTask": "npm prune (functions)",
|
||||
"appService.zipIgnorePattern": [
|
||||
"node_modules{,/**}",
|
||||
".vscode{,/**}"
|
||||
],
|
||||
"python.defaultInterpreterPath": "python/services/.venv/bin/python",
|
||||
"python.defaultInterpreterPath": "${workspaceRoot}/.venv/bin/python",
|
||||
"python.languageServer": "Pylance",
|
||||
"cSpell.customDictionaries": {
|
||||
"project-words": {
|
||||
|
||||
26
CHANGELOG.md
26
CHANGELOG.md
@ -1,6 +1,32 @@
|
||||
# Changelog
|
||||
Note: version releases in the 0.x.y range may introduce breaking changes.
|
||||
|
||||
## 2.7.0
|
||||
|
||||
- minor: Set LiteLLM as default in init_content.
|
||||
- patch: Fix Azure auth scope issue with LiteLLM.
|
||||
- patch: Housekeeping toward 2.7.
|
||||
|
||||
## 2.6.0
|
||||
|
||||
- minor: Add LiteLLM chat and embedding model providers.
|
||||
- minor: Add LoggerFactory and clean up related API.
|
||||
- minor: Add config for NLP async mode.
|
||||
- minor: Add optional input documents to indexing API.
|
||||
- minor: add customization to vector store
|
||||
- patch: Add gpt-5 support by updating fnllm dependency.
|
||||
- patch: Fix all human_readable_id fields to be 0-based.
|
||||
- patch: Fix multi-index search.
|
||||
- patch: Improve upon recent logging refactor
|
||||
- patch: Make cache, storage, and vector_store factories consistent with similar registration support
|
||||
- patch: Remove hard-coded community rate limiter.
|
||||
- patch: generate_text_embeddings only loads tables if embedding field is specified.
|
||||
|
||||
## 2.5.0
|
||||
|
||||
- minor: Add additional context variable to build index signature for custom parameter bag
|
||||
- minor: swap package management from Poetry -> UV
|
||||
|
||||
## 2.4.0
|
||||
|
||||
- minor: Allow injection of custom pipelines.
|
||||
|
||||
@ -22,7 +22,7 @@ or contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any addi
|
||||
2. Create a new branch for your contribution: `git checkout -b my-contribution`.
|
||||
3. Make your changes and ensure that the code passes all tests.
|
||||
4. Commit your changes: `git commit -m "Add my contribution"`.
|
||||
5. Create and commit a semver impact document by running `poetry run semversioner add-change -t <major|minor|patch> -d <description>`.
|
||||
5. Create and commit a semver impact document by running `uv run semversioner add-change -t <major|minor|patch> -d <description>`.
|
||||
6. Push your changes to your forked repository: `git push origin my-contribution`.
|
||||
7. Open a pull request to the main repository.
|
||||
|
||||
|
||||
@ -5,29 +5,29 @@
|
||||
| Name | Installation | Purpose |
|
||||
| ------------------- | ------------------------------------------------------------ | ----------------------------------------------------------------------------------- |
|
||||
| Python 3.10 or 3.11 | [Download](https://www.python.org/downloads/) | The library is Python-based. |
|
||||
| Poetry | [Instructions](https://python-poetry.org/docs/#installation) | Poetry is used for package management and virtualenv management in Python codebases |
|
||||
| uv | [Instructions](https://docs.astral.sh/uv/) | uv is used for package management and virtualenv management in Python codebases |
|
||||
|
||||
# Getting Started
|
||||
|
||||
## Install Dependencies
|
||||
```shell
|
||||
# install python dependencies
|
||||
poetry install
|
||||
uv sync
|
||||
```
|
||||
|
||||
## Execute the indexing engine
|
||||
```shell
|
||||
poetry run poe index <...args>
|
||||
uv run poe index <...args>
|
||||
```
|
||||
|
||||
## Execute prompt tuning
|
||||
```shell
|
||||
poetry run poe prompt_tune <...args>
|
||||
uv run poe prompt_tune <...args>
|
||||
```
|
||||
|
||||
## Execute Queries
|
||||
```shell
|
||||
poetry run poe query <...args>
|
||||
uv run poe query <...args>
|
||||
```
|
||||
|
||||
## Repository Structure
|
||||
@ -63,7 +63,7 @@ Where appropriate, the factories expose a registration method for users to provi
|
||||
|
||||
We use [semversioner](https://github.com/raulgomis/semversioner) to automate and enforce semantic versioning in the release process. Our CI/CD pipeline checks that all PR's include a json file generated by semversioner. When submitting a PR, please run:
|
||||
```shell
|
||||
poetry run semversioner add-change -t patch -d "<a small sentence describing changes made>."
|
||||
uv run semversioner add-change -t patch -d "<a small sentence describing changes made>."
|
||||
```
|
||||
|
||||
# Azurite
|
||||
@ -78,29 +78,29 @@ or by simply running `azurite` in the terminal if already installed globally. Se
|
||||
|
||||
# Lifecycle Scripts
|
||||
|
||||
Our Python package utilizes Poetry to manage dependencies and [poethepoet](https://pypi.org/project/poethepoet/) to manage custom build scripts.
|
||||
Our Python package utilizes uv to manage dependencies and [poethepoet](https://pypi.org/project/poethepoet/) to manage custom build scripts.
|
||||
|
||||
Available scripts are:
|
||||
- `poetry run poe index` - Run the Indexing CLI
|
||||
- `poetry run poe query` - Run the Query CLI
|
||||
- `poetry build` - This invokes `poetry build`, which will build a wheel file and other distributable artifacts.
|
||||
- `poetry run poe test` - This will execute all tests.
|
||||
- `poetry run poe test_unit` - This will execute unit tests.
|
||||
- `poetry run poe test_integration` - This will execute integration tests.
|
||||
- `poetry run poe test_smoke` - This will execute smoke tests.
|
||||
- `poetry run poe check` - This will perform a suite of static checks across the package, including:
|
||||
- `uv run poe index` - Run the Indexing CLI
|
||||
- `uv run poe query` - Run the Query CLI
|
||||
- `uv build` - This invokes `uv build`, which will build a wheel file and other distributable artifacts.
|
||||
- `uv run poe test` - This will execute all tests.
|
||||
- `uv run poe test_unit` - This will execute unit tests.
|
||||
- `uv run poe test_integration` - This will execute integration tests.
|
||||
- `uv run poe test_smoke` - This will execute smoke tests.
|
||||
- `uv run poe check` - This will perform a suite of static checks across the package, including:
|
||||
- formatting
|
||||
- documentation formatting
|
||||
- linting
|
||||
- security patterns
|
||||
- type-checking
|
||||
- `poetry run poe fix` - This will apply any available auto-fixes to the package. Usually this is just formatting fixes.
|
||||
- `poetry run poe fix_unsafe` - This will apply any available auto-fixes to the package, including those that may be unsafe.
|
||||
- `poetry run poe format` - Explicitly run the formatter across the package.
|
||||
- `uv run poe fix` - This will apply any available auto-fixes to the package. Usually this is just formatting fixes.
|
||||
- `uv run poe fix_unsafe` - This will apply any available auto-fixes to the package, including those that may be unsafe.
|
||||
- `uv run poe format` - Explicitly run the formatter across the package.
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### "RuntimeError: llvm-config failed executing, please point LLVM_CONFIG to the path for llvm-config" when running poetry install
|
||||
### "RuntimeError: llvm-config failed executing, please point LLVM_CONFIG to the path for llvm-config" when running uv sync
|
||||
|
||||
Make sure llvm-9 and llvm-9-dev are installed:
|
||||
|
||||
@ -110,13 +110,8 @@ and then in your bashrc, add
|
||||
|
||||
`export LLVM_CONFIG=/usr/bin/llvm-config-9`
|
||||
|
||||
### "numba/\_pymodule.h:6:10: fatal error: Python.h: No such file or directory" when running poetry install
|
||||
### "numba/\_pymodule.h:6:10: fatal error: Python.h: No such file or directory" when running uv sync
|
||||
|
||||
Make sure you have python3.10-dev installed or more generally `python<version>-dev`
|
||||
|
||||
`sudo apt-get install python3.10-dev`
|
||||
|
||||
### LLM call constantly exceeds TPM, RPM or time limits
|
||||
|
||||
`GRAPHRAG_LLM_THREAD_COUNT` and `GRAPHRAG_EMBEDDING_THREAD_COUNT` are both set to 50 by default. You can modify these values
|
||||
to reduce concurrency. Please refer to the [Configuration Documents](https://microsoft.github.io/graphrag/config/overview/)
|
||||
|
||||
@ -81,6 +81,7 @@ typer
|
||||
spacy
|
||||
kwargs
|
||||
ollama
|
||||
litellm
|
||||
|
||||
# Library Methods
|
||||
iterrows
|
||||
@ -103,6 +104,8 @@ isin
|
||||
nocache
|
||||
nbconvert
|
||||
levelno
|
||||
acompletion
|
||||
aembedding
|
||||
|
||||
# HTML
|
||||
nbsp
|
||||
|
||||
@ -8,9 +8,38 @@ GraphRAG was built and tested using OpenAI models, so this is the default model
|
||||
|
||||
GraphRAG also utilizes a language model wrapper library used by several projects within our team, called fnllm. fnllm provides two important functions for GraphRAG: rate limiting configuration to help us maximize throughput for large indexing jobs, and robust caching of API calls to minimize consumption on repeated indexes for testing, experimentation, or incremental ingest. fnllm uses the OpenAI Python SDK under the covers, so OpenAI-compliant endpoints are a base requirement out-of-the-box.
|
||||
|
||||
Starting with version 2.6.0, GraphRAG supports using [LiteLLM](https://docs.litellm.ai/) instead of fnllm for calling language models. LiteLLM provides support for 100+ models though it is important to note that when choosing a model it must support returning [structured outputs](https://openai.com/index/introducing-structured-outputs-in-the-api/) adhering to a [JSON schema](https://docs.litellm.ai/docs/completion/json_mode).
|
||||
|
||||
Example using LiteLLm as the language model tool for GraphRAG:
|
||||
|
||||
```yaml
|
||||
models:
|
||||
default_chat_model:
|
||||
type: chat
|
||||
auth_type: api_key
|
||||
api_key: ${GEMINI_API_KEY}
|
||||
model_provider: gemini
|
||||
model: gemini-2.5-flash-lite
|
||||
default_embedding_model:
|
||||
type: embedding
|
||||
auth_type: api_key
|
||||
api_key: ${GEMINI_API_KEY}
|
||||
model_provider: gemini
|
||||
model: gemini-embedding-001
|
||||
```
|
||||
|
||||
To use LiteLLM one must
|
||||
|
||||
- Set `type` to either `chat` or `embedding`.
|
||||
- Provide a `model_provider`, e.g., `openai`, `azure`, `gemini`, etc.
|
||||
- Set the `model` to a one supported by the `model_provider`'s API.
|
||||
- Provide a `deployment_name` if using `azure` as the `model_provider`.
|
||||
|
||||
See [Detailed Configuration](yaml.md) for more details on configuration. [View LiteLLm basic usage](https://docs.litellm.ai/docs/#basic-usage) for details on how models are called (The `model_provider` is the portion prior to `/` while the `model` is the portion following the `/`).
|
||||
|
||||
## Model Selection Considerations
|
||||
|
||||
GraphRAG has been most thoroughly tested with the gpt-4 series of models from OpenAI, including gpt-4 gpt-4-turbo, gpt-4o, and gpt-4o-mini. Our [arXiv paper](https://arxiv.org/abs/2404.16130), for example, performed quality evaluation using gpt-4-turbo.
|
||||
GraphRAG has been most thoroughly tested with the gpt-4 series of models from OpenAI, including gpt-4 gpt-4-turbo, gpt-4o, and gpt-4o-mini. Our [arXiv paper](https://arxiv.org/abs/2404.16130), for example, performed quality evaluation using gpt-4-turbo. As stated above, non-OpenAI models are now supported with GraphRAG 2.6.0 and onwards through the use of LiteLLM but the suite of gpt-4 series of models from OpenAI remain the most tested and supported suite of models for GraphRAG.
|
||||
|
||||
Versions of GraphRAG before 2.2.0 made extensive use of `max_tokens` and `logit_bias` to control generated response length or content. The introduction of the o-series of models added new, non-compatible parameters because these models include a reasoning component that has different consumption patterns and response generation attributes than non-reasoning models. GraphRAG 2.2.0 now supports these models, but there are important differences that need to be understood before you switch.
|
||||
|
||||
@ -58,11 +87,11 @@ Another option would be to avoid using a language model at all for the graph ext
|
||||
|
||||
## Using Non-OpenAI Models
|
||||
|
||||
As noted above, our primary experience and focus has been on OpenAI models, so this is what is supported out-of-the-box. Many users have requested support for additional model types, but it's out of the scope of our research to handle the many models available today. There are two approaches you can use to connect to a non-OpenAI model:
|
||||
As shown above, non-OpenAI models may be used via LiteLLM starting with GraphRAG version 2.6.0 but cases may still exist in which some users wish to use models not supported by LiteLLM. There are two approaches one can use to connect to unsupported models:
|
||||
|
||||
### Proxy APIs
|
||||
|
||||
Many users have used platforms such as [ollama](https://ollama.com/) to proxy the underlying model HTTP calls to a different model provider. This seems to work reasonably well, but we frequently see issues with malformed responses (especially JSON), so if you do this please understand that your model needs to reliably return the specific response formats that GraphRAG expects. If you're having trouble with a model, you may need to try prompting to coax the format, or intercepting the response within your proxy to try and handle malformed responses.
|
||||
Many users have used platforms such as [ollama](https://ollama.com/) and [LiteLLM Proxy Server](https://docs.litellm.ai/docs/simple_proxy) to proxy the underlying model HTTP calls to a different model provider. This seems to work reasonably well, but we frequently see issues with malformed responses (especially JSON), so if you do this please understand that your model needs to reliably return the specific response formats that GraphRAG expects. If you're having trouble with a model, you may need to try prompting to coax the format, or intercepting the response within your proxy to try and handle malformed responses.
|
||||
|
||||
### Model Protocol
|
||||
|
||||
|
||||
@ -41,7 +41,8 @@ models:
|
||||
|
||||
- `api_key` **str** - The OpenAI API key to use.
|
||||
- `auth_type` **api_key|azure_managed_identity** - Indicate how you want to authenticate requests.
|
||||
- `type` **openai_chat|azure_openai_chat|openai_embedding|azure_openai_embedding|mock_chat|mock_embeddings** - The type of LLM to use.
|
||||
- `type` **chat**|**embedding**|**openai_chat|azure_openai_chat|openai_embedding|azure_openai_embedding|mock_chat|mock_embeddings** - The type of LLM to use.
|
||||
- `model_provider` **str|None** - The model provider to use, e.g., openai, azure, anthropic, etc. Required when `type == chat|embedding`. When `type == chat|embedding`, [LiteLLM](https://docs.litellm.ai/) is used under the hood which has support for calling 100+ models. [View LiteLLm basic usage](https://docs.litellm.ai/docs/#basic-usage) for details on how models are called (The `model_provider` is the portion prior to `/` while the `model` is the portion following the `/`). [View Language Model Selection](models.md) for more details and examples on using LiteLLM.
|
||||
- `model` **str** - The model name.
|
||||
- `encoding_model` **str** - The text encoding model to use. Default is to use the encoding model aligned with the language model (i.e., it is retrieved from tiktoken if unset).
|
||||
- `api_base` **str** - The API base url to use.
|
||||
|
||||
@ -5,27 +5,27 @@
|
||||
| Name | Installation | Purpose |
|
||||
| ------------------- | ------------------------------------------------------------ | ----------------------------------------------------------------------------------- |
|
||||
| Python 3.10-3.12 | [Download](https://www.python.org/downloads/) | The library is Python-based. |
|
||||
| Poetry | [Instructions](https://python-poetry.org/docs/#installation) | Poetry is used for package management and virtualenv management in Python codebases |
|
||||
| uv | [Instructions](https://docs.astral.sh/uv/) | uv is used for package management and virtualenv management in Python codebases |
|
||||
|
||||
# Getting Started
|
||||
|
||||
## Install Dependencies
|
||||
|
||||
```sh
|
||||
# Install Python dependencies.
|
||||
poetry install
|
||||
# install python dependencies
|
||||
uv sync
|
||||
```
|
||||
|
||||
## Execute the Indexing Engine
|
||||
|
||||
```sh
|
||||
poetry run poe index <...args>
|
||||
uv run poe index <...args>
|
||||
```
|
||||
|
||||
## Executing Queries
|
||||
|
||||
```sh
|
||||
poetry run poe query <...args>
|
||||
uv run poe query <...args>
|
||||
```
|
||||
|
||||
# Azurite
|
||||
@ -40,31 +40,31 @@ or by simply running `azurite` in the terminal if already installed globally. Se
|
||||
|
||||
# Lifecycle Scripts
|
||||
|
||||
Our Python package utilizes Poetry to manage dependencies and [poethepoet](https://pypi.org/project/poethepoet/) to manage build scripts.
|
||||
Our Python package utilize uv to manage dependencies and [poethepoet](https://pypi.org/project/poethepoet/) to manage build scripts.
|
||||
|
||||
Available scripts are:
|
||||
|
||||
- `poetry run poe index` - Run the Indexing CLI
|
||||
- `poetry run poe query` - Run the Query CLI
|
||||
- `poetry build` - This invokes `poetry build`, which will build a wheel file and other distributable artifacts.
|
||||
- `poetry run poe test` - This will execute all tests.
|
||||
- `poetry run poe test_unit` - This will execute unit tests.
|
||||
- `poetry run poe test_integration` - This will execute integration tests.
|
||||
- `poetry run poe test_smoke` - This will execute smoke tests.
|
||||
- `poetry run poe test_verbs` - This will execute tests of the basic workflows.
|
||||
- `poetry run poe check` - This will perform a suite of static checks across the package, including:
|
||||
- `uv run poe index` - Run the Indexing CLI
|
||||
- `uv run poe query` - Run the Query CLI
|
||||
- `uv build` - This will build a wheel file and other distributable artifacts.
|
||||
- `uv run poe test` - This will execute all tests.
|
||||
- `uv run poe test_unit` - This will execute unit tests.
|
||||
- `uv run poe test_integration` - This will execute integration tests.
|
||||
- `uv run poe test_smoke` - This will execute smoke tests.
|
||||
- `uv run poe test_verbs` - This will execute tests of the basic workflows.
|
||||
- `uv run poe check` - This will perform a suite of static checks across the package, including:
|
||||
- formatting
|
||||
- documentation formatting
|
||||
- linting
|
||||
- security patterns
|
||||
- type-checking
|
||||
- `poetry run poe fix` - This will apply any available auto-fixes to the package. Usually this is just formatting fixes.
|
||||
- `poetry run poe fix_unsafe` - This will apply any available auto-fixes to the package, including those that may be unsafe.
|
||||
- `poetry run poe format` - Explicitly run the formatter across the package.
|
||||
- `uv run poe fix` - This will apply any available auto-fixes to the package. Usually this is just formatting fixes.
|
||||
- `uv run poe fix_unsafe` - This will apply any available auto-fixes to the package, including those that may be unsafe.
|
||||
- `uv run poe format` - Explicitly run the formatter across the package.
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### "RuntimeError: llvm-config failed executing, please point LLVM_CONFIG to the path for llvm-config" when running poetry install
|
||||
### "RuntimeError: llvm-config failed executing, please point LLVM_CONFIG to the path for llvm-config" when running uv install
|
||||
|
||||
Make sure llvm-9 and llvm-9-dev are installed:
|
||||
|
||||
@ -73,14 +73,3 @@ Make sure llvm-9 and llvm-9-dev are installed:
|
||||
and then in your bashrc, add
|
||||
|
||||
`export LLVM_CONFIG=/usr/bin/llvm-config-9`
|
||||
|
||||
### "numba/\_pymodule.h:6:10: fatal error: Python.h: No such file or directory" when running poetry install
|
||||
|
||||
Make sure you have python3.10-dev installed or more generally `python<version>-dev`
|
||||
|
||||
`sudo apt-get install python3.10-dev`
|
||||
|
||||
### LLM call constantly exceeds TPM, RPM or time limits
|
||||
|
||||
`GRAPHRAG_LLM_THREAD_COUNT` and `GRAPHRAG_EMBEDDING_THREAD_COUNT` are both set to 50 by default. You can modify these values
|
||||
to reduce concurrency. Please refer to the [Configuration Documents](config/overview.md)
|
||||
|
||||
@ -67,6 +67,8 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# note that we expect this to fail on the deployed docs because the PROJECT_DIRECTORY is not set to a real location.\n",
|
||||
"# if you run this notebook locally, make sure to point at a location containing your settings.yaml\n",
|
||||
"graphrag_config = load_config(Path(PROJECT_DIRECTORY))"
|
||||
]
|
||||
},
|
||||
|
||||
680
docs/examples_notebooks/custom_vector_store.ipynb
Normal file
680
docs/examples_notebooks/custom_vector_store.ipynb
Normal file
@ -0,0 +1,680 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Copyright (c) 2024 Microsoft Corporation.\n",
|
||||
"# Licensed under the MIT License."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Bring-Your-Own Vector Store\n",
|
||||
"\n",
|
||||
"This notebook demonstrates how to implement a custom vector store and register for usage with GraphRAG.\n",
|
||||
"\n",
|
||||
"## Overview\n",
|
||||
"\n",
|
||||
"GraphRAG uses a plug-and-play architecture that allow for easy integration of custom vector stores (outside of what is natively supported) by following a factory design pattern. This allows you to:\n",
|
||||
"\n",
|
||||
"- **Extend functionality**: Add support for new vector database backends\n",
|
||||
"- **Customize behavior**: Implement specialized search logic or data structures\n",
|
||||
"- **Integrate existing systems**: Connect GraphRAG to your existing vector database infrastructure\n",
|
||||
"\n",
|
||||
"### What You'll Learn\n",
|
||||
"\n",
|
||||
"1. Understanding the `BaseVectorStore` interface\n",
|
||||
"2. Implementing a custom vector store class\n",
|
||||
"3. Registering your vector store with the `VectorStoreFactory`\n",
|
||||
"4. Testing and validating your implementation\n",
|
||||
"5. Configuring GraphRAG to use your custom vector store\n",
|
||||
"\n",
|
||||
"Let's get started!"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Step 1: Import Required Dependencies\n",
|
||||
"\n",
|
||||
"First, let's import the necessary GraphRAG components and other dependencies we'll need.\n",
|
||||
"\n",
|
||||
"```bash\n",
|
||||
"pip install graphrag\n",
|
||||
"```"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from typing import Any\n",
|
||||
"\n",
|
||||
"import numpy as np\n",
|
||||
"import yaml\n",
|
||||
"\n",
|
||||
"from graphrag.config.models.vector_store_schema_config import VectorStoreSchemaConfig\n",
|
||||
"from graphrag.data_model.types import TextEmbedder\n",
|
||||
"\n",
|
||||
"# GraphRAG vector store components\n",
|
||||
"from graphrag.vector_stores.base import (\n",
|
||||
" BaseVectorStore,\n",
|
||||
" VectorStoreDocument,\n",
|
||||
" VectorStoreSearchResult,\n",
|
||||
")\n",
|
||||
"from graphrag.vector_stores.factory import VectorStoreFactory"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Step 2: Understand the BaseVectorStore Interface\n",
|
||||
"\n",
|
||||
"Before using a custom vector store, let's examine the `BaseVectorStore` interface to understand what methods need to be implemented."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Let's inspect the BaseVectorStore class to understand the required methods\n",
|
||||
"import inspect\n",
|
||||
"\n",
|
||||
"print(\"BaseVectorStore Abstract Methods:\")\n",
|
||||
"print(\"=\" * 40)\n",
|
||||
"\n",
|
||||
"abstract_methods = []\n",
|
||||
"for name, method in inspect.getmembers(BaseVectorStore, predicate=inspect.isfunction):\n",
|
||||
" if getattr(method, \"__isabstractmethod__\", False):\n",
|
||||
" signature = inspect.signature(method)\n",
|
||||
" abstract_methods.append(f\"• {name}{signature}\")\n",
|
||||
" print(f\"• {name}{signature}\")\n",
|
||||
"\n",
|
||||
"print(f\"\\nTotal abstract methods to implement: {len(abstract_methods)}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Step 3: Implement a Custom Vector Store\n",
|
||||
"\n",
|
||||
"Now let's implement a simple in-memory vector store as an example. This vector store will:\n",
|
||||
"\n",
|
||||
"- Store documents and vectors in memory using Python data structures\n",
|
||||
"- Support all required BaseVectorStore methods\n",
|
||||
"\n",
|
||||
"**Note**: This is a simplified example for demonstration. Production vector stores would typically use optimized libraries like FAISS, more sophisticated indexing, and persistent storage."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class SimpleInMemoryVectorStore(BaseVectorStore):\n",
|
||||
" \"\"\"A simple in-memory vector store implementation for demonstration purposes.\n",
|
||||
"\n",
|
||||
" This vector store stores documents and their embeddings in memory and provides\n",
|
||||
" basic similarity search functionality using cosine similarity.\n",
|
||||
"\n",
|
||||
" WARNING: This is for demonstration only - not suitable for production use.\n",
|
||||
" For production, consider using optimized vector databases like LanceDB,\n",
|
||||
" Azure AI Search, or other specialized vector stores.\n",
|
||||
" \"\"\"\n",
|
||||
"\n",
|
||||
" # Internal storage for documents and vectors\n",
|
||||
" documents: dict[str, VectorStoreDocument]\n",
|
||||
" vectors: dict[str, np.ndarray]\n",
|
||||
" connected: bool\n",
|
||||
"\n",
|
||||
" def __init__(self, **kwargs: Any):\n",
|
||||
" \"\"\"Initialize the in-memory vector store.\"\"\"\n",
|
||||
" super().__init__(**kwargs)\n",
|
||||
"\n",
|
||||
" self.documents: dict[str, VectorStoreDocument] = {}\n",
|
||||
" self.vectors: dict[str, np.ndarray] = {}\n",
|
||||
" self.connected = False\n",
|
||||
"\n",
|
||||
" print(f\"🚀 SimpleInMemoryVectorStore initialized for index: {self.index_name}\")\n",
|
||||
"\n",
|
||||
" def connect(self, **kwargs: Any) -> None:\n",
|
||||
" \"\"\"Connect to the vector storage (no-op for in-memory store).\"\"\"\n",
|
||||
" self.connected = True\n",
|
||||
" print(f\"✅ Connected to in-memory vector store: {self.index_name}\")\n",
|
||||
"\n",
|
||||
" def load_documents(\n",
|
||||
" self, documents: list[VectorStoreDocument], overwrite: bool = True\n",
|
||||
" ) -> None:\n",
|
||||
" \"\"\"Load documents into the vector store.\"\"\"\n",
|
||||
" if not self.connected:\n",
|
||||
" msg = \"Vector store not connected. Call connect() first.\"\n",
|
||||
" raise RuntimeError(msg)\n",
|
||||
"\n",
|
||||
" if overwrite:\n",
|
||||
" self.documents.clear()\n",
|
||||
" self.vectors.clear()\n",
|
||||
"\n",
|
||||
" loaded_count = 0\n",
|
||||
" for doc in documents:\n",
|
||||
" if doc.vector is not None:\n",
|
||||
" doc_id = str(doc.id)\n",
|
||||
" self.documents[doc_id] = doc\n",
|
||||
" self.vectors[doc_id] = np.array(doc.vector, dtype=np.float32)\n",
|
||||
" loaded_count += 1\n",
|
||||
"\n",
|
||||
" print(f\"📚 Loaded {loaded_count} documents into vector store\")\n",
|
||||
"\n",
|
||||
" def _cosine_similarity(self, vec1: np.ndarray, vec2: np.ndarray) -> float:\n",
|
||||
" \"\"\"Calculate cosine similarity between two vectors.\"\"\"\n",
|
||||
" # Normalize vectors\n",
|
||||
" norm1 = np.linalg.norm(vec1)\n",
|
||||
" norm2 = np.linalg.norm(vec2)\n",
|
||||
"\n",
|
||||
" if norm1 == 0 or norm2 == 0:\n",
|
||||
" return 0.0\n",
|
||||
"\n",
|
||||
" return float(np.dot(vec1, vec2) / (norm1 * norm2))\n",
|
||||
"\n",
|
||||
" def similarity_search_by_vector(\n",
|
||||
" self, query_embedding: list[float], k: int = 10, **kwargs: Any\n",
|
||||
" ) -> list[VectorStoreSearchResult]:\n",
|
||||
" \"\"\"Perform similarity search using a query vector.\"\"\"\n",
|
||||
" if not self.connected:\n",
|
||||
" msg = \"Vector store not connected. Call connect() first.\"\n",
|
||||
" raise RuntimeError(msg)\n",
|
||||
"\n",
|
||||
" if not self.vectors:\n",
|
||||
" return []\n",
|
||||
"\n",
|
||||
" query_vec = np.array(query_embedding, dtype=np.float32)\n",
|
||||
" similarities = []\n",
|
||||
"\n",
|
||||
" # Calculate similarity with all stored vectors\n",
|
||||
" for doc_id, stored_vec in self.vectors.items():\n",
|
||||
" similarity = self._cosine_similarity(query_vec, stored_vec)\n",
|
||||
" similarities.append((doc_id, similarity))\n",
|
||||
"\n",
|
||||
" # Sort by similarity (descending) and take top k\n",
|
||||
" similarities.sort(key=lambda x: x[1], reverse=True)\n",
|
||||
" top_k = similarities[:k]\n",
|
||||
"\n",
|
||||
" # Create search results\n",
|
||||
" results = []\n",
|
||||
" for doc_id, score in top_k:\n",
|
||||
" document = self.documents[doc_id]\n",
|
||||
" result = VectorStoreSearchResult(document=document, score=score)\n",
|
||||
" results.append(result)\n",
|
||||
"\n",
|
||||
" return results\n",
|
||||
"\n",
|
||||
" def similarity_search_by_text(\n",
|
||||
" self, text: str, text_embedder: TextEmbedder, k: int = 10, **kwargs: Any\n",
|
||||
" ) -> list[VectorStoreSearchResult]:\n",
|
||||
" \"\"\"Perform similarity search using text (which gets embedded first).\"\"\"\n",
|
||||
" # Embed the text first\n",
|
||||
" query_embedding = text_embedder(text)\n",
|
||||
"\n",
|
||||
" # Use vector search with the embedding\n",
|
||||
" return self.similarity_search_by_vector(query_embedding, k, **kwargs)\n",
|
||||
"\n",
|
||||
" def filter_by_id(self, include_ids: list[str] | list[int]) -> Any:\n",
|
||||
" \"\"\"Build a query filter to filter documents by id.\n",
|
||||
"\n",
|
||||
" For this simple implementation, we return the list of IDs as the filter.\n",
|
||||
" \"\"\"\n",
|
||||
" return [str(id_) for id_ in include_ids]\n",
|
||||
"\n",
|
||||
" def search_by_id(self, id: str) -> VectorStoreDocument:\n",
|
||||
" \"\"\"Search for a document by id.\"\"\"\n",
|
||||
" doc_id = str(id)\n",
|
||||
" if doc_id not in self.documents:\n",
|
||||
" msg = f\"Document with id '{id}' not found\"\n",
|
||||
" raise KeyError(msg)\n",
|
||||
"\n",
|
||||
" return self.documents[doc_id]\n",
|
||||
"\n",
|
||||
" def get_stats(self) -> dict[str, Any]:\n",
|
||||
" \"\"\"Get statistics about the vector store (custom method).\"\"\"\n",
|
||||
" return {\n",
|
||||
" \"index_name\": self.index_name,\n",
|
||||
" \"document_count\": len(self.documents),\n",
|
||||
" \"vector_count\": len(self.vectors),\n",
|
||||
" \"connected\": self.connected,\n",
|
||||
" \"vector_dimension\": len(next(iter(self.vectors.values())))\n",
|
||||
" if self.vectors\n",
|
||||
" else 0,\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"print(\"✅ SimpleInMemoryVectorStore class defined!\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Step 4: Register the Custom Vector Store\n",
|
||||
"\n",
|
||||
"Now let's register our custom vector store with the `VectorStoreFactory` so it can be used throughout GraphRAG."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Register our custom vector store with a unique identifier\n",
|
||||
"CUSTOM_VECTOR_STORE_TYPE = \"simple_memory\"\n",
|
||||
"\n",
|
||||
"# Register the vector store class\n",
|
||||
"VectorStoreFactory.register(CUSTOM_VECTOR_STORE_TYPE, SimpleInMemoryVectorStore)\n",
|
||||
"\n",
|
||||
"print(f\"✅ Registered custom vector store with type: '{CUSTOM_VECTOR_STORE_TYPE}'\")\n",
|
||||
"\n",
|
||||
"# Verify registration\n",
|
||||
"available_types = VectorStoreFactory.get_vector_store_types()\n",
|
||||
"print(f\"\\n📋 Available vector store types: {available_types}\")\n",
|
||||
"print(\n",
|
||||
" f\"🔍 Is our custom type supported? {VectorStoreFactory.is_supported_type(CUSTOM_VECTOR_STORE_TYPE)}\"\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Step 5: Test the Custom Vector Store\n",
|
||||
"\n",
|
||||
"Let's create some sample data and test our custom vector store implementation."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Create sample documents with mock embeddings\n",
|
||||
"def create_mock_embedding(dimension: int = 384) -> list[float]:\n",
|
||||
" \"\"\"Create a random embedding vector for testing.\"\"\"\n",
|
||||
" return np.random.normal(0, 1, dimension).tolist()\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# Sample documents\n",
|
||||
"sample_documents = [\n",
|
||||
" VectorStoreDocument(\n",
|
||||
" id=\"doc_1\",\n",
|
||||
" text=\"GraphRAG is a powerful knowledge graph extraction and reasoning framework.\",\n",
|
||||
" vector=create_mock_embedding(),\n",
|
||||
" attributes={\"category\": \"technology\", \"source\": \"documentation\"},\n",
|
||||
" ),\n",
|
||||
" VectorStoreDocument(\n",
|
||||
" id=\"doc_2\",\n",
|
||||
" text=\"Vector stores enable efficient similarity search over high-dimensional data.\",\n",
|
||||
" vector=create_mock_embedding(),\n",
|
||||
" attributes={\"category\": \"technology\", \"source\": \"research\"},\n",
|
||||
" ),\n",
|
||||
" VectorStoreDocument(\n",
|
||||
" id=\"doc_3\",\n",
|
||||
" text=\"Machine learning models can process and understand natural language text.\",\n",
|
||||
" vector=create_mock_embedding(),\n",
|
||||
" attributes={\"category\": \"AI\", \"source\": \"article\"},\n",
|
||||
" ),\n",
|
||||
" VectorStoreDocument(\n",
|
||||
" id=\"doc_4\",\n",
|
||||
" text=\"Custom implementations allow for specialized behavior and integration.\",\n",
|
||||
" vector=create_mock_embedding(),\n",
|
||||
" attributes={\"category\": \"development\", \"source\": \"tutorial\"},\n",
|
||||
" ),\n",
|
||||
"]\n",
|
||||
"\n",
|
||||
"print(f\"📝 Created {len(sample_documents)} sample documents\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Test creating vector store using the factory\n",
|
||||
"schema = VectorStoreSchemaConfig(index_name=\"test_collection\")\n",
|
||||
"\n",
|
||||
"# Create vector store instance using factory\n",
|
||||
"vector_store = VectorStoreFactory.create_vector_store(\n",
|
||||
" CUSTOM_VECTOR_STORE_TYPE, vector_store_schema_config=schema\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"print(f\"✅ Created vector store instance: {type(vector_store).__name__}\")\n",
|
||||
"print(f\"📊 Initial stats: {vector_store.get_stats()}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Connect and load documents\n",
|
||||
"vector_store.connect()\n",
|
||||
"vector_store.load_documents(sample_documents)\n",
|
||||
"\n",
|
||||
"print(f\"📊 Updated stats: {vector_store.get_stats()}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Test similarity search\n",
|
||||
"query_vector = create_mock_embedding() # Random query vector for testing\n",
|
||||
"\n",
|
||||
"search_results = vector_store.similarity_search_by_vector(\n",
|
||||
" query_vector,\n",
|
||||
" k=3, # Get top 3 similar documents\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"print(f\"🔍 Found {len(search_results)} similar documents:\\n\")\n",
|
||||
"\n",
|
||||
"for i, result in enumerate(search_results, 1):\n",
|
||||
" doc = result.document\n",
|
||||
" print(f\"{i}. ID: {doc.id}\")\n",
|
||||
" print(f\" Text: {doc.text[:60]}...\")\n",
|
||||
" print(f\" Similarity Score: {result.score:.4f}\")\n",
|
||||
" print(f\" Category: {doc.attributes.get('category', 'N/A')}\")\n",
|
||||
" print()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Test search by ID\n",
|
||||
"try:\n",
|
||||
" found_doc = vector_store.search_by_id(\"doc_2\")\n",
|
||||
" print(\"✅ Found document by ID:\")\n",
|
||||
" print(f\" ID: {found_doc.id}\")\n",
|
||||
" print(f\" Text: {found_doc.text}\")\n",
|
||||
" print(f\" Attributes: {found_doc.attributes}\")\n",
|
||||
"except KeyError as e:\n",
|
||||
" print(f\"❌ Error: {e}\")\n",
|
||||
"\n",
|
||||
"# Test filter by ID\n",
|
||||
"id_filter = vector_store.filter_by_id([\"doc_1\", \"doc_3\"])\n",
|
||||
"print(f\"\\n🔧 ID filter result: {id_filter}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Step 6: Configuration for GraphRAG\n",
|
||||
"\n",
|
||||
"Now let's see how you would configure GraphRAG to use your custom vector store in a settings file."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Example GraphRAG yaml settings\n",
|
||||
"example_settings = {\n",
|
||||
" \"vector_store\": {\n",
|
||||
" \"default_vector_store\": {\n",
|
||||
" \"type\": CUSTOM_VECTOR_STORE_TYPE, # \"simple_memory\"\n",
|
||||
" \"collection_name\": \"graphrag_entities\",\n",
|
||||
" # Add any custom parameters your vector store needs\n",
|
||||
" \"custom_parameter\": \"custom_value\",\n",
|
||||
" }\n",
|
||||
" },\n",
|
||||
" # Other GraphRAG configuration...\n",
|
||||
" \"models\": {\n",
|
||||
" \"default_embedding_model\": {\n",
|
||||
" \"type\": \"openai_embedding\",\n",
|
||||
" \"model\": \"text-embedding-3-small\",\n",
|
||||
" }\n",
|
||||
" },\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"# Convert to YAML format for settings.yml\n",
|
||||
"yaml_config = yaml.dump(example_settings, default_flow_style=False, indent=2)\n",
|
||||
"\n",
|
||||
"print(\"📄 Example settings.yml configuration:\")\n",
|
||||
"print(\"=\" * 40)\n",
|
||||
"print(yaml_config)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Step 7: Integration with GraphRAG Pipeline\n",
|
||||
"\n",
|
||||
"Here's how your custom vector store would be used in a typical GraphRAG pipeline."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Example of how GraphRAG would use your custom vector store\n",
|
||||
"def simulate_graphrag_pipeline():\n",
|
||||
" \"\"\"Simulate how GraphRAG would use the custom vector store.\"\"\"\n",
|
||||
" print(\"🚀 Simulating GraphRAG pipeline with custom vector store...\\n\")\n",
|
||||
"\n",
|
||||
" # 1. GraphRAG creates vector store using factory\n",
|
||||
" schema = VectorStoreSchemaConfig(index_name=\"graphrag_entities\")\n",
|
||||
"\n",
|
||||
" store = VectorStoreFactory.create_vector_store(\n",
|
||||
" CUSTOM_VECTOR_STORE_TYPE,\n",
|
||||
" vector_store_schema_config=schema,\n",
|
||||
" similarity_threshold=0.3,\n",
|
||||
" )\n",
|
||||
" store.connect()\n",
|
||||
"\n",
|
||||
" print(\"✅ Step 1: Vector store created and connected\")\n",
|
||||
"\n",
|
||||
" # 2. During indexing, GraphRAG loads extracted entities\n",
|
||||
" entity_documents = [\n",
|
||||
" VectorStoreDocument(\n",
|
||||
" id=f\"entity_{i}\",\n",
|
||||
" text=f\"Entity {i} description: Important concept in the knowledge graph\",\n",
|
||||
" vector=create_mock_embedding(),\n",
|
||||
" attributes={\"type\": \"entity\", \"importance\": i % 3 + 1},\n",
|
||||
" )\n",
|
||||
" for i in range(10)\n",
|
||||
" ]\n",
|
||||
"\n",
|
||||
" store.load_documents(entity_documents)\n",
|
||||
" print(f\"✅ Step 2: Loaded {len(entity_documents)} entity documents\")\n",
|
||||
"\n",
|
||||
" # 3. During query time, GraphRAG searches for relevant entities\n",
|
||||
" query_embedding = create_mock_embedding()\n",
|
||||
" relevant_entities = store.similarity_search_by_vector(query_embedding, k=5)\n",
|
||||
"\n",
|
||||
" print(f\"✅ Step 3: Found {len(relevant_entities)} relevant entities for query\")\n",
|
||||
"\n",
|
||||
" # 4. GraphRAG uses these entities for context building\n",
|
||||
" context_entities = [result.document for result in relevant_entities]\n",
|
||||
"\n",
|
||||
" print(\"✅ Step 4: Context built using retrieved entities\")\n",
|
||||
" print(f\"📊 Final stats: {store.get_stats()}\")\n",
|
||||
"\n",
|
||||
" return context_entities\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# Run the simulation\n",
|
||||
"context = simulate_graphrag_pipeline()\n",
|
||||
"print(f\"\\n🎯 Retrieved {len(context)} entities for context building\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Step 8: Testing and Validation\n",
|
||||
"\n",
|
||||
"Let's create a comprehensive test suite to ensure our vector store works correctly."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def test_custom_vector_store():\n",
|
||||
" \"\"\"Comprehensive test suite for the custom vector store.\"\"\"\n",
|
||||
" print(\"🧪 Running comprehensive vector store tests...\\n\")\n",
|
||||
"\n",
|
||||
" # Test 1: Basic functionality\n",
|
||||
" print(\"Test 1: Basic functionality\")\n",
|
||||
" store = VectorStoreFactory.create_vector_store(\n",
|
||||
" CUSTOM_VECTOR_STORE_TYPE,\n",
|
||||
" vector_store_schema_config=VectorStoreSchemaConfig(index_name=\"test\"),\n",
|
||||
" )\n",
|
||||
" store.connect()\n",
|
||||
"\n",
|
||||
" # Load test documents\n",
|
||||
" test_docs = sample_documents[:2]\n",
|
||||
" store.load_documents(test_docs)\n",
|
||||
"\n",
|
||||
" assert len(store.documents) == 2, \"Should have 2 documents\"\n",
|
||||
" assert len(store.vectors) == 2, \"Should have 2 vectors\"\n",
|
||||
" print(\"✅ Basic functionality test passed\")\n",
|
||||
"\n",
|
||||
" # Test 2: Search functionality\n",
|
||||
" print(\"\\nTest 2: Search functionality\")\n",
|
||||
" query_vec = create_mock_embedding()\n",
|
||||
" results = store.similarity_search_by_vector(query_vec, k=5)\n",
|
||||
"\n",
|
||||
" assert len(results) <= 2, \"Should not return more results than documents\"\n",
|
||||
" assert all(isinstance(r, VectorStoreSearchResult) for r in results), (\n",
|
||||
" \"Should return VectorStoreSearchResult objects\"\n",
|
||||
" )\n",
|
||||
" assert all(-1 <= r.score <= 1 for r in results), (\n",
|
||||
" \"Similarity scores should be between -1 and 1\"\n",
|
||||
" )\n",
|
||||
" print(\"✅ Search functionality test passed\")\n",
|
||||
"\n",
|
||||
" # Test 3: Search by ID\n",
|
||||
" print(\"\\nTest 3: Search by ID\")\n",
|
||||
" found_doc = store.search_by_id(\"doc_1\")\n",
|
||||
" assert found_doc.id == \"doc_1\", \"Should find correct document\"\n",
|
||||
"\n",
|
||||
" try:\n",
|
||||
" store.search_by_id(\"nonexistent\")\n",
|
||||
" assert False, \"Should raise KeyError for nonexistent ID\"\n",
|
||||
" except KeyError:\n",
|
||||
" pass # Expected\n",
|
||||
"\n",
|
||||
" print(\"✅ Search by ID test passed\")\n",
|
||||
"\n",
|
||||
" # Test 4: Filter functionality\n",
|
||||
" print(\"\\nTest 4: Filter functionality\")\n",
|
||||
" filter_result = store.filter_by_id([\"doc_1\", \"doc_2\"])\n",
|
||||
" assert filter_result == [\"doc_1\", \"doc_2\"], \"Should return filtered IDs\"\n",
|
||||
" print(\"✅ Filter functionality test passed\")\n",
|
||||
"\n",
|
||||
" # Test 5: Error handling\n",
|
||||
" print(\"\\nTest 5: Error handling\")\n",
|
||||
" disconnected_store = VectorStoreFactory.create_vector_store(\n",
|
||||
" CUSTOM_VECTOR_STORE_TYPE,\n",
|
||||
" vector_store_schema_config=VectorStoreSchemaConfig(index_name=\"test2\"),\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" try:\n",
|
||||
" disconnected_store.load_documents(test_docs)\n",
|
||||
" assert False, \"Should raise error when not connected\"\n",
|
||||
" except RuntimeError:\n",
|
||||
" pass # Expected\n",
|
||||
"\n",
|
||||
" try:\n",
|
||||
" disconnected_store.similarity_search_by_vector(query_vec)\n",
|
||||
" assert False, \"Should raise error when not connected\"\n",
|
||||
" except RuntimeError:\n",
|
||||
" pass # Expected\n",
|
||||
"\n",
|
||||
" print(\"✅ Error handling test passed\")\n",
|
||||
"\n",
|
||||
" print(\"\\n🎉 All tests passed! Your custom vector store is working correctly.\")\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# Run the tests\n",
|
||||
"test_custom_vector_store()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Summary and Next Steps\n",
|
||||
"\n",
|
||||
"Congratulations! You've successfully learned how to implement and register a custom vector store with GraphRAG. Here's what you accomplished:\n",
|
||||
"\n",
|
||||
"### What You Built\n",
|
||||
"- ✅ **Custom Vector Store Class**: Implemented `SimpleInMemoryVectorStore` with all required methods\n",
|
||||
"- ✅ **Factory Integration**: Registered your vector store with `VectorStoreFactory`\n",
|
||||
"- ✅ **Comprehensive Testing**: Validated functionality with a full test suite\n",
|
||||
"- ✅ **Configuration Examples**: Learned how to configure GraphRAG to use your vector store\n",
|
||||
"\n",
|
||||
"### Key Takeaways\n",
|
||||
"1. **Interface Compliance**: Always implement all methods from `BaseVectorStore`\n",
|
||||
"2. **Factory Pattern**: Use `VectorStoreFactory.register()` to make your vector store available\n",
|
||||
"3. **Configuration**: Vector stores are configured in GraphRAG settings files\n",
|
||||
"4. **Testing**: Thoroughly test all functionality before deploying\n",
|
||||
"\n",
|
||||
"### Next Steps\n",
|
||||
"Check out the API Overview notebook to learn how to index and query data via the graphrag API.\n",
|
||||
"\n",
|
||||
"### Resources\n",
|
||||
"- [GraphRAG Documentation](https://microsoft.github.io/graphrag/)\n",
|
||||
"\n",
|
||||
"Happy building! 🚀"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "graphrag",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.12.10"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 4
|
||||
}
|
||||
@ -20,11 +20,11 @@
|
||||
"from pathlib import Path\n",
|
||||
"\n",
|
||||
"import pandas as pd\n",
|
||||
"import tiktoken\n",
|
||||
"\n",
|
||||
"from graphrag.config.enums import ModelType\n",
|
||||
"from graphrag.config.models.drift_search_config import DRIFTSearchConfig\n",
|
||||
"from graphrag.config.models.language_model_config import LanguageModelConfig\n",
|
||||
"from graphrag.config.models.vector_store_schema_config import VectorStoreSchemaConfig\n",
|
||||
"from graphrag.language_model.manager import ModelManager\n",
|
||||
"from graphrag.query.indexer_adapters import (\n",
|
||||
" read_indexer_entities,\n",
|
||||
@ -37,6 +37,7 @@
|
||||
" DRIFTSearchContextBuilder,\n",
|
||||
")\n",
|
||||
"from graphrag.query.structured_search.drift_search.search import DRIFTSearch\n",
|
||||
"from graphrag.tokenizer.get_tokenizer import get_tokenizer\n",
|
||||
"from graphrag.vector_stores.lancedb import LanceDBVectorStore\n",
|
||||
"\n",
|
||||
"INPUT_DIR = \"./inputs/operation dulce\"\n",
|
||||
@ -62,12 +63,16 @@
|
||||
"# load description embeddings to an in-memory lancedb vectorstore\n",
|
||||
"# to connect to a remote db, specify url and port values.\n",
|
||||
"description_embedding_store = LanceDBVectorStore(\n",
|
||||
" collection_name=\"default-entity-description\",\n",
|
||||
" vector_store_schema_config=VectorStoreSchemaConfig(\n",
|
||||
" index_name=\"default-entity-description\"\n",
|
||||
" ),\n",
|
||||
")\n",
|
||||
"description_embedding_store.connect(db_uri=LANCEDB_URI)\n",
|
||||
"\n",
|
||||
"full_content_embedding_store = LanceDBVectorStore(\n",
|
||||
" collection_name=\"default-community-full_content\",\n",
|
||||
" vector_store_schema_config=VectorStoreSchemaConfig(\n",
|
||||
" index_name=\"default-community-full_content\"\n",
|
||||
" )\n",
|
||||
")\n",
|
||||
"full_content_embedding_store.connect(db_uri=LANCEDB_URI)\n",
|
||||
"\n",
|
||||
@ -94,33 +99,33 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"api_key = os.environ[\"GRAPHRAG_API_KEY\"]\n",
|
||||
"llm_model = os.environ[\"GRAPHRAG_LLM_MODEL\"]\n",
|
||||
"embedding_model = os.environ[\"GRAPHRAG_EMBEDDING_MODEL\"]\n",
|
||||
"\n",
|
||||
"chat_config = LanguageModelConfig(\n",
|
||||
" api_key=api_key,\n",
|
||||
" type=ModelType.OpenAIChat,\n",
|
||||
" model=llm_model,\n",
|
||||
" type=ModelType.Chat,\n",
|
||||
" model_provider=\"openai\",\n",
|
||||
" model=\"gpt-4.1\",\n",
|
||||
" max_retries=20,\n",
|
||||
")\n",
|
||||
"chat_model = ModelManager().get_or_create_chat_model(\n",
|
||||
" name=\"local_search\",\n",
|
||||
" model_type=ModelType.OpenAIChat,\n",
|
||||
" model_type=ModelType.Chat,\n",
|
||||
" config=chat_config,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"token_encoder = tiktoken.encoding_for_model(llm_model)\n",
|
||||
"tokenizer = get_tokenizer(chat_config)\n",
|
||||
"\n",
|
||||
"embedding_config = LanguageModelConfig(\n",
|
||||
" api_key=api_key,\n",
|
||||
" type=ModelType.OpenAIEmbedding,\n",
|
||||
" model=embedding_model,\n",
|
||||
" type=ModelType.Embedding,\n",
|
||||
" model_provider=\"openai\",\n",
|
||||
" model=\"text-embedding-3-small\",\n",
|
||||
" max_retries=20,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"text_embedder = ModelManager().get_or_create_embedding_model(\n",
|
||||
" name=\"local_search_embedding\",\n",
|
||||
" model_type=ModelType.OpenAIEmbedding,\n",
|
||||
" model_type=ModelType.Embedding,\n",
|
||||
" config=embedding_config,\n",
|
||||
")"
|
||||
]
|
||||
@ -173,12 +178,12 @@
|
||||
" reports=reports,\n",
|
||||
" entity_text_embeddings=description_embedding_store,\n",
|
||||
" text_units=text_units,\n",
|
||||
" token_encoder=token_encoder,\n",
|
||||
" tokenizer=tokenizer,\n",
|
||||
" config=drift_params,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"search = DRIFTSearch(\n",
|
||||
" model=chat_model, context_builder=context_builder, token_encoder=token_encoder\n",
|
||||
" model=chat_model, context_builder=context_builder, tokenizer=tokenizer\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
@ -212,7 +217,7 @@
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": ".venv",
|
||||
"display_name": "graphrag",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
@ -226,7 +231,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.9"
|
||||
"version": "3.12.10"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
@ -19,7 +19,6 @@
|
||||
"import os\n",
|
||||
"\n",
|
||||
"import pandas as pd\n",
|
||||
"import tiktoken\n",
|
||||
"\n",
|
||||
"from graphrag.config.enums import ModelType\n",
|
||||
"from graphrag.config.models.language_model_config import LanguageModelConfig\n",
|
||||
@ -32,7 +31,8 @@
|
||||
"from graphrag.query.structured_search.global_search.community_context import (\n",
|
||||
" GlobalCommunityContext,\n",
|
||||
")\n",
|
||||
"from graphrag.query.structured_search.global_search.search import GlobalSearch"
|
||||
"from graphrag.query.structured_search.global_search.search import GlobalSearch\n",
|
||||
"from graphrag.tokenizer.get_tokenizer import get_tokenizer"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -58,21 +58,21 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"api_key = os.environ[\"GRAPHRAG_API_KEY\"]\n",
|
||||
"llm_model = os.environ[\"GRAPHRAG_LLM_MODEL\"]\n",
|
||||
"\n",
|
||||
"config = LanguageModelConfig(\n",
|
||||
" api_key=api_key,\n",
|
||||
" type=ModelType.OpenAIChat,\n",
|
||||
" model=llm_model,\n",
|
||||
" type=ModelType.Chat,\n",
|
||||
" model_provider=\"openai\",\n",
|
||||
" model=\"gpt-4.1\",\n",
|
||||
" max_retries=20,\n",
|
||||
")\n",
|
||||
"model = ModelManager().get_or_create_chat_model(\n",
|
||||
" name=\"global_search\",\n",
|
||||
" model_type=ModelType.OpenAIChat,\n",
|
||||
" model_type=ModelType.Chat,\n",
|
||||
" config=config,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"token_encoder = tiktoken.encoding_for_model(llm_model)"
|
||||
"tokenizer = get_tokenizer(config)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -142,7 +142,7 @@
|
||||
" community_reports=reports,\n",
|
||||
" communities=communities,\n",
|
||||
" entities=entities, # default to None if you don't want to use community weights for ranking\n",
|
||||
" token_encoder=token_encoder,\n",
|
||||
" tokenizer=tokenizer,\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
@ -193,7 +193,7 @@
|
||||
"search_engine = GlobalSearch(\n",
|
||||
" model=model,\n",
|
||||
" context_builder=context_builder,\n",
|
||||
" token_encoder=token_encoder,\n",
|
||||
" tokenizer=tokenizer,\n",
|
||||
" max_data_tokens=12_000, # change this based on the token limit you have on your model (if you are using a model with 8k limit, a good setting could be 5000)\n",
|
||||
" map_llm_params=map_llm_params,\n",
|
||||
" reduce_llm_params=reduce_llm_params,\n",
|
||||
@ -241,7 +241,7 @@
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": ".venv",
|
||||
"display_name": "graphrag",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
@ -255,7 +255,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.9"
|
||||
"version": "3.12.10"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
@ -19,7 +19,6 @@
|
||||
"import os\n",
|
||||
"\n",
|
||||
"import pandas as pd\n",
|
||||
"import tiktoken\n",
|
||||
"\n",
|
||||
"from graphrag.config.enums import ModelType\n",
|
||||
"from graphrag.config.models.language_model_config import LanguageModelConfig\n",
|
||||
@ -57,22 +56,24 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from graphrag.tokenizer.get_tokenizer import get_tokenizer\n",
|
||||
"\n",
|
||||
"api_key = os.environ[\"GRAPHRAG_API_KEY\"]\n",
|
||||
"llm_model = os.environ[\"GRAPHRAG_LLM_MODEL\"]\n",
|
||||
"\n",
|
||||
"config = LanguageModelConfig(\n",
|
||||
" api_key=api_key,\n",
|
||||
" type=ModelType.OpenAIChat,\n",
|
||||
" model=llm_model,\n",
|
||||
" type=ModelType.Chat,\n",
|
||||
" model_provider=\"openai\",\n",
|
||||
" model=\"gpt-4.1\",\n",
|
||||
" max_retries=20,\n",
|
||||
")\n",
|
||||
"model = ModelManager().get_or_create_chat_model(\n",
|
||||
" name=\"global_search\",\n",
|
||||
" model_type=ModelType.OpenAIChat,\n",
|
||||
" model_type=ModelType.Chat,\n",
|
||||
" config=config,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"token_encoder = tiktoken.encoding_for_model(llm_model)"
|
||||
"tokenizer = get_tokenizer(config)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -155,11 +156,11 @@
|
||||
" community_reports=reports,\n",
|
||||
" communities=communities,\n",
|
||||
" entities=entities, # default to None if you don't want to use community weights for ranking\n",
|
||||
" token_encoder=token_encoder,\n",
|
||||
" tokenizer=tokenizer,\n",
|
||||
" dynamic_community_selection=True,\n",
|
||||
" dynamic_community_selection_kwargs={\n",
|
||||
" \"model\": model,\n",
|
||||
" \"token_encoder\": token_encoder,\n",
|
||||
" \"tokenizer\": tokenizer,\n",
|
||||
" },\n",
|
||||
")"
|
||||
]
|
||||
@ -211,7 +212,7 @@
|
||||
"search_engine = GlobalSearch(\n",
|
||||
" model=model,\n",
|
||||
" context_builder=context_builder,\n",
|
||||
" token_encoder=token_encoder,\n",
|
||||
" tokenizer=tokenizer,\n",
|
||||
" max_data_tokens=12_000, # change this based on the token limit you have on your model (if you are using a model with 8k limit, a good setting could be 5000)\n",
|
||||
" map_llm_params=map_llm_params,\n",
|
||||
" reduce_llm_params=reduce_llm_params,\n",
|
||||
@ -255,7 +256,7 @@
|
||||
"prompt_tokens = result.prompt_tokens_categories[\"build_context\"]\n",
|
||||
"output_tokens = result.output_tokens_categories[\"build_context\"]\n",
|
||||
"print(\n",
|
||||
" f\"Build context ({llm_model})\\nLLM calls: {llm_calls}. Prompt tokens: {prompt_tokens}. Output tokens: {output_tokens}.\"\n",
|
||||
" f\"Build context LLM calls: {llm_calls}. Prompt tokens: {prompt_tokens}. Output tokens: {output_tokens}.\"\n",
|
||||
")\n",
|
||||
"# inspect number of LLM calls and tokens in map-reduce\n",
|
||||
"llm_calls = result.llm_calls_categories[\"map\"] + result.llm_calls_categories[\"reduce\"]\n",
|
||||
@ -266,14 +267,14 @@
|
||||
" result.output_tokens_categories[\"map\"] + result.output_tokens_categories[\"reduce\"]\n",
|
||||
")\n",
|
||||
"print(\n",
|
||||
" f\"Map-reduce ({llm_model})\\nLLM calls: {llm_calls}. Prompt tokens: {prompt_tokens}. Output tokens: {output_tokens}.\"\n",
|
||||
" f\"Map-reduce LLM calls: {llm_calls}. Prompt tokens: {prompt_tokens}. Output tokens: {output_tokens}.\"\n",
|
||||
")"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": ".venv",
|
||||
"display_name": "graphrag",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
@ -287,7 +288,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.9"
|
||||
"version": "3.12.10"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
@ -202,10 +202,11 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from graphrag.index.flows.generate_text_embeddings import generate_text_embeddings\n",
|
||||
"\n",
|
||||
"from graphrag.cache.factory import CacheFactory\n",
|
||||
"from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks\n",
|
||||
"from graphrag.config.embeddings import get_embedded_fields, get_embedding_settings\n",
|
||||
"from graphrag.index.flows.generate_text_embeddings import generate_text_embeddings\n",
|
||||
"\n",
|
||||
"# We only need to re-run the embeddings workflow, to ensure that embeddings for all required search fields are in place\n",
|
||||
"# We'll construct the context and run this function flow directly to avoid everything else\n",
|
||||
|
||||
194
docs/examples_notebooks/input_documents.ipynb
Normal file
194
docs/examples_notebooks/input_documents.ipynb
Normal file
@ -0,0 +1,194 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Copyright (c) 2024 Microsoft Corporation.\n",
|
||||
"# Licensed under the MIT License."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Example of indexing from an existing in-memory dataframe\n",
|
||||
"\n",
|
||||
"Newer versions of GraphRAG let you submit a dataframe directly instead of running through the input processing step. This notebook demonstrates with regular or update runs.\n",
|
||||
"\n",
|
||||
"If performing an update, the assumption is that your dataframe contains only the new documents to add to the index."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from pathlib import Path\n",
|
||||
"from pprint import pprint\n",
|
||||
"\n",
|
||||
"import pandas as pd\n",
|
||||
"\n",
|
||||
"import graphrag.api as api\n",
|
||||
"from graphrag.config.load_config import load_config\n",
|
||||
"from graphrag.index.typing.pipeline_run_result import PipelineRunResult"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"PROJECT_DIRECTORY = \"<your project directory>\"\n",
|
||||
"UPDATE = False\n",
|
||||
"FILENAME = \"new_documents.parquet\" if UPDATE else \"<original_documents>.parquet\"\n",
|
||||
"inputs = pd.read_parquet(f\"{PROJECT_DIRECTORY}/input/{FILENAME}\")\n",
|
||||
"# Only the bare minimum for input. These are the same fields that would be present after the load_input_documents workflow\n",
|
||||
"inputs = inputs.loc[:, [\"id\", \"title\", \"text\", \"creation_date\"]]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Generate a `GraphRagConfig` object"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"graphrag_config = load_config(Path(PROJECT_DIRECTORY))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Indexing API\n",
|
||||
"\n",
|
||||
"*Indexing* is the process of ingesting raw text data and constructing a knowledge graph. GraphRAG currently supports plaintext (`.txt`) and `.csv` file formats."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Build an index"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"index_result: list[PipelineRunResult] = await api.build_index(\n",
|
||||
" config=graphrag_config, input_documents=inputs, is_update_run=UPDATE\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# index_result is a list of workflows that make up the indexing pipeline that was run\n",
|
||||
"for workflow_result in index_result:\n",
|
||||
" status = f\"error\\n{workflow_result.errors}\" if workflow_result.errors else \"success\"\n",
|
||||
" print(f\"Workflow Name: {workflow_result.workflow}\\tStatus: {status}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Query an index\n",
|
||||
"\n",
|
||||
"To query an index, several index files must first be read into memory and passed to the query API. "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"entities = pd.read_parquet(f\"{PROJECT_DIRECTORY}/output/entities.parquet\")\n",
|
||||
"communities = pd.read_parquet(f\"{PROJECT_DIRECTORY}/output/communities.parquet\")\n",
|
||||
"community_reports = pd.read_parquet(\n",
|
||||
" f\"{PROJECT_DIRECTORY}/output/community_reports.parquet\"\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"response, context = await api.global_search(\n",
|
||||
" config=graphrag_config,\n",
|
||||
" entities=entities,\n",
|
||||
" communities=communities,\n",
|
||||
" community_reports=community_reports,\n",
|
||||
" community_level=2,\n",
|
||||
" dynamic_community_selection=False,\n",
|
||||
" response_type=\"Multiple Paragraphs\",\n",
|
||||
" query=\"What are the top five themes of the dataset?\",\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"The response object is the official reponse from graphrag while the context object holds various metadata regarding the querying process used to obtain the final response."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"print(response)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Digging into the context a bit more provides users with extremely granular information such as what sources of data (down to the level of text chunks) were ultimately retrieved and used as part of the context sent to the LLM model)."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"pprint(context) # noqa: T203"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "graphrag",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.12.10"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
@ -19,8 +19,8 @@
|
||||
"import os\n",
|
||||
"\n",
|
||||
"import pandas as pd\n",
|
||||
"import tiktoken\n",
|
||||
"\n",
|
||||
"from graphrag.config.models.vector_store_schema_config import VectorStoreSchemaConfig\n",
|
||||
"from graphrag.query.context_builder.entity_extraction import EntityVectorStoreKey\n",
|
||||
"from graphrag.query.indexer_adapters import (\n",
|
||||
" read_indexer_covariates,\n",
|
||||
@ -102,7 +102,9 @@
|
||||
"# load description embeddings to an in-memory lancedb vectorstore\n",
|
||||
"# to connect to a remote db, specify url and port values.\n",
|
||||
"description_embedding_store = LanceDBVectorStore(\n",
|
||||
" collection_name=\"default-entity-description\",\n",
|
||||
" vector_store_schema_config=VectorStoreSchemaConfig(\n",
|
||||
" index_name=\"default-entity-description\"\n",
|
||||
" )\n",
|
||||
")\n",
|
||||
"description_embedding_store.connect(db_uri=LANCEDB_URI)\n",
|
||||
"\n",
|
||||
@ -195,37 +197,38 @@
|
||||
"from graphrag.config.enums import ModelType\n",
|
||||
"from graphrag.config.models.language_model_config import LanguageModelConfig\n",
|
||||
"from graphrag.language_model.manager import ModelManager\n",
|
||||
"from graphrag.tokenizer.get_tokenizer import get_tokenizer\n",
|
||||
"\n",
|
||||
"api_key = os.environ[\"GRAPHRAG_API_KEY\"]\n",
|
||||
"llm_model = os.environ[\"GRAPHRAG_LLM_MODEL\"]\n",
|
||||
"embedding_model = os.environ[\"GRAPHRAG_EMBEDDING_MODEL\"]\n",
|
||||
"\n",
|
||||
"chat_config = LanguageModelConfig(\n",
|
||||
" api_key=api_key,\n",
|
||||
" type=ModelType.OpenAIChat,\n",
|
||||
" model=llm_model,\n",
|
||||
" type=ModelType.Chat,\n",
|
||||
" model_provider=\"openai\",\n",
|
||||
" model=\"gpt-4.1\",\n",
|
||||
" max_retries=20,\n",
|
||||
")\n",
|
||||
"chat_model = ModelManager().get_or_create_chat_model(\n",
|
||||
" name=\"local_search\",\n",
|
||||
" model_type=ModelType.OpenAIChat,\n",
|
||||
" model_type=ModelType.Chat,\n",
|
||||
" config=chat_config,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"token_encoder = tiktoken.encoding_for_model(llm_model)\n",
|
||||
"\n",
|
||||
"embedding_config = LanguageModelConfig(\n",
|
||||
" api_key=api_key,\n",
|
||||
" type=ModelType.OpenAIEmbedding,\n",
|
||||
" model=embedding_model,\n",
|
||||
" type=ModelType.Embedding,\n",
|
||||
" model_provider=\"openai\",\n",
|
||||
" model=\"text-embedding-3-small\",\n",
|
||||
" max_retries=20,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"text_embedder = ModelManager().get_or_create_embedding_model(\n",
|
||||
" name=\"local_search_embedding\",\n",
|
||||
" model_type=ModelType.OpenAIEmbedding,\n",
|
||||
" model_type=ModelType.Embedding,\n",
|
||||
" config=embedding_config,\n",
|
||||
")"
|
||||
")\n",
|
||||
"\n",
|
||||
"tokenizer = get_tokenizer(chat_config)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -251,7 +254,7 @@
|
||||
" entity_text_embeddings=description_embedding_store,\n",
|
||||
" embedding_vectorstore_key=EntityVectorStoreKey.ID, # if the vectorstore uses entity title as ids, set this to EntityVectorStoreKey.TITLE\n",
|
||||
" text_embedder=text_embedder,\n",
|
||||
" token_encoder=token_encoder,\n",
|
||||
" tokenizer=tokenizer,\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
@ -314,7 +317,7 @@
|
||||
"search_engine = LocalSearch(\n",
|
||||
" model=chat_model,\n",
|
||||
" context_builder=context_builder,\n",
|
||||
" token_encoder=token_encoder,\n",
|
||||
" tokenizer=tokenizer,\n",
|
||||
" model_params=model_params,\n",
|
||||
" context_builder_params=local_context_params,\n",
|
||||
" response_type=\"multiple paragraphs\", # free form text describing the response type and format, can be anything, e.g. prioritized list, single paragraph, multiple paragraphs, multiple-page report\n",
|
||||
@ -426,7 +429,7 @@
|
||||
"question_generator = LocalQuestionGen(\n",
|
||||
" model=chat_model,\n",
|
||||
" context_builder=context_builder,\n",
|
||||
" token_encoder=token_encoder,\n",
|
||||
" tokenizer=tokenizer,\n",
|
||||
" model_params=model_params,\n",
|
||||
" context_builder_params=local_context_params,\n",
|
||||
")"
|
||||
@ -451,7 +454,7 @@
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": ".venv",
|
||||
"display_name": "graphrag",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
@ -465,7 +468,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.9"
|
||||
"version": "3.12.10"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
@ -1,5 +1,7 @@
|
||||
# Getting Started
|
||||
|
||||
⚠️ GraphRAG can consume a lot of LLM resources! We strongly recommend starting with the tutorial dataset here until you understand how the system works, and consider experimenting with fast/inexpensive models first before committing to a big indexing job.
|
||||
|
||||
## Requirements
|
||||
|
||||
[Python 3.10-3.12](https://www.python.org/downloads/)
|
||||
@ -24,25 +26,25 @@ pip install graphrag
|
||||
We need to set up a data project and some initial configuration. First let's get a sample dataset ready:
|
||||
|
||||
```sh
|
||||
mkdir -p ./ragtest/input
|
||||
mkdir -p ./christmas/input
|
||||
```
|
||||
|
||||
Get a copy of A Christmas Carol by Charles Dickens from a trusted source:
|
||||
|
||||
```sh
|
||||
curl https://www.gutenberg.org/cache/epub/24022/pg24022.txt -o ./ragtest/input/book.txt
|
||||
curl https://www.gutenberg.org/cache/epub/24022/pg24022.txt -o ./christmas/input/book.txt
|
||||
```
|
||||
|
||||
## Set Up Your Workspace Variables
|
||||
|
||||
To initialize your workspace, first run the `graphrag init` command.
|
||||
Since we have already configured a directory named `./ragtest` in the previous step, run the following command:
|
||||
Since we have already configured a directory named `./christmas` in the previous step, run the following command:
|
||||
|
||||
```sh
|
||||
graphrag init --root ./ragtest
|
||||
graphrag init --root ./christmas
|
||||
```
|
||||
|
||||
This will create two files: `.env` and `settings.yaml` in the `./ragtest` directory.
|
||||
This will create two files: `.env` and `settings.yaml` in the `./christmas` directory.
|
||||
|
||||
- `.env` contains the environment variables required to run the GraphRAG pipeline. If you inspect the file, you'll see a single environment variable defined,
|
||||
`GRAPHRAG_API_KEY=<API_KEY>`. Replace `<API_KEY>` with your own OpenAI or Azure API key.
|
||||
@ -65,11 +67,10 @@ deployment_name: <azure_model_deployment_name>
|
||||
```
|
||||
|
||||
#### Using Managed Auth on Azure
|
||||
To use managed auth, add an additional value to your model config and comment out or remove the api_key line:
|
||||
To use managed auth, edit the auth_type in your model config and *remove* the api_key line:
|
||||
|
||||
```yaml
|
||||
auth_type: azure_managed_identity # Default auth_type is is api_key
|
||||
# api_key: ${GRAPHRAG_API_KEY}
|
||||
```
|
||||
|
||||
You will also need to login with [az login](https://learn.microsoft.com/en-us/cli/azure/authenticate-azure-cli) and select the subscription with your endpoint.
|
||||
@ -79,13 +80,13 @@ You will also need to login with [az login](https://learn.microsoft.com/en-us/cl
|
||||
Finally we'll run the pipeline!
|
||||
|
||||
```sh
|
||||
graphrag index --root ./ragtest
|
||||
graphrag index --root ./christmas
|
||||
```
|
||||
|
||||

|
||||
|
||||
This process will take some time to run. This depends on the size of your input data, what model you're using, and the text chunk size being used (these can be configured in your `settings.yaml` file).
|
||||
Once the pipeline is complete, you should see a new folder called `./ragtest/output` with a series of parquet files.
|
||||
Once the pipeline is complete, you should see a new folder called `./christmas/output` with a series of parquet files.
|
||||
|
||||
# Using the Query Engine
|
||||
|
||||
@ -95,7 +96,7 @@ Here is an example using Global search to ask a high-level question:
|
||||
|
||||
```sh
|
||||
graphrag query \
|
||||
--root ./ragtest \
|
||||
--root ./christmas \
|
||||
--method global \
|
||||
--query "What are the top themes in this story?"
|
||||
```
|
||||
@ -104,7 +105,7 @@ Here is an example using Local search to ask a more specific question about a pa
|
||||
|
||||
```sh
|
||||
graphrag query \
|
||||
--root ./ragtest \
|
||||
--root ./christmas \
|
||||
--method local \
|
||||
--query "Who is Scrooge and what are his main relationships?"
|
||||
```
|
||||
@ -116,4 +117,4 @@ Please refer to [Query Engine](query/overview.md) docs for detailed information
|
||||
- For more details about configuring GraphRAG, see the [configuration documentation](config/overview.md).
|
||||
- To learn more about Initialization, refer to the [Initialization documentation](config/init.md).
|
||||
- For more details about using the CLI, refer to the [CLI documentation](cli.md).
|
||||
- Check out our [visualization guide](visualization_guide.md) for a more interactive experience in debugging and exploring the knowledge graph.
|
||||
- Check out our [visualization guide](visualization_guide.md) for a more interactive experience in debugging and exploring the knowledge graph.
|
||||
|
||||
@ -47,6 +47,7 @@ At query time, these structures are used to provide materials for the LLM contex
|
||||
- [_Global Search_](query/global_search.md) for reasoning about holistic questions about the corpus by leveraging the community summaries.
|
||||
- [_Local Search_](query/local_search.md) for reasoning about specific entities by fanning-out to their neighbors and associated concepts.
|
||||
- [_DRIFT Search_](query/drift_search.md) for reasoning about specific entities by fanning-out to their neighbors and associated concepts, but with the added context of community information.
|
||||
- _Basic Search_ for those times when your query is best answered by baseline RAG (standard top _k_ vector search).
|
||||
|
||||
### Prompt Tuning
|
||||
|
||||
|
||||
@ -32,3 +32,20 @@ The GraphRAG library was designed with LLM interactions in mind, and a common se
|
||||
Because of these potential error cases, we've added a cache layer around LLM interactions.
|
||||
When completion requests are made using the same input set (prompt and tuning parameters), we return a cached result if one exists.
|
||||
This allows our indexer to be more resilient to network issues, to act idempotently, and to provide a more efficient end-user experience.
|
||||
|
||||
### Providers & Factories
|
||||
|
||||
Several subsystems within GraphRAG use a factory pattern to register and retrieve provider implementations. This allows deep customization to support models, storage, and so on that you may use but isn't built directly into GraphRAG.
|
||||
|
||||
The following subsystems use a factory pattern that allows you to register your own implementations:
|
||||
|
||||
- [language model](https://github.com/microsoft/graphrag/blob/main/graphrag/language_model/factory.py) - implement your own `chat` and `embed` methods to use a model provider of choice beyond the built-in OpenAI/Azure support
|
||||
- [cache](https://github.com/microsoft/graphrag/blob/main/graphrag/cache/factory.py) - create your own cache storage location in addition to the file, blob, and CosmosDB ones we provide
|
||||
- [logger](https://github.com/microsoft/graphrag/blob/main/graphrag/logger/factory.py) - create your own log writing location in addition to the built-in file and blob storage
|
||||
- [storage](https://github.com/microsoft/graphrag/blob/main/graphrag/storage/factory.py) - create your own storage provider (database, etc.) beyond the file, blob, and CosmosDB ones built in
|
||||
- [vector store](https://github.com/microsoft/graphrag/blob/main/graphrag/vector_stores/factory.py) - implement your own vector store other than the built-in lancedb, Azure AI Search, and CosmosDB ones built in
|
||||
- [pipeline + workflows](https://github.com/microsoft/graphrag/blob/main/graphrag/index/workflows/factory.py) - implement your own workflow steps with a custom `run_workflow` function, or register an entire pipeline (list of named workflows)
|
||||
|
||||
The links for each of these subsystems point to the source code of the factory, which includes registration of the default built-in implementations. In addition, we have a detailed discussion of [language models](../config/models.md), which includes and example of a custom provider, and a [sample notebook](../examples_notebooks/custom_vector_store.ipynb) that demonstrates a custom vector store.
|
||||
|
||||
All of these factories allow you to register an impl using any string name you would like, even overriding built-in ones directly.
|
||||
@ -16,6 +16,10 @@ All input formats are loaded within GraphRAG and passed to the indexing pipeline
|
||||
|
||||
Also see the [outputs](outputs.md) documentation for the final documents table schema saved to parquet after pipeline completion.
|
||||
|
||||
## Bring-your-own DataFrame
|
||||
|
||||
As of version 2.6.0, GraphRAG's [indexing API method](https://github.com/microsoft/graphrag/blob/main/graphrag/api/index.py) allows you to pass in your own pandas DataFrame and bypass all of the input loading/parsing described in the next section. This is convenient if you have content in a format or storage location we don't support out-of-the-box. __You must ensure that your input DataFrame conforms to the schema described above.__ All of the chunking behavior described later will proceed exactly the same.
|
||||
|
||||
## Formats
|
||||
|
||||
We support three file formats out-of-the-box. This covers the overwhelming majority of use cases we have encountered. If you have a different format, we recommend writing a script to convert to one of these, which are widely used and supported by many tools and libraries.
|
||||
|
||||
@ -26,8 +26,7 @@ After you have a config file you can run the pipeline using the CLI or the Pytho
|
||||
### CLI
|
||||
|
||||
```bash
|
||||
# Via Poetry
|
||||
poetry run poe index --root <data_root> # default config mode
|
||||
uv run poe index --root <data_root> # default config mode
|
||||
```
|
||||
|
||||
### Python API
|
||||
|
||||
@ -58,13 +58,13 @@ graphrag prompt-tune [--root ROOT] [--config CONFIG] [--domain DOMAIN] [--selec
|
||||
```bash
|
||||
python -m graphrag prompt-tune --root /path/to/project --config /path/to/settings.yaml --domain "environmental news" \
|
||||
--selection-method random --limit 10 --language English --max-tokens 2048 --chunk-size 256 --min-examples-required 3 \
|
||||
--no-entity-types --output /path/to/output
|
||||
--no-discover-entity-types --output /path/to/output
|
||||
```
|
||||
|
||||
or, with minimal configuration (suggested):
|
||||
|
||||
```bash
|
||||
python -m graphrag prompt-tune --root /path/to/project --config /path/to/settings.yaml --no-entity-types
|
||||
python -m graphrag prompt-tune --root /path/to/project --config /path/to/settings.yaml --no-discover-entity-types
|
||||
```
|
||||
|
||||
## Document Selection Methods
|
||||
@ -79,15 +79,7 @@ After that, it uses one of the following selection methods to pick a sample to w
|
||||
|
||||
## Modify Env Vars
|
||||
|
||||
After running auto tuning, you should modify the following environment variables (or config variables) to pick up the new prompts on your index run. Note: Please make sure to update the correct path to the generated prompts, in this example we are using the default "prompts" path.
|
||||
|
||||
- `GRAPHRAG_ENTITY_EXTRACTION_PROMPT_FILE` = "prompts/entity_extraction.txt"
|
||||
|
||||
- `GRAPHRAG_COMMUNITY_REPORT_PROMPT_FILE` = "prompts/community_report.txt"
|
||||
|
||||
- `GRAPHRAG_SUMMARIZE_DESCRIPTIONS_PROMPT_FILE` = "prompts/summarize_descriptions.txt"
|
||||
|
||||
or in your yaml config file:
|
||||
After running auto tuning, you should modify the following config variables to pick up the new prompts on your index run. Note: Please make sure to update the correct path to the generated prompts, in this example we are using the default "prompts" path.
|
||||
|
||||
```yaml
|
||||
entity_extraction:
|
||||
|
||||
@ -24,7 +24,7 @@ Below are the key parameters of the [DRIFTSearch class](https://github.com/micro
|
||||
- `llm`: OpenAI model object to be used for response generation
|
||||
- `context_builder`: [context builder](https://github.com/microsoft/graphrag/blob/main/graphrag/query/structured_search/drift_search/drift_context.py) object to be used for preparing context data from community reports and query information
|
||||
- `config`: model to define the DRIFT Search hyperparameters. [DRIFT Config model](https://github.com/microsoft/graphrag/blob/main/graphrag/config/models/drift_search_config.py)
|
||||
- `token_encoder`: token encoder for tracking the budget for the algorithm.
|
||||
- `tokenizer`: token encoder for tracking the budget for the algorithm.
|
||||
- `query_state`: a state object as defined in [Query State](https://github.com/microsoft/graphrag/blob/main/graphrag/query/structured_search/drift_search/state.py) that allows to track execution of a DRIFT Search instance, alongside follow ups and [DRIFT actions](https://github.com/microsoft/graphrag/blob/main/graphrag/query/structured_search/drift_search/action.py).
|
||||
|
||||
## How to Use
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
# API Notebooks
|
||||
|
||||
- [API Overview Notebook](../../examples_notebooks/api_overview.ipynb)
|
||||
- [Bring-Your-Own Vector Store](../../examples_notebooks/custom_vector_store.ipynb)
|
||||
|
||||
# Query Engine Notebooks
|
||||
|
||||
|
||||
@ -29,6 +29,9 @@
|
||||
"\n",
|
||||
"import pandas as pd\n",
|
||||
"import tiktoken\n",
|
||||
"from graphrag.query.llm.oai.chat_openai import ChatOpenAI\n",
|
||||
"from graphrag.query.llm.oai.embedding import OpenAIEmbedding\n",
|
||||
"from graphrag.query.llm.oai.typing import OpenaiApiType\n",
|
||||
"\n",
|
||||
"from graphrag.query.context_builder.entity_extraction import EntityVectorStoreKey\n",
|
||||
"from graphrag.query.indexer_adapters import (\n",
|
||||
@ -38,9 +41,6 @@
|
||||
" read_indexer_reports,\n",
|
||||
" read_indexer_text_units,\n",
|
||||
")\n",
|
||||
"from graphrag.query.llm.oai.chat_openai import ChatOpenAI\n",
|
||||
"from graphrag.query.llm.oai.embedding import OpenAIEmbedding\n",
|
||||
"from graphrag.query.llm.oai.typing import OpenaiApiType\n",
|
||||
"from graphrag.query.structured_search.local_search.mixed_context import (\n",
|
||||
" LocalSearchMixedContext,\n",
|
||||
")\n",
|
||||
|
||||
@ -2,10 +2,3 @@
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""The GraphRAG package."""
|
||||
|
||||
import logging
|
||||
|
||||
from graphrag.logger.standard_logging import init_console_logger
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
init_console_logger()
|
||||
|
||||
@ -9,6 +9,9 @@ Backwards compatibility is not guaranteed at this time.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks
|
||||
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
|
||||
@ -17,7 +20,6 @@ from graphrag.config.models.graph_rag_config import GraphRagConfig
|
||||
from graphrag.index.run.run_pipeline import run_pipeline
|
||||
from graphrag.index.run.utils import create_callback_chain
|
||||
from graphrag.index.typing.pipeline_run_result import PipelineRunResult
|
||||
from graphrag.index.typing.workflow import WorkflowFunction
|
||||
from graphrag.index.workflows.factory import PipelineFactory
|
||||
from graphrag.logger.standard_logging import init_loggers
|
||||
|
||||
@ -30,6 +32,9 @@ async def build_index(
|
||||
is_update_run: bool = False,
|
||||
memory_profile: bool = False,
|
||||
callbacks: list[WorkflowCallbacks] | None = None,
|
||||
additional_context: dict[str, Any] | None = None,
|
||||
verbose: bool = False,
|
||||
input_documents: pd.DataFrame | None = None,
|
||||
) -> list[PipelineRunResult]:
|
||||
"""Run the pipeline with the given configuration.
|
||||
|
||||
@ -43,13 +48,17 @@ async def build_index(
|
||||
Whether to enable memory profiling.
|
||||
callbacks : list[WorkflowCallbacks] | None default=None
|
||||
A list of callbacks to register.
|
||||
additional_context : dict[str, Any] | None default=None
|
||||
Additional context to pass to the pipeline run. This can be accessed in the pipeline state under the 'additional_context' key.
|
||||
input_documents : pd.DataFrame | None default=None.
|
||||
Override document loading and parsing and supply your own dataframe of documents to index.
|
||||
|
||||
Returns
|
||||
-------
|
||||
list[PipelineRunResult]
|
||||
The list of pipeline run results
|
||||
"""
|
||||
init_loggers(config=config)
|
||||
init_loggers(config=config, verbose=verbose)
|
||||
|
||||
# Create callbacks for pipeline lifecycle events if provided
|
||||
workflow_callbacks = (
|
||||
@ -73,6 +82,8 @@ async def build_index(
|
||||
config,
|
||||
callbacks=workflow_callbacks,
|
||||
is_update_run=is_update_run,
|
||||
additional_context=additional_context,
|
||||
input_documents=input_documents,
|
||||
):
|
||||
outputs.append(output)
|
||||
if output.errors and len(output.errors) > 0:
|
||||
@ -85,11 +96,6 @@ async def build_index(
|
||||
return outputs
|
||||
|
||||
|
||||
def register_workflow_function(name: str, workflow: WorkflowFunction):
|
||||
"""Register a custom workflow function. You can then include the name in the settings.yaml workflows list."""
|
||||
PipelineFactory.register(name, workflow)
|
||||
|
||||
|
||||
def _get_method(method: IndexingMethod | str, is_update_run: bool) -> str:
|
||||
m = method.value if isinstance(method, IndexingMethod) else method
|
||||
return f"{m}-update" if is_update_run else m
|
||||
|
||||
@ -47,6 +47,7 @@ from graphrag.prompt_tune.generator.language import detect_language
|
||||
from graphrag.prompt_tune.generator.persona import generate_persona
|
||||
from graphrag.prompt_tune.loader.input import load_docs_in_chunks
|
||||
from graphrag.prompt_tune.types import DocSelectionType
|
||||
from graphrag.tokenizer.get_tokenizer import get_tokenizer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -67,6 +68,7 @@ async def generate_indexing_prompts(
|
||||
min_examples_required: PositiveInt = 2,
|
||||
n_subset_max: PositiveInt = 300,
|
||||
k: PositiveInt = 15,
|
||||
verbose: bool = False,
|
||||
) -> tuple[str, str, str]:
|
||||
"""Generate indexing prompts.
|
||||
|
||||
@ -89,7 +91,7 @@ async def generate_indexing_prompts(
|
||||
-------
|
||||
tuple[str, str, str]: entity extraction prompt, entity summarization prompt, community summarization prompt
|
||||
"""
|
||||
init_loggers(config=config)
|
||||
init_loggers(config=config, verbose=verbose, filename="prompt-tuning.log")
|
||||
|
||||
# Retrieve documents
|
||||
logger.info("Chunking documents...")
|
||||
@ -165,7 +167,7 @@ async def generate_indexing_prompts(
|
||||
examples=examples,
|
||||
language=language,
|
||||
json_mode=False, # config.llm.model_supports_json should be used, but these prompts are used in non-json mode by the index engine
|
||||
encoding_model=extract_graph_llm_settings.encoding_model,
|
||||
tokenizer=get_tokenizer(model_config=extract_graph_llm_settings),
|
||||
max_token_count=max_tokens,
|
||||
min_examples_required=min_examples_required,
|
||||
)
|
||||
|
||||
@ -93,7 +93,7 @@ async def global_search(
|
||||
-------
|
||||
TODO: Document the search response type and format.
|
||||
"""
|
||||
init_loggers(config=config, verbose=verbose)
|
||||
init_loggers(config=config, verbose=verbose, filename="query.log")
|
||||
|
||||
callbacks = callbacks or []
|
||||
full_response = ""
|
||||
@ -156,7 +156,7 @@ def global_search_streaming(
|
||||
-------
|
||||
TODO: Document the search response type and format.
|
||||
"""
|
||||
init_loggers(config=config, verbose=verbose)
|
||||
init_loggers(config=config, verbose=verbose, filename="query.log")
|
||||
|
||||
communities_ = read_indexer_communities(communities, community_reports)
|
||||
reports = read_indexer_reports(
|
||||
@ -229,7 +229,11 @@ async def multi_index_global_search(
|
||||
-------
|
||||
TODO: Document the search response type and format.
|
||||
"""
|
||||
init_loggers(config=config, verbose=verbose)
|
||||
init_loggers(config=config, verbose=verbose, filename="query.log")
|
||||
|
||||
logger.warning(
|
||||
"Multi-index search is deprecated and will be removed in GraphRAG v3."
|
||||
)
|
||||
|
||||
# Streaming not supported yet
|
||||
if streaming:
|
||||
@ -369,7 +373,7 @@ async def local_search(
|
||||
-------
|
||||
TODO: Document the search response type and format.
|
||||
"""
|
||||
init_loggers(config=config, verbose=verbose)
|
||||
init_loggers(config=config, verbose=verbose, filename="query.log")
|
||||
|
||||
callbacks = callbacks or []
|
||||
full_response = ""
|
||||
@ -435,7 +439,7 @@ def local_search_streaming(
|
||||
-------
|
||||
TODO: Document the search response type and format.
|
||||
"""
|
||||
init_loggers(config=config, verbose=verbose)
|
||||
init_loggers(config=config, verbose=verbose, filename="query.log")
|
||||
|
||||
vector_store_args = {}
|
||||
for index, store in config.vector_store.items():
|
||||
@ -508,8 +512,11 @@ async def multi_index_local_search(
|
||||
-------
|
||||
TODO: Document the search response type and format.
|
||||
"""
|
||||
init_loggers(config=config, verbose=verbose)
|
||||
init_loggers(config=config, verbose=verbose, filename="query.log")
|
||||
|
||||
logger.warning(
|
||||
"Multi-index search is deprecated and will be removed in GraphRAG v3."
|
||||
)
|
||||
# Streaming not supported yet
|
||||
if streaming:
|
||||
message = "Streaming not yet implemented for multi_index_local_search"
|
||||
@ -730,7 +737,7 @@ async def drift_search(
|
||||
-------
|
||||
TODO: Document the search response type and format.
|
||||
"""
|
||||
init_loggers(config=config, verbose=verbose)
|
||||
init_loggers(config=config, verbose=verbose, filename="query.log")
|
||||
|
||||
callbacks = callbacks or []
|
||||
full_response = ""
|
||||
@ -792,7 +799,7 @@ def drift_search_streaming(
|
||||
-------
|
||||
TODO: Document the search response type and format.
|
||||
"""
|
||||
init_loggers(config=config, verbose=verbose)
|
||||
init_loggers(config=config, verbose=verbose, filename="query.log")
|
||||
|
||||
vector_store_args = {}
|
||||
for index, store in config.vector_store.items():
|
||||
@ -872,7 +879,11 @@ async def multi_index_drift_search(
|
||||
-------
|
||||
TODO: Document the search response type and format.
|
||||
"""
|
||||
init_loggers(config=config, verbose=verbose)
|
||||
init_loggers(config=config, verbose=verbose, filename="query.log")
|
||||
|
||||
logger.warning(
|
||||
"Multi-index search is deprecated and will be removed in GraphRAG v3."
|
||||
)
|
||||
|
||||
# Streaming not supported yet
|
||||
if streaming:
|
||||
@ -1065,7 +1076,7 @@ async def basic_search(
|
||||
-------
|
||||
TODO: Document the search response type and format.
|
||||
"""
|
||||
init_loggers(config=config, verbose=verbose)
|
||||
init_loggers(config=config, verbose=verbose, filename="query.log")
|
||||
|
||||
callbacks = callbacks or []
|
||||
full_response = ""
|
||||
@ -1111,7 +1122,7 @@ def basic_search_streaming(
|
||||
-------
|
||||
TODO: Document the search response type and format.
|
||||
"""
|
||||
init_loggers(config=config, verbose=verbose)
|
||||
init_loggers(config=config, verbose=verbose, filename="query.log")
|
||||
|
||||
vector_store_args = {}
|
||||
for index, store in config.vector_store.items():
|
||||
@ -1119,7 +1130,7 @@ def basic_search_streaming(
|
||||
msg = f"Vector Store Args: {redact(vector_store_args)}"
|
||||
logger.debug(msg)
|
||||
|
||||
description_embedding_store = get_embedding_store(
|
||||
embedding_store = get_embedding_store(
|
||||
config_args=vector_store_args,
|
||||
embedding_name=text_unit_text_embedding,
|
||||
)
|
||||
@ -1130,7 +1141,7 @@ def basic_search_streaming(
|
||||
search_engine = get_basic_search_engine(
|
||||
config=config,
|
||||
text_units=read_indexer_text_units(text_units),
|
||||
text_unit_embeddings=description_embedding_store,
|
||||
text_unit_embeddings=embedding_store,
|
||||
system_prompt=prompt,
|
||||
callbacks=callbacks,
|
||||
)
|
||||
@ -1164,7 +1175,11 @@ async def multi_index_basic_search(
|
||||
-------
|
||||
TODO: Document the search response type and format.
|
||||
"""
|
||||
init_loggers(config=config, verbose=verbose)
|
||||
init_loggers(config=config, verbose=verbose, filename="query.log")
|
||||
|
||||
logger.warning(
|
||||
"Multi-index search is deprecated and will be removed in GraphRAG v3."
|
||||
)
|
||||
|
||||
# Streaming not supported yet
|
||||
if streaming:
|
||||
|
||||
128
graphrag/cache/factory.py
vendored
128
graphrag/cache/factory.py
vendored
@ -1,23 +1,24 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""A module containing create_cache method definition."""
|
||||
"""Factory functions for creating a cache."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, ClassVar
|
||||
|
||||
from graphrag.config.enums import CacheType
|
||||
from graphrag.storage.blob_pipeline_storage import create_blob_storage
|
||||
from graphrag.storage.cosmosdb_pipeline_storage import create_cosmosdb_storage
|
||||
from graphrag.storage.file_pipeline_storage import FilePipelineStorage
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from graphrag.cache.pipeline_cache import PipelineCache
|
||||
|
||||
from graphrag.cache.json_pipeline_cache import JsonPipelineCache
|
||||
from graphrag.cache.memory_pipeline_cache import InMemoryCache
|
||||
from graphrag.cache.noop_pipeline_cache import NoopPipelineCache
|
||||
from graphrag.config.enums import CacheType
|
||||
from graphrag.storage.blob_pipeline_storage import BlobPipelineStorage
|
||||
from graphrag.storage.cosmosdb_pipeline_storage import CosmosDBPipelineStorage
|
||||
from graphrag.storage.file_pipeline_storage import FilePipelineStorage
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Callable
|
||||
|
||||
from graphrag.cache.pipeline_cache import PipelineCache
|
||||
|
||||
|
||||
class CacheFactory:
|
||||
@ -25,39 +26,90 @@ class CacheFactory:
|
||||
|
||||
Includes a method for users to register a custom cache implementation.
|
||||
|
||||
Configuration arguments are passed to each cache implementation as kwargs (where possible)
|
||||
Configuration arguments are passed to each cache implementation as kwargs
|
||||
for individual enforcement of required/optional arguments.
|
||||
"""
|
||||
|
||||
cache_types: ClassVar[dict[str, type]] = {}
|
||||
_registry: ClassVar[dict[str, Callable[..., PipelineCache]]] = {}
|
||||
|
||||
@classmethod
|
||||
def register(cls, cache_type: str, cache: type):
|
||||
"""Register a custom cache implementation."""
|
||||
cls.cache_types[cache_type] = cache
|
||||
def register(cls, cache_type: str, creator: Callable[..., PipelineCache]) -> None:
|
||||
"""Register a custom cache implementation.
|
||||
|
||||
Args:
|
||||
cache_type: The type identifier for the cache.
|
||||
creator: A class or callable that creates an instance of PipelineCache.
|
||||
"""
|
||||
cls._registry[cache_type] = creator
|
||||
|
||||
@classmethod
|
||||
def create_cache(
|
||||
cls, cache_type: CacheType | str | None, root_dir: str, kwargs: dict
|
||||
) -> PipelineCache:
|
||||
"""Create or get a cache from the provided type."""
|
||||
if not cache_type:
|
||||
return NoopPipelineCache()
|
||||
match cache_type:
|
||||
case CacheType.none:
|
||||
return NoopPipelineCache()
|
||||
case CacheType.memory:
|
||||
return InMemoryCache()
|
||||
case CacheType.file:
|
||||
return JsonPipelineCache(
|
||||
FilePipelineStorage(root_dir=root_dir).child(kwargs["base_dir"])
|
||||
)
|
||||
case CacheType.blob:
|
||||
return JsonPipelineCache(create_blob_storage(**kwargs))
|
||||
case CacheType.cosmosdb:
|
||||
return JsonPipelineCache(create_cosmosdb_storage(**kwargs))
|
||||
case _:
|
||||
if cache_type in cls.cache_types:
|
||||
return cls.cache_types[cache_type](**kwargs)
|
||||
msg = f"Unknown cache type: {cache_type}"
|
||||
raise ValueError(msg)
|
||||
def create_cache(cls, cache_type: str, kwargs: dict) -> PipelineCache:
|
||||
"""Create a cache object from the provided type.
|
||||
|
||||
Args:
|
||||
cache_type: The type of cache to create.
|
||||
root_dir: The root directory for file-based caches.
|
||||
kwargs: Additional keyword arguments for the cache constructor.
|
||||
|
||||
Returns
|
||||
-------
|
||||
A PipelineCache instance.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError: If the cache type is not registered.
|
||||
"""
|
||||
if cache_type not in cls._registry:
|
||||
msg = f"Unknown cache type: {cache_type}"
|
||||
raise ValueError(msg)
|
||||
|
||||
return cls._registry[cache_type](**kwargs)
|
||||
|
||||
@classmethod
|
||||
def get_cache_types(cls) -> list[str]:
|
||||
"""Get the registered cache implementations."""
|
||||
return list(cls._registry.keys())
|
||||
|
||||
@classmethod
|
||||
def is_supported_type(cls, cache_type: str) -> bool:
|
||||
"""Check if the given cache type is supported."""
|
||||
return cache_type in cls._registry
|
||||
|
||||
|
||||
# --- register built-in cache implementations ---
|
||||
def create_file_cache(root_dir: str, base_dir: str, **kwargs) -> PipelineCache:
|
||||
"""Create a file-based cache implementation."""
|
||||
# Create storage with base_dir in kwargs since FilePipelineStorage expects it there
|
||||
storage_kwargs = {"base_dir": root_dir, **kwargs}
|
||||
storage = FilePipelineStorage(**storage_kwargs).child(base_dir)
|
||||
return JsonPipelineCache(storage)
|
||||
|
||||
|
||||
def create_blob_cache(**kwargs) -> PipelineCache:
|
||||
"""Create a blob storage-based cache implementation."""
|
||||
storage = BlobPipelineStorage(**kwargs)
|
||||
return JsonPipelineCache(storage)
|
||||
|
||||
|
||||
def create_cosmosdb_cache(**kwargs) -> PipelineCache:
|
||||
"""Create a CosmosDB-based cache implementation."""
|
||||
storage = CosmosDBPipelineStorage(**kwargs)
|
||||
return JsonPipelineCache(storage)
|
||||
|
||||
|
||||
def create_noop_cache(**_kwargs) -> PipelineCache:
|
||||
"""Create a no-op cache implementation."""
|
||||
return NoopPipelineCache()
|
||||
|
||||
|
||||
def create_memory_cache(**kwargs) -> PipelineCache:
|
||||
"""Create a memory cache implementation."""
|
||||
return InMemoryCache(**kwargs)
|
||||
|
||||
|
||||
# --- register built-in cache implementations ---
|
||||
CacheFactory.register(CacheType.none.value, create_noop_cache)
|
||||
CacheFactory.register(CacheType.memory.value, create_memory_cache)
|
||||
CacheFactory.register(CacheType.file.value, create_file_cache)
|
||||
CacheFactory.register(CacheType.blob.value, create_blob_cache)
|
||||
CacheFactory.register(CacheType.cosmosdb.value, create_cosmosdb_cache)
|
||||
|
||||
46
graphrag/callbacks/console_workflow_callbacks.py
Normal file
46
graphrag/callbacks/console_workflow_callbacks.py
Normal file
@ -0,0 +1,46 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""A logger that emits updates from the indexing engine to the console."""
|
||||
|
||||
from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks
|
||||
from graphrag.index.typing.pipeline_run_result import PipelineRunResult
|
||||
from graphrag.logger.progress import Progress
|
||||
|
||||
# ruff: noqa: T201
|
||||
|
||||
|
||||
class ConsoleWorkflowCallbacks(NoopWorkflowCallbacks):
|
||||
"""A logger that writes to a console."""
|
||||
|
||||
_verbose = False
|
||||
|
||||
def __init__(self, verbose=False):
|
||||
self._verbose = verbose
|
||||
|
||||
def pipeline_start(self, names: list[str]) -> None:
|
||||
"""Execute this callback to signal when the entire pipeline starts."""
|
||||
print("Starting pipeline with workflows:", ", ".join(names))
|
||||
|
||||
def pipeline_end(self, results: list[PipelineRunResult]) -> None:
|
||||
"""Execute this callback to signal when the entire pipeline ends."""
|
||||
print("Pipeline complete")
|
||||
|
||||
def workflow_start(self, name: str, instance: object) -> None:
|
||||
"""Execute this callback when a workflow starts."""
|
||||
print(f"Starting workflow: {name}")
|
||||
|
||||
def workflow_end(self, name: str, instance: object) -> None:
|
||||
"""Execute this callback when a workflow ends."""
|
||||
print("") # account for potential return on prior progress
|
||||
print(f"Workflow complete: {name}")
|
||||
if self._verbose:
|
||||
print(instance)
|
||||
|
||||
def progress(self, progress: Progress) -> None:
|
||||
"""Handle when progress occurs."""
|
||||
complete = progress.completed_items or 0
|
||||
total = progress.total_items or 1
|
||||
percent = round((complete / total) * 100)
|
||||
start = f" {complete} / {total} "
|
||||
print(f"{start:{'.'}<{percent}}", flush=True, end="\r")
|
||||
@ -10,7 +10,8 @@ import warnings
|
||||
from pathlib import Path
|
||||
|
||||
import graphrag.api as api
|
||||
from graphrag.config.enums import CacheType, IndexingMethod, ReportingType
|
||||
from graphrag.callbacks.console_workflow_callbacks import ConsoleWorkflowCallbacks
|
||||
from graphrag.config.enums import CacheType, IndexingMethod
|
||||
from graphrag.config.load_config import load_config
|
||||
from graphrag.index.validate_config import validate_config_names
|
||||
from graphrag.utils.cli import redact
|
||||
@ -115,24 +116,12 @@ def _run_index(
|
||||
# Initialize loggers and reporting config
|
||||
init_loggers(
|
||||
config=config,
|
||||
root_dir=str(config.root_dir) if config.root_dir else None,
|
||||
verbose=verbose,
|
||||
)
|
||||
|
||||
if not cache:
|
||||
config.cache.type = CacheType.none
|
||||
|
||||
# Log the configuration details
|
||||
if config.reporting.type == ReportingType.file:
|
||||
log_dir = Path(config.root_dir or "") / (config.reporting.base_dir or "")
|
||||
log_path = log_dir / "logs.txt"
|
||||
logger.info("Logging enabled at %s", log_path)
|
||||
else:
|
||||
logger.info(
|
||||
"Logging not enabled for config %s",
|
||||
redact(config.model_dump()),
|
||||
)
|
||||
|
||||
if not skip_validation:
|
||||
validate_config_names(config)
|
||||
|
||||
@ -154,6 +143,8 @@ def _run_index(
|
||||
method=method,
|
||||
is_update_run=is_update_run,
|
||||
memory_profile=memprofile,
|
||||
callbacks=[ConsoleWorkflowCallbacks(verbose=verbose)],
|
||||
verbose=verbose,
|
||||
)
|
||||
)
|
||||
encountered_errors = any(
|
||||
|
||||
@ -7,7 +7,6 @@ import logging
|
||||
from pathlib import Path
|
||||
|
||||
import graphrag.api as api
|
||||
from graphrag.config.enums import ReportingType
|
||||
from graphrag.config.load_config import load_config
|
||||
from graphrag.prompt_tune.generator.community_report_summarization import (
|
||||
COMMUNITY_SUMMARIZATION_FILENAME,
|
||||
@ -73,22 +72,13 @@ async def prompt_tune(
|
||||
from graphrag.logger.standard_logging import init_loggers
|
||||
|
||||
# initialize loggers with config
|
||||
init_loggers(
|
||||
config=graph_config,
|
||||
root_dir=str(root_path),
|
||||
verbose=verbose,
|
||||
)
|
||||
init_loggers(config=graph_config, verbose=verbose, filename="prompt-tuning.log")
|
||||
|
||||
# log the configuration details
|
||||
if graph_config.reporting.type == ReportingType.file:
|
||||
log_dir = Path(root_path) / (graph_config.reporting.base_dir or "")
|
||||
log_path = log_dir / "logs.txt"
|
||||
logger.info("Logging enabled at %s", log_path)
|
||||
else:
|
||||
logger.info(
|
||||
"Logging not enabled for config %s",
|
||||
redact(graph_config.model_dump()),
|
||||
)
|
||||
logger.info("Starting prompt tune.")
|
||||
logger.info(
|
||||
"Using default configuration: %s",
|
||||
redact(graph_config.model_dump()),
|
||||
)
|
||||
|
||||
prompts = await api.generate_indexing_prompts(
|
||||
config=graph_config,
|
||||
@ -103,6 +93,7 @@ async def prompt_tune(
|
||||
min_examples_required=min_examples_required,
|
||||
n_subset_max=n_subset_max,
|
||||
k=k,
|
||||
verbose=verbose,
|
||||
)
|
||||
|
||||
output_path = output.resolve()
|
||||
|
||||
@ -4,7 +4,6 @@
|
||||
"""CLI implementation of the query subcommand."""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any
|
||||
@ -19,8 +18,7 @@ from graphrag.utils.storage import load_table_from_storage, storage_has_table
|
||||
if TYPE_CHECKING:
|
||||
import pandas as pd
|
||||
|
||||
# Initialize standard logger
|
||||
logger = logging.getLogger(__name__)
|
||||
# ruff: noqa: T201
|
||||
|
||||
|
||||
def run_global_search(
|
||||
@ -61,10 +59,6 @@ def run_global_search(
|
||||
final_community_reports_list = dataframe_dict["community_reports"]
|
||||
index_names = dataframe_dict["index_names"]
|
||||
|
||||
logger.info(
|
||||
"Running multi-index global search on indexes: %s",
|
||||
dataframe_dict["index_names"],
|
||||
)
|
||||
response, context_data = asyncio.run(
|
||||
api.multi_index_global_search(
|
||||
config=config,
|
||||
@ -80,11 +74,7 @@ def run_global_search(
|
||||
verbose=verbose,
|
||||
)
|
||||
)
|
||||
# log the full response at INFO level for user visibility but at DEBUG level in the API layer
|
||||
logger.info("Query Response:\n%s", response)
|
||||
|
||||
# NOTE: we return the response and context data here purely as a complete demonstration of the API.
|
||||
# External users should use the API directly to get the response and context data.
|
||||
print(response)
|
||||
return response, context_data
|
||||
|
||||
# Otherwise, call the Single-Index Global Search API
|
||||
@ -118,9 +108,9 @@ def run_global_search(
|
||||
verbose=verbose,
|
||||
):
|
||||
full_response += stream_chunk
|
||||
print(stream_chunk, end="") # noqa: T201
|
||||
sys.stdout.flush() # flush output buffer to display text immediately
|
||||
print() # noqa: T201
|
||||
print(stream_chunk, end="")
|
||||
sys.stdout.flush()
|
||||
print()
|
||||
return full_response, context_data
|
||||
|
||||
return asyncio.run(run_streaming_search())
|
||||
@ -138,11 +128,8 @@ def run_global_search(
|
||||
verbose=verbose,
|
||||
)
|
||||
)
|
||||
# log the full response at INFO level for user visibility but at DEBUG level in the API layer
|
||||
logger.info("Global Search Response:\n%s", response)
|
||||
print(response)
|
||||
|
||||
# NOTE: we return the response and context data here purely as a complete demonstration of the API.
|
||||
# External users should use the API directly to get the response and context data.
|
||||
return response, context_data
|
||||
|
||||
|
||||
@ -188,11 +175,6 @@ def run_local_search(
|
||||
final_relationships_list = dataframe_dict["relationships"]
|
||||
index_names = dataframe_dict["index_names"]
|
||||
|
||||
logger.info(
|
||||
"Running multi-index local search on indexes: %s",
|
||||
dataframe_dict["index_names"],
|
||||
)
|
||||
|
||||
# If any covariates tables are missing from any index, set the covariates list to None
|
||||
if len(dataframe_dict["covariates"]) != dataframe_dict["num_indexes"]:
|
||||
final_covariates_list = None
|
||||
@ -216,11 +198,8 @@ def run_local_search(
|
||||
verbose=verbose,
|
||||
)
|
||||
)
|
||||
# log the full response at INFO level for user visibility but at DEBUG level in the API layer
|
||||
logger.info("Local Search Response:\n%s", response)
|
||||
print(response)
|
||||
|
||||
# NOTE: we return the response and context data here purely as a complete demonstration of the API.
|
||||
# External users should use the API directly to get the response and context data.
|
||||
return response, context_data
|
||||
|
||||
# Otherwise, call the Single-Index Local Search API
|
||||
@ -259,9 +238,9 @@ def run_local_search(
|
||||
verbose=verbose,
|
||||
):
|
||||
full_response += stream_chunk
|
||||
print(stream_chunk, end="") # noqa: T201
|
||||
sys.stdout.flush() # flush output buffer to display text immediately
|
||||
print() # noqa: T201
|
||||
print(stream_chunk, end="")
|
||||
sys.stdout.flush()
|
||||
print()
|
||||
return full_response, context_data
|
||||
|
||||
return asyncio.run(run_streaming_search())
|
||||
@ -281,11 +260,8 @@ def run_local_search(
|
||||
verbose=verbose,
|
||||
)
|
||||
)
|
||||
# log the full response at INFO level for user visibility but at DEBUG level in the API layer
|
||||
logger.info("Local Search Response:\n%s", response)
|
||||
print(response)
|
||||
|
||||
# NOTE: we return the response and context data here purely as a complete demonstration of the API.
|
||||
# External users should use the API directly to get the response and context data.
|
||||
return response, context_data
|
||||
|
||||
|
||||
@ -329,11 +305,6 @@ def run_drift_search(
|
||||
final_relationships_list = dataframe_dict["relationships"]
|
||||
index_names = dataframe_dict["index_names"]
|
||||
|
||||
logger.info(
|
||||
"Running multi-index drift search on indexes: %s",
|
||||
dataframe_dict["index_names"],
|
||||
)
|
||||
|
||||
response, context_data = asyncio.run(
|
||||
api.multi_index_drift_search(
|
||||
config=config,
|
||||
@ -350,11 +321,8 @@ def run_drift_search(
|
||||
verbose=verbose,
|
||||
)
|
||||
)
|
||||
# log the full response at INFO level for user visibility but at DEBUG level in the API layer
|
||||
logger.info("DRIFT Search Response:\n%s", response)
|
||||
print(response)
|
||||
|
||||
# NOTE: we return the response and context data here purely as a complete demonstration of the API.
|
||||
# External users should use the API directly to get the response and context data.
|
||||
return response, context_data
|
||||
|
||||
# Otherwise, call the Single-Index Drift Search API
|
||||
@ -391,9 +359,9 @@ def run_drift_search(
|
||||
verbose=verbose,
|
||||
):
|
||||
full_response += stream_chunk
|
||||
print(stream_chunk, end="") # noqa: T201
|
||||
sys.stdout.flush() # flush output buffer to display text immediately
|
||||
print() # noqa: T201
|
||||
print(stream_chunk, end="")
|
||||
sys.stdout.flush()
|
||||
print()
|
||||
return full_response, context_data
|
||||
|
||||
return asyncio.run(run_streaming_search())
|
||||
@ -413,11 +381,8 @@ def run_drift_search(
|
||||
verbose=verbose,
|
||||
)
|
||||
)
|
||||
# log the full response at INFO level for user visibility but at DEBUG level in the API layer
|
||||
logger.info("DRIFT Search Response:\n%s", response)
|
||||
print(response)
|
||||
|
||||
# NOTE: we return the response and context data here purely as a complete demonstration of the API.
|
||||
# External users should use the API directly to get the response and context data.
|
||||
return response, context_data
|
||||
|
||||
|
||||
@ -451,11 +416,6 @@ def run_basic_search(
|
||||
final_text_units_list = dataframe_dict["text_units"]
|
||||
index_names = dataframe_dict["index_names"]
|
||||
|
||||
logger.info(
|
||||
"Running multi-index basic search on indexes: %s",
|
||||
dataframe_dict["index_names"],
|
||||
)
|
||||
|
||||
response, context_data = asyncio.run(
|
||||
api.multi_index_basic_search(
|
||||
config=config,
|
||||
@ -466,11 +426,8 @@ def run_basic_search(
|
||||
verbose=verbose,
|
||||
)
|
||||
)
|
||||
# log the full response at INFO level for user visibility but at DEBUG level in the API layer
|
||||
logger.info("Basic Search Response:\n%s", response)
|
||||
print(response)
|
||||
|
||||
# NOTE: we return the response and context data here purely as a complete demonstration of the API.
|
||||
# External users should use the API directly to get the response and context data.
|
||||
return response, context_data
|
||||
|
||||
# Otherwise, call the Single-Index Basic Search API
|
||||
@ -497,9 +454,9 @@ def run_basic_search(
|
||||
verbose=verbose,
|
||||
):
|
||||
full_response += stream_chunk
|
||||
print(stream_chunk, end="") # noqa: T201
|
||||
sys.stdout.flush() # flush output buffer to display text immediately
|
||||
print() # noqa: T201
|
||||
print(stream_chunk, end="")
|
||||
sys.stdout.flush()
|
||||
print()
|
||||
return full_response, context_data
|
||||
|
||||
return asyncio.run(run_streaming_search())
|
||||
@ -512,11 +469,8 @@ def run_basic_search(
|
||||
verbose=verbose,
|
||||
)
|
||||
)
|
||||
# log the full response at INFO level for user visibility but at DEBUG level in the API layer
|
||||
logger.info("Basic Search Response:\n%s", response)
|
||||
print(response)
|
||||
|
||||
# NOTE: we return the response and context data here purely as a complete demonstration of the API.
|
||||
# External users should use the API directly to get the response and context data.
|
||||
return response, context_data
|
||||
|
||||
|
||||
|
||||
@ -3,9 +3,10 @@
|
||||
|
||||
"""Common default configuration values."""
|
||||
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import ClassVar, Literal
|
||||
from typing import ClassVar
|
||||
|
||||
from graphrag.config.embeddings import default_embeddings
|
||||
from graphrag.config.enums import (
|
||||
@ -18,27 +19,59 @@ from graphrag.config.enums import (
|
||||
NounPhraseExtractorType,
|
||||
ReportingType,
|
||||
StorageType,
|
||||
VectorStoreType,
|
||||
)
|
||||
from graphrag.index.operations.build_noun_graph.np_extractors.stop_words import (
|
||||
EN_STOP_WORDS,
|
||||
)
|
||||
from graphrag.vector_stores.factory import VectorStoreType
|
||||
from graphrag.language_model.providers.litellm.services.rate_limiter.rate_limiter import (
|
||||
RateLimiter,
|
||||
)
|
||||
from graphrag.language_model.providers.litellm.services.rate_limiter.static_rate_limiter import (
|
||||
StaticRateLimiter,
|
||||
)
|
||||
from graphrag.language_model.providers.litellm.services.retry.exponential_retry import (
|
||||
ExponentialRetry,
|
||||
)
|
||||
from graphrag.language_model.providers.litellm.services.retry.incremental_wait_retry import (
|
||||
IncrementalWaitRetry,
|
||||
)
|
||||
from graphrag.language_model.providers.litellm.services.retry.native_wait_retry import (
|
||||
NativeRetry,
|
||||
)
|
||||
from graphrag.language_model.providers.litellm.services.retry.random_wait_retry import (
|
||||
RandomWaitRetry,
|
||||
)
|
||||
from graphrag.language_model.providers.litellm.services.retry.retry import Retry
|
||||
|
||||
DEFAULT_OUTPUT_BASE_DIR = "output"
|
||||
DEFAULT_CHAT_MODEL_ID = "default_chat_model"
|
||||
DEFAULT_CHAT_MODEL_TYPE = ModelType.OpenAIChat
|
||||
DEFAULT_CHAT_MODEL_TYPE = ModelType.Chat
|
||||
DEFAULT_CHAT_MODEL = "gpt-4-turbo-preview"
|
||||
DEFAULT_CHAT_MODEL_AUTH_TYPE = AuthType.APIKey
|
||||
DEFAULT_EMBEDDING_MODEL_ID = "default_embedding_model"
|
||||
DEFAULT_EMBEDDING_MODEL_TYPE = ModelType.OpenAIEmbedding
|
||||
DEFAULT_EMBEDDING_MODEL_TYPE = ModelType.Embedding
|
||||
DEFAULT_EMBEDDING_MODEL = "text-embedding-3-small"
|
||||
DEFAULT_EMBEDDING_MODEL_AUTH_TYPE = AuthType.APIKey
|
||||
DEFAULT_MODEL_PROVIDER = "openai"
|
||||
DEFAULT_VECTOR_STORE_ID = "default_vector_store"
|
||||
|
||||
ENCODING_MODEL = "cl100k_base"
|
||||
COGNITIVE_SERVICES_AUDIENCE = "https://cognitiveservices.azure.com/.default"
|
||||
|
||||
|
||||
DEFAULT_RETRY_SERVICES: dict[str, Callable[..., Retry]] = {
|
||||
"native": NativeRetry,
|
||||
"exponential_backoff": ExponentialRetry,
|
||||
"random_wait": RandomWaitRetry,
|
||||
"incremental_wait": IncrementalWaitRetry,
|
||||
}
|
||||
|
||||
DEFAULT_RATE_LIMITER_SERVICES: dict[str, Callable[..., RateLimiter]] = {
|
||||
"static": StaticRateLimiter,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class BasicSearchDefaults:
|
||||
"""Default values for basic search."""
|
||||
@ -214,6 +247,7 @@ class ExtractGraphNLPDefaults:
|
||||
normalize_edge_weights: bool = True
|
||||
text_analyzer: TextAnalyzerDefaults = field(default_factory=TextAnalyzerDefaults)
|
||||
concurrent_requests: int = 25
|
||||
async_mode: AsyncType = AsyncType.Threaded
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -239,7 +273,7 @@ class GlobalSearchDefaults:
|
||||
class StorageDefaults:
|
||||
"""Default values for storage."""
|
||||
|
||||
type = StorageType.file
|
||||
type: ClassVar[StorageType] = StorageType.file
|
||||
base_dir: str = DEFAULT_OUTPUT_BASE_DIR
|
||||
connection_string: None = None
|
||||
container_name: None = None
|
||||
@ -274,6 +308,7 @@ class LanguageModelDefaults:
|
||||
|
||||
api_key: None = None
|
||||
auth_type: ClassVar[AuthType] = AuthType.APIKey
|
||||
model_provider: str | None = None
|
||||
encoding_model: str = ""
|
||||
max_tokens: int | None = None
|
||||
temperature: float = 0
|
||||
@ -291,9 +326,10 @@ class LanguageModelDefaults:
|
||||
proxy: None = None
|
||||
audience: None = None
|
||||
model_supports_json: None = None
|
||||
tokens_per_minute: Literal["auto"] = "auto"
|
||||
requests_per_minute: Literal["auto"] = "auto"
|
||||
retry_strategy: str = "native"
|
||||
tokens_per_minute: None = None
|
||||
requests_per_minute: None = None
|
||||
rate_limit_strategy: str | None = "static"
|
||||
retry_strategy: str = "exponential_backoff"
|
||||
max_retries: int = 10
|
||||
max_retry_wait: float = 10.0
|
||||
concurrent_requests: int = 25
|
||||
@ -393,6 +429,7 @@ class VectorStoreDefaults:
|
||||
api_key: None = None
|
||||
audience: None = None
|
||||
database_name: None = None
|
||||
schema: None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@ -29,14 +29,14 @@ default_embeddings: list[str] = [
|
||||
]
|
||||
|
||||
|
||||
def create_collection_name(
|
||||
def create_index_name(
|
||||
container_name: str, embedding_name: str, validate: bool = True
|
||||
) -> str:
|
||||
"""
|
||||
Create a collection name for the embedding store.
|
||||
Create a index name for the embedding store.
|
||||
|
||||
Within any given vector store, we can have multiple sets of embeddings organized into projects.
|
||||
The `container` param is used for this partitioning, and is added as a prefix to the collection name for differentiation.
|
||||
The `container` param is used for this partitioning, and is added as a prefix to the index name for differentiation.
|
||||
|
||||
The embedding name is fixed, with the available list defined in graphrag.index.config.embeddings
|
||||
|
||||
|
||||
@ -59,6 +59,14 @@ class StorageType(str, Enum):
|
||||
return f'"{self.value}"'
|
||||
|
||||
|
||||
class VectorStoreType(str, Enum):
|
||||
"""The supported vector store types."""
|
||||
|
||||
LanceDB = "lancedb"
|
||||
AzureAISearch = "azure_ai_search"
|
||||
CosmosDB = "cosmosdb"
|
||||
|
||||
|
||||
class ReportingType(str, Enum):
|
||||
"""The reporting configuration type for the pipeline."""
|
||||
|
||||
@ -78,10 +86,12 @@ class ModelType(str, Enum):
|
||||
# Embeddings
|
||||
OpenAIEmbedding = "openai_embedding"
|
||||
AzureOpenAIEmbedding = "azure_openai_embedding"
|
||||
Embedding = "embedding"
|
||||
|
||||
# Chat Completion
|
||||
OpenAIChat = "openai_chat"
|
||||
AzureOpenAIChat = "azure_openai_chat"
|
||||
Chat = "chat"
|
||||
|
||||
# Debug
|
||||
MockChat = "mock_chat"
|
||||
|
||||
@ -33,15 +33,6 @@ class AzureApiVersionMissingError(ValueError):
|
||||
super().__init__(msg)
|
||||
|
||||
|
||||
class AzureDeploymentNameMissingError(ValueError):
|
||||
"""Azure Deployment Name missing error."""
|
||||
|
||||
def __init__(self, llm_type: str) -> None:
|
||||
"""Init method definition."""
|
||||
msg = f"Deployment name is required for {llm_type}. Please rerun `graphrag init` set the deployment_name."
|
||||
super().__init__(msg)
|
||||
|
||||
|
||||
class LanguageModelConfigMissingError(ValueError):
|
||||
"""Missing model configuration error."""
|
||||
|
||||
|
||||
@ -11,7 +11,6 @@ def get_embedding_settings(
|
||||
vector_store_params: dict | None = None,
|
||||
) -> dict:
|
||||
"""Transform GraphRAG config into settings for workflows."""
|
||||
# TEMP
|
||||
embeddings_llm_settings = settings.get_language_model_config(
|
||||
settings.embed_text.model_id
|
||||
)
|
||||
|
||||
@ -19,41 +19,34 @@ INIT_YAML = f"""\
|
||||
|
||||
models:
|
||||
{defs.DEFAULT_CHAT_MODEL_ID}:
|
||||
type: {defs.DEFAULT_CHAT_MODEL_TYPE.value} # or azure_openai_chat
|
||||
# api_base: https://<instance>.openai.azure.com
|
||||
# api_version: 2024-05-01-preview
|
||||
type: {defs.DEFAULT_CHAT_MODEL_TYPE.value}
|
||||
model_provider: {defs.DEFAULT_MODEL_PROVIDER}
|
||||
auth_type: {defs.DEFAULT_CHAT_MODEL_AUTH_TYPE.value} # or azure_managed_identity
|
||||
api_key: ${{GRAPHRAG_API_KEY}} # set this in the generated .env file
|
||||
# audience: "https://cognitiveservices.azure.com/.default"
|
||||
# organization: <organization_id>
|
||||
api_key: ${{GRAPHRAG_API_KEY}} # set this in the generated .env file, or remove if managed identity
|
||||
model: {defs.DEFAULT_CHAT_MODEL}
|
||||
# deployment_name: <azure_model_deployment_name>
|
||||
# encoding_model: {defs.ENCODING_MODEL} # automatically set by tiktoken if left undefined
|
||||
model_supports_json: true # recommended if this is available for your model.
|
||||
concurrent_requests: {language_model_defaults.concurrent_requests} # max number of simultaneous LLM requests allowed
|
||||
async_mode: {language_model_defaults.async_mode.value} # or asyncio
|
||||
retry_strategy: native
|
||||
max_retries: {language_model_defaults.max_retries}
|
||||
tokens_per_minute: {language_model_defaults.tokens_per_minute} # set to null to disable rate limiting
|
||||
requests_per_minute: {language_model_defaults.requests_per_minute} # set to null to disable rate limiting
|
||||
{defs.DEFAULT_EMBEDDING_MODEL_ID}:
|
||||
type: {defs.DEFAULT_EMBEDDING_MODEL_TYPE.value} # or azure_openai_embedding
|
||||
# api_base: https://<instance>.openai.azure.com
|
||||
# api_version: 2024-05-01-preview
|
||||
auth_type: {defs.DEFAULT_EMBEDDING_MODEL_AUTH_TYPE.value} # or azure_managed_identity
|
||||
api_key: ${{GRAPHRAG_API_KEY}}
|
||||
# audience: "https://cognitiveservices.azure.com/.default"
|
||||
# organization: <organization_id>
|
||||
model: {defs.DEFAULT_EMBEDDING_MODEL}
|
||||
# deployment_name: <azure_model_deployment_name>
|
||||
# encoding_model: {defs.ENCODING_MODEL} # automatically set by tiktoken if left undefined
|
||||
model_supports_json: true # recommended if this is available for your model.
|
||||
concurrent_requests: {language_model_defaults.concurrent_requests} # max number of simultaneous LLM requests allowed
|
||||
concurrent_requests: {language_model_defaults.concurrent_requests}
|
||||
async_mode: {language_model_defaults.async_mode.value} # or asyncio
|
||||
retry_strategy: native
|
||||
retry_strategy: {language_model_defaults.retry_strategy}
|
||||
max_retries: {language_model_defaults.max_retries}
|
||||
tokens_per_minute: null # set to null to disable rate limiting or auto for dynamic
|
||||
requests_per_minute: null # set to null to disable rate limiting or auto for dynamic
|
||||
tokens_per_minute: null
|
||||
requests_per_minute: null
|
||||
{defs.DEFAULT_EMBEDDING_MODEL_ID}:
|
||||
type: {defs.DEFAULT_EMBEDDING_MODEL_TYPE.value}
|
||||
model_provider: {defs.DEFAULT_MODEL_PROVIDER}
|
||||
auth_type: {defs.DEFAULT_EMBEDDING_MODEL_AUTH_TYPE.value}
|
||||
api_key: ${{GRAPHRAG_API_KEY}}
|
||||
model: {defs.DEFAULT_EMBEDDING_MODEL}
|
||||
# api_base: https://<instance>.openai.azure.com
|
||||
# api_version: 2024-05-01-preview
|
||||
concurrent_requests: {language_model_defaults.concurrent_requests}
|
||||
async_mode: {language_model_defaults.async_mode.value} # or asyncio
|
||||
retry_strategy: {language_model_defaults.retry_strategy}
|
||||
max_retries: {language_model_defaults.max_retries}
|
||||
tokens_per_minute: null
|
||||
requests_per_minute: null
|
||||
|
||||
### Input settings ###
|
||||
|
||||
@ -62,7 +55,6 @@ input:
|
||||
type: {graphrag_config_defaults.input.storage.type.value} # or blob
|
||||
base_dir: "{graphrag_config_defaults.input.storage.base_dir}"
|
||||
file_type: {graphrag_config_defaults.input.file_type.value} # [csv, text, json]
|
||||
|
||||
|
||||
chunks:
|
||||
size: {graphrag_config_defaults.chunks.size}
|
||||
@ -82,7 +74,7 @@ cache:
|
||||
base_dir: "{graphrag_config_defaults.cache.base_dir}"
|
||||
|
||||
reporting:
|
||||
type: {graphrag_config_defaults.reporting.type.value} # [file, blob, cosmosdb]
|
||||
type: {graphrag_config_defaults.reporting.type.value} # [file, blob]
|
||||
base_dir: "{graphrag_config_defaults.reporting.base_dir}"
|
||||
|
||||
vector_store:
|
||||
@ -90,7 +82,6 @@ vector_store:
|
||||
type: {vector_store_defaults.type}
|
||||
db_uri: {vector_store_defaults.db_uri}
|
||||
container_name: {vector_store_defaults.container_name}
|
||||
overwrite: {vector_store_defaults.overwrite}
|
||||
|
||||
### Workflow settings ###
|
||||
|
||||
@ -112,6 +103,7 @@ summarize_descriptions:
|
||||
extract_graph_nlp:
|
||||
text_analyzer:
|
||||
extractor_type: {graphrag_config_defaults.extract_graph_nlp.text_analyzer.extractor_type.value} # [regex_english, syntactic_parser, cfg]
|
||||
async_mode: {graphrag_config_defaults.extract_graph_nlp.async_mode.value} # or asyncio
|
||||
|
||||
cluster_graph:
|
||||
max_cluster_size: {graphrag_config_defaults.cluster_graph.max_cluster_size}
|
||||
|
||||
@ -12,7 +12,7 @@ from graphrag.config.enums import CacheType
|
||||
class CacheConfig(BaseModel):
|
||||
"""The default configuration section for Cache."""
|
||||
|
||||
type: CacheType = Field(
|
||||
type: CacheType | str = Field(
|
||||
description="The cache type to use.",
|
||||
default=graphrag_config_defaults.cache.type,
|
||||
)
|
||||
|
||||
@ -6,7 +6,7 @@
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from graphrag.config.defaults import graphrag_config_defaults
|
||||
from graphrag.config.enums import NounPhraseExtractorType
|
||||
from graphrag.config.enums import AsyncType, NounPhraseExtractorType
|
||||
|
||||
|
||||
class TextAnalyzerConfig(BaseModel):
|
||||
@ -68,3 +68,7 @@ class ExtractGraphNLPConfig(BaseModel):
|
||||
description="The number of threads to use for the extraction process.",
|
||||
default=graphrag_config_defaults.extract_graph_nlp.concurrent_requests,
|
||||
)
|
||||
async_mode: AsyncType = Field(
|
||||
description="The async mode to use.",
|
||||
default=graphrag_config_defaults.extract_graph_nlp.async_mode,
|
||||
)
|
||||
|
||||
@ -11,6 +11,7 @@ from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
import graphrag.config.defaults as defs
|
||||
from graphrag.config.defaults import graphrag_config_defaults
|
||||
from graphrag.config.enums import VectorStoreType
|
||||
from graphrag.config.errors import LanguageModelConfigMissingError
|
||||
from graphrag.config.models.basic_search_config import BasicSearchConfig
|
||||
from graphrag.config.models.cache_config import CacheConfig
|
||||
@ -36,7 +37,12 @@ from graphrag.config.models.summarize_descriptions_config import (
|
||||
from graphrag.config.models.text_embedding_config import TextEmbeddingConfig
|
||||
from graphrag.config.models.umap_config import UmapConfig
|
||||
from graphrag.config.models.vector_store_config import VectorStoreConfig
|
||||
from graphrag.vector_stores.factory import VectorStoreType
|
||||
from graphrag.language_model.providers.litellm.services.rate_limiter.rate_limiter_factory import (
|
||||
RateLimiterFactory,
|
||||
)
|
||||
from graphrag.language_model.providers.litellm.services.retry.retry_factory import (
|
||||
RetryFactory,
|
||||
)
|
||||
|
||||
|
||||
class GraphRagConfig(BaseModel):
|
||||
@ -89,6 +95,47 @@ class GraphRagConfig(BaseModel):
|
||||
if defs.DEFAULT_EMBEDDING_MODEL_ID not in self.models:
|
||||
raise LanguageModelConfigMissingError(defs.DEFAULT_EMBEDDING_MODEL_ID)
|
||||
|
||||
def _validate_retry_services(self) -> None:
|
||||
"""Validate the retry services configuration."""
|
||||
retry_factory = RetryFactory()
|
||||
|
||||
for model_id, model in self.models.items():
|
||||
if model.retry_strategy != "none":
|
||||
if model.retry_strategy not in retry_factory:
|
||||
msg = f"Retry strategy '{model.retry_strategy}' for model '{model_id}' is not registered. Available strategies: {', '.join(retry_factory.keys())}"
|
||||
raise ValueError(msg)
|
||||
|
||||
_ = retry_factory.create(
|
||||
strategy=model.retry_strategy,
|
||||
max_retries=model.max_retries,
|
||||
max_retry_wait=model.max_retry_wait,
|
||||
)
|
||||
|
||||
def _validate_rate_limiter_services(self) -> None:
|
||||
"""Validate the rate limiter services configuration."""
|
||||
rate_limiter_factory = RateLimiterFactory()
|
||||
|
||||
for model_id, model in self.models.items():
|
||||
if model.rate_limit_strategy is not None:
|
||||
if model.rate_limit_strategy not in rate_limiter_factory:
|
||||
msg = f"Rate Limiter strategy '{model.rate_limit_strategy}' for model '{model_id}' is not registered. Available strategies: {', '.join(rate_limiter_factory.keys())}"
|
||||
raise ValueError(msg)
|
||||
|
||||
rpm = (
|
||||
model.requests_per_minute
|
||||
if type(model.requests_per_minute) is int
|
||||
else None
|
||||
)
|
||||
tpm = (
|
||||
model.tokens_per_minute
|
||||
if type(model.tokens_per_minute) is int
|
||||
else None
|
||||
)
|
||||
if rpm is not None or tpm is not None:
|
||||
_ = rate_limiter_factory.create(
|
||||
strategy=model.rate_limit_strategy, rpm=rpm, tpm=tpm
|
||||
)
|
||||
|
||||
input: InputConfig = Field(
|
||||
description="The input configuration.", default=InputConfig()
|
||||
)
|
||||
@ -300,6 +347,11 @@ class GraphRagConfig(BaseModel):
|
||||
raise ValueError(msg)
|
||||
store.db_uri = str((Path(self.root_dir) / store.db_uri).resolve())
|
||||
|
||||
def _validate_factories(self) -> None:
|
||||
"""Validate the factories used in the configuration."""
|
||||
self._validate_retry_services()
|
||||
self._validate_rate_limiter_services()
|
||||
|
||||
def get_language_model_config(self, model_id: str) -> LanguageModelConfig:
|
||||
"""Get a model configuration by ID.
|
||||
|
||||
@ -360,4 +412,5 @@ class GraphRagConfig(BaseModel):
|
||||
self._validate_multi_output_base_dirs()
|
||||
self._validate_update_index_output_base_dir()
|
||||
self._validate_vector_store_db_uri()
|
||||
self._validate_factories()
|
||||
return self
|
||||
|
||||
@ -3,6 +3,7 @@
|
||||
|
||||
"""Language model configuration."""
|
||||
|
||||
import logging
|
||||
from typing import Literal
|
||||
|
||||
import tiktoken
|
||||
@ -14,11 +15,12 @@ from graphrag.config.errors import (
|
||||
ApiKeyMissingError,
|
||||
AzureApiBaseMissingError,
|
||||
AzureApiVersionMissingError,
|
||||
AzureDeploymentNameMissingError,
|
||||
ConflictingSettingsError,
|
||||
)
|
||||
from graphrag.language_model.factory import ModelFactory
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LanguageModelConfig(BaseModel):
|
||||
"""Language model configuration."""
|
||||
@ -73,8 +75,11 @@ class LanguageModelConfig(BaseModel):
|
||||
ConflictingSettingsError
|
||||
If the Azure authentication type conflicts with the model being used.
|
||||
"""
|
||||
if self.auth_type == AuthType.AzureManagedIdentity and (
|
||||
self.type == ModelType.OpenAIChat or self.type == ModelType.OpenAIEmbedding
|
||||
if (
|
||||
self.auth_type == AuthType.AzureManagedIdentity
|
||||
and self.type != ModelType.AzureOpenAIChat
|
||||
and self.type != ModelType.AzureOpenAIEmbedding
|
||||
and self.model_provider != "azure" # indicates Litellm + AOI
|
||||
):
|
||||
msg = f"auth_type of azure_managed_identity is not supported for model type {self.type}. Please rerun `graphrag init` and set the auth_type to api_key."
|
||||
raise ConflictingSettingsError(msg)
|
||||
@ -93,6 +98,35 @@ class LanguageModelConfig(BaseModel):
|
||||
if not ModelFactory.is_supported_model(self.type):
|
||||
msg = f"Model type {self.type} is not recognized, must be one of {ModelFactory.get_chat_models() + ModelFactory.get_embedding_models()}."
|
||||
raise KeyError(msg)
|
||||
if self.type in [
|
||||
"openai_chat",
|
||||
"openai_embedding",
|
||||
"azure_openai_chat",
|
||||
"azure_openai_embedding",
|
||||
]:
|
||||
msg = f"Model config based on fnllm is deprecated and will be removed in GraphRAG v3, please use {ModelType.Chat} or {ModelType.Embedding} instead to switch to LiteLLM config."
|
||||
logger.warning(msg)
|
||||
|
||||
model_provider: str | None = Field(
|
||||
description="The model provider to use.",
|
||||
default=language_model_defaults.model_provider,
|
||||
)
|
||||
|
||||
def _validate_model_provider(self) -> None:
|
||||
"""Validate the model provider.
|
||||
|
||||
Required when using Litellm.
|
||||
|
||||
Raises
|
||||
------
|
||||
KeyError
|
||||
If the model provider is not recognized.
|
||||
"""
|
||||
if (self.type == ModelType.Chat or self.type == ModelType.Embedding) and (
|
||||
self.model_provider is None or self.model_provider.strip() == ""
|
||||
):
|
||||
msg = f"Model provider must be specified when using type == {self.type}."
|
||||
raise KeyError(msg)
|
||||
|
||||
model: str = Field(description="The LLM model to use.")
|
||||
encoding_model: str = Field(
|
||||
@ -103,12 +137,27 @@ class LanguageModelConfig(BaseModel):
|
||||
def _validate_encoding_model(self) -> None:
|
||||
"""Validate the encoding model.
|
||||
|
||||
The default behavior is to use an encoding model that matches the LLM model.
|
||||
LiteLLM supports 100+ models and their tokenization. There is no need to
|
||||
set the encoding model when using the new LiteLLM provider as was done with fnllm provider.
|
||||
|
||||
Users can still manually specify a tiktoken based encoding model to use even with the LiteLLM provider
|
||||
in which case the specified encoding model will be used regardless of the LLM model being used, even if
|
||||
it is not an openai based model.
|
||||
|
||||
If not using LiteLLM provider, set the encoding model based on the LLM model name.
|
||||
This is for backward compatibility with existing fnllm provider until fnllm is removed.
|
||||
|
||||
Raises
|
||||
------
|
||||
KeyError
|
||||
If the model name is not recognized.
|
||||
"""
|
||||
if self.encoding_model.strip() == "":
|
||||
if (
|
||||
self.type != ModelType.Chat
|
||||
and self.type != ModelType.Embedding
|
||||
and self.encoding_model.strip() == ""
|
||||
):
|
||||
self.encoding_model = tiktoken.encoding_name_for_model(self.model)
|
||||
|
||||
api_base: str | None = Field(
|
||||
@ -129,6 +178,7 @@ class LanguageModelConfig(BaseModel):
|
||||
if (
|
||||
self.type == ModelType.AzureOpenAIChat
|
||||
or self.type == ModelType.AzureOpenAIEmbedding
|
||||
or self.model_provider == "azure" # indicates Litellm + AOI
|
||||
) and (self.api_base is None or self.api_base.strip() == ""):
|
||||
raise AzureApiBaseMissingError(self.type)
|
||||
|
||||
@ -150,6 +200,7 @@ class LanguageModelConfig(BaseModel):
|
||||
if (
|
||||
self.type == ModelType.AzureOpenAIChat
|
||||
or self.type == ModelType.AzureOpenAIEmbedding
|
||||
or self.model_provider == "azure" # indicates Litellm + AOI
|
||||
) and (self.api_version is None or self.api_version.strip() == ""):
|
||||
raise AzureApiVersionMissingError(self.type)
|
||||
|
||||
@ -171,8 +222,10 @@ class LanguageModelConfig(BaseModel):
|
||||
if (
|
||||
self.type == ModelType.AzureOpenAIChat
|
||||
or self.type == ModelType.AzureOpenAIEmbedding
|
||||
or self.model_provider == "azure" # indicates Litellm + AOI
|
||||
) and (self.deployment_name is None or self.deployment_name.strip() == ""):
|
||||
raise AzureDeploymentNameMissingError(self.type)
|
||||
msg = f"deployment_name is not set for Azure-hosted model. This will default to your model name ({self.model}). If different, this should be set."
|
||||
logger.debug(msg)
|
||||
|
||||
organization: str | None = Field(
|
||||
description="The organization to use for the LLM service.",
|
||||
@ -212,6 +265,14 @@ class LanguageModelConfig(BaseModel):
|
||||
msg = f"Tokens per minute must be a non zero positive number, 'auto' or null. Suggested value: {language_model_defaults.tokens_per_minute}."
|
||||
raise ValueError(msg)
|
||||
|
||||
if (
|
||||
(self.type == ModelType.Chat or self.type == ModelType.Embedding)
|
||||
and self.rate_limit_strategy is not None
|
||||
and self.tokens_per_minute == "auto"
|
||||
):
|
||||
msg = f"tokens_per_minute cannot be set to 'auto' when using type '{self.type}'. Please set it to a positive integer or null to disable."
|
||||
raise ValueError(msg)
|
||||
|
||||
requests_per_minute: int | Literal["auto"] | None = Field(
|
||||
description="The number of requests per minute to use for the LLM service.",
|
||||
default=language_model_defaults.requests_per_minute,
|
||||
@ -230,6 +291,19 @@ class LanguageModelConfig(BaseModel):
|
||||
msg = f"Requests per minute must be a non zero positive number, 'auto' or null. Suggested value: {language_model_defaults.requests_per_minute}."
|
||||
raise ValueError(msg)
|
||||
|
||||
if (
|
||||
(self.type == ModelType.Chat or self.type == ModelType.Embedding)
|
||||
and self.rate_limit_strategy is not None
|
||||
and self.requests_per_minute == "auto"
|
||||
):
|
||||
msg = f"requests_per_minute cannot be set to 'auto' when using type '{self.type}'. Please set it to a positive integer or null to disable."
|
||||
raise ValueError(msg)
|
||||
|
||||
rate_limit_strategy: str | None = Field(
|
||||
description="The rate limit strategy to use for the LLM service.",
|
||||
default=language_model_defaults.rate_limit_strategy,
|
||||
)
|
||||
|
||||
retry_strategy: str = Field(
|
||||
description="The retry strategy to use for the LLM service.",
|
||||
default=language_model_defaults.retry_strategy,
|
||||
@ -318,6 +392,7 @@ class LanguageModelConfig(BaseModel):
|
||||
@model_validator(mode="after")
|
||||
def _validate_model(self):
|
||||
self._validate_type()
|
||||
self._validate_model_provider()
|
||||
self._validate_auth_type()
|
||||
self._validate_api_key()
|
||||
self._validate_tokens_per_minute()
|
||||
|
||||
@ -12,7 +12,7 @@ from graphrag.config.enums import ReportingType
|
||||
class ReportingConfig(BaseModel):
|
||||
"""The default configuration section for Reporting."""
|
||||
|
||||
type: ReportingType = Field(
|
||||
type: ReportingType | str = Field(
|
||||
description="The reporting type to use.",
|
||||
default=graphrag_config_defaults.reporting.type,
|
||||
)
|
||||
|
||||
@ -14,7 +14,7 @@ from graphrag.config.enums import StorageType
|
||||
class StorageConfig(BaseModel):
|
||||
"""The default configuration section for storage."""
|
||||
|
||||
type: StorageType = Field(
|
||||
type: StorageType | str = Field(
|
||||
description="The storage type to use.",
|
||||
default=graphrag_config_defaults.storage.type,
|
||||
)
|
||||
|
||||
@ -39,7 +39,7 @@ class SummarizeDescriptionsConfig(BaseModel):
|
||||
self, root_dir: str, model_config: LanguageModelConfig
|
||||
) -> dict:
|
||||
"""Get the resolved description summarization strategy."""
|
||||
from graphrag.index.operations.summarize_descriptions import (
|
||||
from graphrag.index.operations.summarize_descriptions.summarize_descriptions import (
|
||||
SummarizeStrategyType,
|
||||
)
|
||||
|
||||
|
||||
@ -39,7 +39,7 @@ class TextEmbeddingConfig(BaseModel):
|
||||
|
||||
def resolved_strategy(self, model_config: LanguageModelConfig) -> dict:
|
||||
"""Get the resolved text embedding strategy."""
|
||||
from graphrag.index.operations.embed_text import (
|
||||
from graphrag.index.operations.embed_text.embed_text import (
|
||||
TextEmbedStrategyType,
|
||||
)
|
||||
|
||||
|
||||
@ -6,7 +6,9 @@
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
from graphrag.config.defaults import vector_store_defaults
|
||||
from graphrag.vector_stores.factory import VectorStoreType
|
||||
from graphrag.config.embeddings import all_embeddings
|
||||
from graphrag.config.enums import VectorStoreType
|
||||
from graphrag.config.models.vector_store_schema_config import VectorStoreSchemaConfig
|
||||
|
||||
|
||||
class VectorStoreConfig(BaseModel):
|
||||
@ -85,9 +87,25 @@ class VectorStoreConfig(BaseModel):
|
||||
default=vector_store_defaults.overwrite,
|
||||
)
|
||||
|
||||
embeddings_schema: dict[str, VectorStoreSchemaConfig] = {}
|
||||
|
||||
def _validate_embeddings_schema(self) -> None:
|
||||
"""Validate the embeddings schema."""
|
||||
for name in self.embeddings_schema:
|
||||
if name not in all_embeddings:
|
||||
msg = f"vector_store.embeddings_schema contains an invalid embedding schema name: {name}. Please update your settings.yaml and select the correct embedding schema names."
|
||||
raise ValueError(msg)
|
||||
|
||||
if self.type == VectorStoreType.CosmosDB:
|
||||
for id_field in self.embeddings_schema:
|
||||
if id_field != "id":
|
||||
msg = "When using CosmosDB, the id_field in embeddings_schema must be 'id'. Please update your settings.yaml and set the id_field to 'id'."
|
||||
raise ValueError(msg)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _validate_model(self):
|
||||
"""Validate the model."""
|
||||
self._validate_db_uri()
|
||||
self._validate_url()
|
||||
self._validate_embeddings_schema()
|
||||
return self
|
||||
|
||||
66
graphrag/config/models/vector_store_schema_config.py
Normal file
66
graphrag/config/models/vector_store_schema_config.py
Normal file
@ -0,0 +1,66 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""Parameterization settings for the default configuration."""
|
||||
|
||||
import re
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
DEFAULT_VECTOR_SIZE: int = 1536
|
||||
|
||||
VALID_IDENTIFIER_REGEX = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$")
|
||||
|
||||
|
||||
def is_valid_field_name(field: str) -> bool:
|
||||
"""Check if a field name is valid for CosmosDB."""
|
||||
return bool(VALID_IDENTIFIER_REGEX.match(field))
|
||||
|
||||
|
||||
class VectorStoreSchemaConfig(BaseModel):
|
||||
"""The default configuration section for Vector Store Schema."""
|
||||
|
||||
id_field: str = Field(
|
||||
description="The ID field to use.",
|
||||
default="id",
|
||||
)
|
||||
|
||||
vector_field: str = Field(
|
||||
description="The vector field to use.",
|
||||
default="vector",
|
||||
)
|
||||
|
||||
text_field: str = Field(
|
||||
description="The text field to use.",
|
||||
default="text",
|
||||
)
|
||||
|
||||
attributes_field: str = Field(
|
||||
description="The attributes field to use.",
|
||||
default="attributes",
|
||||
)
|
||||
|
||||
vector_size: int = Field(
|
||||
description="The vector size to use.",
|
||||
default=DEFAULT_VECTOR_SIZE,
|
||||
)
|
||||
|
||||
index_name: str | None = Field(description="The index name to use.", default=None)
|
||||
|
||||
def _validate_schema(self) -> None:
|
||||
"""Validate the schema."""
|
||||
for field in [
|
||||
self.id_field,
|
||||
self.vector_field,
|
||||
self.text_field,
|
||||
self.attributes_field,
|
||||
]:
|
||||
if not is_valid_field_name(field):
|
||||
msg = f"Unsafe or invalid field name: {field}"
|
||||
raise ValueError(msg)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _validate_model(self):
|
||||
"""Validate the model."""
|
||||
self._validate_schema()
|
||||
return self
|
||||
4
graphrag/factory/__init__.py
Normal file
4
graphrag/factory/__init__.py
Normal file
@ -0,0 +1,4 @@
|
||||
# Copyright (c) 2025 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""Factory module."""
|
||||
68
graphrag/factory/factory.py
Normal file
68
graphrag/factory/factory.py
Normal file
@ -0,0 +1,68 @@
|
||||
# Copyright (c) 2025 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""Factory ABC."""
|
||||
|
||||
from abc import ABC
|
||||
from collections.abc import Callable
|
||||
from typing import Any, ClassVar, Generic, TypeVar
|
||||
|
||||
T = TypeVar("T", covariant=True)
|
||||
|
||||
|
||||
class Factory(ABC, Generic[T]):
|
||||
"""Abstract base class for factories."""
|
||||
|
||||
_instance: ClassVar["Factory | None"] = None
|
||||
|
||||
def __new__(cls, *args: Any, **kwargs: Any) -> "Factory":
|
||||
"""Create a new instance of Factory if it does not exist."""
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls, *args, **kwargs)
|
||||
return cls._instance
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(self, "_initialized"):
|
||||
self._services: dict[str, Callable[..., T]] = {}
|
||||
self._initialized = True
|
||||
|
||||
def __contains__(self, strategy: str) -> bool:
|
||||
"""Check if a strategy is registered."""
|
||||
return strategy in self._services
|
||||
|
||||
def keys(self) -> list[str]:
|
||||
"""Get a list of registered strategy names."""
|
||||
return list(self._services.keys())
|
||||
|
||||
def register(self, *, strategy: str, service_initializer: Callable[..., T]) -> None:
|
||||
"""
|
||||
Register a new service.
|
||||
|
||||
Args
|
||||
----
|
||||
strategy: The name of the strategy.
|
||||
service_initializer: A callable that creates an instance of T.
|
||||
"""
|
||||
self._services[strategy] = service_initializer
|
||||
|
||||
def create(self, *, strategy: str, **kwargs: Any) -> T:
|
||||
"""
|
||||
Create a service instance based on the strategy.
|
||||
|
||||
Args
|
||||
----
|
||||
strategy: The name of the strategy.
|
||||
**kwargs: Additional arguments to pass to the service initializer.
|
||||
|
||||
Returns
|
||||
-------
|
||||
An instance of T.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError: If the strategy is not registered.
|
||||
"""
|
||||
if strategy not in self._services:
|
||||
msg = f"Strategy '{strategy}' is not registered."
|
||||
raise ValueError(msg)
|
||||
return self._services[strategy](**kwargs)
|
||||
@ -24,12 +24,17 @@ async def build_noun_graph(
|
||||
text_analyzer: BaseNounPhraseExtractor,
|
||||
normalize_edge_weights: bool,
|
||||
num_threads: int = 4,
|
||||
async_mode: AsyncType = AsyncType.Threaded,
|
||||
cache: PipelineCache | None = None,
|
||||
) -> tuple[pd.DataFrame, pd.DataFrame]:
|
||||
"""Build a noun graph from text units."""
|
||||
text_units = text_unit_df.loc[:, ["id", "text"]]
|
||||
nodes_df = await _extract_nodes(
|
||||
text_units, text_analyzer, num_threads=num_threads, cache=cache
|
||||
text_units,
|
||||
text_analyzer,
|
||||
num_threads=num_threads,
|
||||
async_mode=async_mode,
|
||||
cache=cache,
|
||||
)
|
||||
edges_df = _extract_edges(nodes_df, normalize_edge_weights=normalize_edge_weights)
|
||||
return (nodes_df, edges_df)
|
||||
@ -39,6 +44,7 @@ async def _extract_nodes(
|
||||
text_unit_df: pd.DataFrame,
|
||||
text_analyzer: BaseNounPhraseExtractor,
|
||||
num_threads: int = 4,
|
||||
async_mode: AsyncType = AsyncType.Threaded,
|
||||
cache: PipelineCache | None = None,
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
@ -64,7 +70,7 @@ async def _extract_nodes(
|
||||
text_unit_df,
|
||||
extract,
|
||||
num_threads=num_threads,
|
||||
async_type=AsyncType.Threaded,
|
||||
async_type=async_mode,
|
||||
progress_msg="extract noun phrases progress: ",
|
||||
)
|
||||
|
||||
|
||||
@ -11,7 +11,7 @@ import tiktoken
|
||||
from graphrag.config.models.chunking_config import ChunkingConfig
|
||||
from graphrag.index.operations.chunk_text.typing import TextChunk
|
||||
from graphrag.index.text_splitting.text_splitting import (
|
||||
Tokenizer,
|
||||
TokenChunkerOptions,
|
||||
split_multiple_texts_on_tokens,
|
||||
)
|
||||
from graphrag.logger.progress import ProgressTicker
|
||||
@ -45,7 +45,7 @@ def run_tokens(
|
||||
encode, decode = get_encoding_fn(encoding_name)
|
||||
return split_multiple_texts_on_tokens(
|
||||
input,
|
||||
Tokenizer(
|
||||
TokenChunkerOptions(
|
||||
chunk_overlap=chunk_overlap,
|
||||
tokens_per_chunk=tokens_per_chunk,
|
||||
encode=encode,
|
||||
|
||||
@ -2,10 +2,3 @@
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""The Indexing Engine text embed package root."""
|
||||
|
||||
from graphrag.index.operations.embed_text.embed_text import (
|
||||
TextEmbedStrategyType,
|
||||
embed_text,
|
||||
)
|
||||
|
||||
__all__ = ["TextEmbedStrategyType", "embed_text"]
|
||||
|
||||
@ -12,7 +12,8 @@ import pandas as pd
|
||||
|
||||
from graphrag.cache.pipeline_cache import PipelineCache
|
||||
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
|
||||
from graphrag.config.embeddings import create_collection_name
|
||||
from graphrag.config.embeddings import create_index_name
|
||||
from graphrag.config.models.vector_store_schema_config import VectorStoreSchemaConfig
|
||||
from graphrag.index.operations.embed_text.strategies.typing import TextEmbeddingStrategy
|
||||
from graphrag.vector_stores.base import BaseVectorStore, VectorStoreDocument
|
||||
from graphrag.vector_stores.factory import VectorStoreFactory
|
||||
@ -49,9 +50,9 @@ async def embed_text(
|
||||
vector_store_config = strategy.get("vector_store")
|
||||
|
||||
if vector_store_config:
|
||||
collection_name = _get_collection_name(vector_store_config, embedding_name)
|
||||
index_name = _get_index_name(vector_store_config, embedding_name)
|
||||
vector_store: BaseVectorStore = _create_vector_store(
|
||||
vector_store_config, collection_name
|
||||
vector_store_config, index_name, embedding_name
|
||||
)
|
||||
vector_store_workflow_config = vector_store_config.get(
|
||||
embedding_name, vector_store_config
|
||||
@ -183,27 +184,46 @@ async def _text_embed_with_vector_store(
|
||||
|
||||
|
||||
def _create_vector_store(
|
||||
vector_store_config: dict, collection_name: str
|
||||
vector_store_config: dict, index_name: str, embedding_name: str | None = None
|
||||
) -> BaseVectorStore:
|
||||
vector_store_type: str = str(vector_store_config.get("type"))
|
||||
if collection_name:
|
||||
vector_store_config.update({"collection_name": collection_name})
|
||||
|
||||
embeddings_schema: dict[str, VectorStoreSchemaConfig] = vector_store_config.get(
|
||||
"embeddings_schema", {}
|
||||
)
|
||||
single_embedding_config: VectorStoreSchemaConfig = VectorStoreSchemaConfig()
|
||||
|
||||
if (
|
||||
embeddings_schema is not None
|
||||
and embedding_name is not None
|
||||
and embedding_name in embeddings_schema
|
||||
):
|
||||
raw_config = embeddings_schema[embedding_name]
|
||||
if isinstance(raw_config, dict):
|
||||
single_embedding_config = VectorStoreSchemaConfig(**raw_config)
|
||||
else:
|
||||
single_embedding_config = raw_config
|
||||
|
||||
if single_embedding_config.index_name is None:
|
||||
single_embedding_config.index_name = index_name
|
||||
|
||||
vector_store = VectorStoreFactory().create_vector_store(
|
||||
vector_store_type, kwargs=vector_store_config
|
||||
vector_store_schema_config=single_embedding_config,
|
||||
vector_store_type=vector_store_type,
|
||||
**vector_store_config,
|
||||
)
|
||||
|
||||
vector_store.connect(**vector_store_config)
|
||||
return vector_store
|
||||
|
||||
|
||||
def _get_collection_name(vector_store_config: dict, embedding_name: str) -> str:
|
||||
def _get_index_name(vector_store_config: dict, embedding_name: str) -> str:
|
||||
container_name = vector_store_config.get("container_name", "default")
|
||||
collection_name = create_collection_name(container_name, embedding_name)
|
||||
index_name = create_index_name(container_name, embedding_name)
|
||||
|
||||
msg = f"using vector store {vector_store_config.get('type')} with container_name {container_name} for embedding {embedding_name}: {collection_name}"
|
||||
msg = f"using vector store {vector_store_config.get('type')} with container_name {container_name} for embedding {embedding_name}: {index_name}"
|
||||
logger.info(msg)
|
||||
return collection_name
|
||||
return index_name
|
||||
|
||||
|
||||
def load_strategy(strategy: TextEmbedStrategyType) -> TextEmbeddingStrategy:
|
||||
|
||||
@ -18,6 +18,7 @@ from graphrag.index.utils.is_null import is_null
|
||||
from graphrag.language_model.manager import ModelManager
|
||||
from graphrag.language_model.protocol.base import EmbeddingModel
|
||||
from graphrag.logger.progress import ProgressTicker, progress_ticker
|
||||
from graphrag.tokenizer.get_tokenizer import get_tokenizer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -79,7 +80,7 @@ def _get_splitter(
|
||||
config: LanguageModelConfig, batch_max_tokens: int
|
||||
) -> TokenTextSplitter:
|
||||
return TokenTextSplitter(
|
||||
encoding_name=config.encoding_model,
|
||||
tokenizer=get_tokenizer(model_config=config),
|
||||
chunk_size=batch_max_tokens,
|
||||
)
|
||||
|
||||
|
||||
@ -8,10 +8,12 @@ import graphrag.data_model.schemas as schemas
|
||||
from graphrag.index.operations.summarize_communities.graph_context.sort_context import (
|
||||
sort_context,
|
||||
)
|
||||
from graphrag.query.llm.text_utils import num_tokens
|
||||
from graphrag.tokenizer.tokenizer import Tokenizer
|
||||
|
||||
|
||||
def build_mixed_context(context: list[dict], max_context_tokens: int) -> str:
|
||||
def build_mixed_context(
|
||||
context: list[dict], tokenizer: Tokenizer, max_context_tokens: int
|
||||
) -> str:
|
||||
"""
|
||||
Build parent context by concatenating all sub-communities' contexts.
|
||||
|
||||
@ -45,9 +47,10 @@ def build_mixed_context(context: list[dict], max_context_tokens: int) -> str:
|
||||
remaining_local_context.extend(sorted_context[rid][schemas.ALL_CONTEXT])
|
||||
new_context_string = sort_context(
|
||||
local_context=remaining_local_context + final_local_contexts,
|
||||
tokenizer=tokenizer,
|
||||
sub_community_reports=substitute_reports,
|
||||
)
|
||||
if num_tokens(new_context_string) <= max_context_tokens:
|
||||
if tokenizer.num_tokens(new_context_string) <= max_context_tokens:
|
||||
exceeded_limit = False
|
||||
context_string = new_context_string
|
||||
break
|
||||
@ -63,7 +66,7 @@ def build_mixed_context(context: list[dict], max_context_tokens: int) -> str:
|
||||
new_context_string = pd.DataFrame(substitute_reports).to_csv(
|
||||
index=False, sep=","
|
||||
)
|
||||
if num_tokens(new_context_string) > max_context_tokens:
|
||||
if tokenizer.num_tokens(new_context_string) > max_context_tokens:
|
||||
break
|
||||
|
||||
context_string = new_context_string
|
||||
|
||||
@ -30,7 +30,7 @@ from graphrag.index.utils.dataframes import (
|
||||
where_column_equals,
|
||||
)
|
||||
from graphrag.logger.progress import progress_iterable
|
||||
from graphrag.query.llm.text_utils import num_tokens
|
||||
from graphrag.tokenizer.tokenizer import Tokenizer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -39,6 +39,7 @@ def build_local_context(
|
||||
nodes,
|
||||
edges,
|
||||
claims,
|
||||
tokenizer: Tokenizer,
|
||||
callbacks: WorkflowCallbacks,
|
||||
max_context_tokens: int = 16_000,
|
||||
):
|
||||
@ -49,7 +50,7 @@ def build_local_context(
|
||||
|
||||
for level in progress_iterable(levels, callbacks.progress, len(levels)):
|
||||
communities_at_level_df = _prepare_reports_at_level(
|
||||
nodes, edges, claims, level, max_context_tokens
|
||||
nodes, edges, claims, tokenizer, level, max_context_tokens
|
||||
)
|
||||
|
||||
communities_at_level_df.loc[:, schemas.COMMUNITY_LEVEL] = level
|
||||
@ -63,6 +64,7 @@ def _prepare_reports_at_level(
|
||||
node_df: pd.DataFrame,
|
||||
edge_df: pd.DataFrame,
|
||||
claim_df: pd.DataFrame | None,
|
||||
tokenizer: Tokenizer,
|
||||
level: int,
|
||||
max_context_tokens: int = 16_000,
|
||||
) -> pd.DataFrame:
|
||||
@ -181,6 +183,7 @@ def _prepare_reports_at_level(
|
||||
# Generate community-level context strings using vectorized batch processing
|
||||
return parallel_sort_context_batch(
|
||||
community_df,
|
||||
tokenizer=tokenizer,
|
||||
max_context_tokens=max_context_tokens,
|
||||
)
|
||||
|
||||
@ -189,6 +192,7 @@ def build_level_context(
|
||||
report_df: pd.DataFrame | None,
|
||||
community_hierarchy_df: pd.DataFrame,
|
||||
local_context_df: pd.DataFrame,
|
||||
tokenizer: Tokenizer,
|
||||
level: int,
|
||||
max_context_tokens: int,
|
||||
) -> pd.DataFrame:
|
||||
@ -219,11 +223,11 @@ def build_level_context(
|
||||
|
||||
if report_df is None or report_df.empty:
|
||||
invalid_context_df.loc[:, schemas.CONTEXT_STRING] = _sort_and_trim_context(
|
||||
invalid_context_df, max_context_tokens
|
||||
invalid_context_df, tokenizer, max_context_tokens
|
||||
)
|
||||
invalid_context_df[schemas.CONTEXT_SIZE] = invalid_context_df.loc[
|
||||
:, schemas.CONTEXT_STRING
|
||||
].map(num_tokens)
|
||||
].map(tokenizer.num_tokens)
|
||||
invalid_context_df[schemas.CONTEXT_EXCEED_FLAG] = False
|
||||
return union(valid_context_df, invalid_context_df)
|
||||
|
||||
@ -237,6 +241,7 @@ def build_level_context(
|
||||
invalid_context_df,
|
||||
sub_context_df,
|
||||
community_hierarchy_df,
|
||||
tokenizer,
|
||||
max_context_tokens,
|
||||
)
|
||||
|
||||
@ -244,11 +249,13 @@ def build_level_context(
|
||||
# this should be rare, but if it happens, we will just trim the local context to fit the limit
|
||||
remaining_df = _antijoin_reports(invalid_context_df, community_df)
|
||||
remaining_df.loc[:, schemas.CONTEXT_STRING] = _sort_and_trim_context(
|
||||
remaining_df, max_context_tokens
|
||||
remaining_df, tokenizer, max_context_tokens
|
||||
)
|
||||
|
||||
result = union(valid_context_df, community_df, remaining_df)
|
||||
result[schemas.CONTEXT_SIZE] = result.loc[:, schemas.CONTEXT_STRING].map(num_tokens)
|
||||
result[schemas.CONTEXT_SIZE] = result.loc[:, schemas.CONTEXT_STRING].map(
|
||||
tokenizer.num_tokens
|
||||
)
|
||||
|
||||
result[schemas.CONTEXT_EXCEED_FLAG] = False
|
||||
return result
|
||||
@ -269,19 +276,29 @@ def _antijoin_reports(df: pd.DataFrame, reports: pd.DataFrame) -> pd.DataFrame:
|
||||
return antijoin(df, reports, schemas.COMMUNITY_ID)
|
||||
|
||||
|
||||
def _sort_and_trim_context(df: pd.DataFrame, max_context_tokens: int) -> pd.Series:
|
||||
def _sort_and_trim_context(
|
||||
df: pd.DataFrame, tokenizer: Tokenizer, max_context_tokens: int
|
||||
) -> pd.Series:
|
||||
"""Sort and trim context to fit the limit."""
|
||||
series = cast("pd.Series", df[schemas.ALL_CONTEXT])
|
||||
return transform_series(
|
||||
series, lambda x: sort_context(x, max_context_tokens=max_context_tokens)
|
||||
series,
|
||||
lambda x: sort_context(
|
||||
x, tokenizer=tokenizer, max_context_tokens=max_context_tokens
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _build_mixed_context(df: pd.DataFrame, max_context_tokens: int) -> pd.Series:
|
||||
def _build_mixed_context(
|
||||
df: pd.DataFrame, tokenizer: Tokenizer, max_context_tokens: int
|
||||
) -> pd.Series:
|
||||
"""Sort and trim context to fit the limit."""
|
||||
series = cast("pd.Series", df[schemas.ALL_CONTEXT])
|
||||
return transform_series(
|
||||
series, lambda x: build_mixed_context(x, max_context_tokens=max_context_tokens)
|
||||
series,
|
||||
lambda x: build_mixed_context(
|
||||
x, tokenizer, max_context_tokens=max_context_tokens
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@ -303,6 +320,7 @@ def _get_community_df(
|
||||
invalid_context_df: pd.DataFrame,
|
||||
sub_context_df: pd.DataFrame,
|
||||
community_hierarchy_df: pd.DataFrame,
|
||||
tokenizer: Tokenizer,
|
||||
max_context_tokens: int,
|
||||
) -> pd.DataFrame:
|
||||
"""Get community context for each community."""
|
||||
@ -338,7 +356,7 @@ def _get_community_df(
|
||||
.reset_index()
|
||||
)
|
||||
community_df[schemas.CONTEXT_STRING] = _build_mixed_context(
|
||||
community_df, max_context_tokens
|
||||
community_df, tokenizer, max_context_tokens
|
||||
)
|
||||
community_df[schemas.COMMUNITY_LEVEL] = level
|
||||
return community_df
|
||||
|
||||
@ -5,11 +5,12 @@
|
||||
import pandas as pd
|
||||
|
||||
import graphrag.data_model.schemas as schemas
|
||||
from graphrag.query.llm.text_utils import num_tokens
|
||||
from graphrag.tokenizer.tokenizer import Tokenizer
|
||||
|
||||
|
||||
def sort_context(
|
||||
local_context: list[dict],
|
||||
tokenizer: Tokenizer,
|
||||
sub_community_reports: list[dict] | None = None,
|
||||
max_context_tokens: int | None = None,
|
||||
node_name_column: str = schemas.TITLE,
|
||||
@ -112,7 +113,10 @@ def sort_context(
|
||||
new_context_string = _get_context_string(
|
||||
sorted_nodes, sorted_edges, sorted_claims, sub_community_reports
|
||||
)
|
||||
if max_context_tokens and num_tokens(new_context_string) > max_context_tokens:
|
||||
if (
|
||||
max_context_tokens
|
||||
and tokenizer.num_tokens(new_context_string) > max_context_tokens
|
||||
):
|
||||
break
|
||||
context_string = new_context_string
|
||||
|
||||
@ -122,7 +126,9 @@ def sort_context(
|
||||
)
|
||||
|
||||
|
||||
def parallel_sort_context_batch(community_df, max_context_tokens, parallel=False):
|
||||
def parallel_sort_context_batch(
|
||||
community_df, tokenizer: Tokenizer, max_context_tokens, parallel=False
|
||||
):
|
||||
"""Calculate context using parallelization if enabled."""
|
||||
if parallel:
|
||||
# Use ThreadPoolExecutor for parallel execution
|
||||
@ -131,7 +137,9 @@ def parallel_sort_context_batch(community_df, max_context_tokens, parallel=False
|
||||
with ThreadPoolExecutor(max_workers=None) as executor:
|
||||
context_strings = list(
|
||||
executor.map(
|
||||
lambda x: sort_context(x, max_context_tokens=max_context_tokens),
|
||||
lambda x: sort_context(
|
||||
x, tokenizer, max_context_tokens=max_context_tokens
|
||||
),
|
||||
community_df[schemas.ALL_CONTEXT],
|
||||
)
|
||||
)
|
||||
@ -141,13 +149,13 @@ def parallel_sort_context_batch(community_df, max_context_tokens, parallel=False
|
||||
# Assign context strings directly to the DataFrame
|
||||
community_df[schemas.CONTEXT_STRING] = community_df[schemas.ALL_CONTEXT].apply(
|
||||
lambda context_list: sort_context(
|
||||
context_list, max_context_tokens=max_context_tokens
|
||||
context_list, tokenizer, max_context_tokens=max_context_tokens
|
||||
)
|
||||
)
|
||||
|
||||
# Calculate other columns
|
||||
community_df[schemas.CONTEXT_SIZE] = community_df[schemas.CONTEXT_STRING].apply(
|
||||
num_tokens
|
||||
tokenizer.num_tokens
|
||||
)
|
||||
community_df[schemas.CONTEXT_EXCEED_FLAG] = (
|
||||
community_df[schemas.CONTEXT_SIZE] > max_context_tokens
|
||||
|
||||
@ -16,7 +16,6 @@ from graphrag.index.operations.summarize_communities.typing import (
|
||||
Finding,
|
||||
StrategyConfig,
|
||||
)
|
||||
from graphrag.index.utils.rate_limiter import RateLimiter
|
||||
from graphrag.language_model.manager import ModelManager
|
||||
from graphrag.language_model.protocol.base import ChatModel
|
||||
|
||||
@ -51,8 +50,6 @@ async def _run_extractor(
|
||||
level: int,
|
||||
args: StrategyConfig,
|
||||
) -> CommunityReport | None:
|
||||
# RateLimiter
|
||||
rate_limiter = RateLimiter(rate=1, per=60)
|
||||
extractor = CommunityReportsExtractor(
|
||||
model,
|
||||
extraction_prompt=args.get("extraction_prompt", None),
|
||||
@ -63,7 +60,6 @@ async def _run_extractor(
|
||||
)
|
||||
|
||||
try:
|
||||
await rate_limiter.acquire()
|
||||
results = await extractor(input)
|
||||
report = results.structured_output
|
||||
if report is None:
|
||||
|
||||
@ -23,6 +23,7 @@ from graphrag.index.operations.summarize_communities.utils import (
|
||||
)
|
||||
from graphrag.index.utils.derive_from_rows import derive_from_rows
|
||||
from graphrag.logger.progress import progress_ticker
|
||||
from graphrag.tokenizer.tokenizer import Tokenizer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -35,6 +36,7 @@ async def summarize_communities(
|
||||
callbacks: WorkflowCallbacks,
|
||||
cache: PipelineCache,
|
||||
strategy: dict,
|
||||
tokenizer: Tokenizer,
|
||||
max_input_length: int,
|
||||
async_mode: AsyncType = AsyncType.AsyncIO,
|
||||
num_threads: int = 4,
|
||||
@ -44,7 +46,6 @@ async def summarize_communities(
|
||||
tick = progress_ticker(callbacks.progress, len(local_contexts))
|
||||
strategy_exec = load_strategy(strategy["type"])
|
||||
strategy_config = {**strategy}
|
||||
|
||||
community_hierarchy = (
|
||||
communities.explode("children")
|
||||
.rename({"children": "sub_community"}, axis=1)
|
||||
@ -60,6 +61,7 @@ async def summarize_communities(
|
||||
community_hierarchy_df=community_hierarchy,
|
||||
local_context_df=local_contexts,
|
||||
level=level,
|
||||
tokenizer=tokenizer,
|
||||
max_context_tokens=max_input_length,
|
||||
)
|
||||
level_contexts.append(level_context)
|
||||
|
||||
@ -18,7 +18,7 @@ from graphrag.index.operations.summarize_communities.text_unit_context.prep_text
|
||||
from graphrag.index.operations.summarize_communities.text_unit_context.sort_context import (
|
||||
sort_context,
|
||||
)
|
||||
from graphrag.query.llm.text_utils import num_tokens
|
||||
from graphrag.tokenizer.tokenizer import Tokenizer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -27,6 +27,7 @@ def build_local_context(
|
||||
community_membership_df: pd.DataFrame,
|
||||
text_units_df: pd.DataFrame,
|
||||
node_df: pd.DataFrame,
|
||||
tokenizer: Tokenizer,
|
||||
max_context_tokens: int = 16000,
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
@ -69,10 +70,10 @@ def build_local_context(
|
||||
.reset_index()
|
||||
)
|
||||
context_df[schemas.CONTEXT_STRING] = context_df[schemas.ALL_CONTEXT].apply(
|
||||
lambda x: sort_context(x)
|
||||
lambda x: sort_context(x, tokenizer)
|
||||
)
|
||||
context_df[schemas.CONTEXT_SIZE] = context_df[schemas.CONTEXT_STRING].apply(
|
||||
lambda x: num_tokens(x)
|
||||
lambda x: tokenizer.num_tokens(x)
|
||||
)
|
||||
context_df[schemas.CONTEXT_EXCEED_FLAG] = context_df[schemas.CONTEXT_SIZE].apply(
|
||||
lambda x: x > max_context_tokens
|
||||
@ -86,6 +87,7 @@ def build_level_context(
|
||||
community_hierarchy_df: pd.DataFrame,
|
||||
local_context_df: pd.DataFrame,
|
||||
level: int,
|
||||
tokenizer: Tokenizer,
|
||||
max_context_tokens: int = 16000,
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
@ -116,10 +118,12 @@ def build_level_context(
|
||||
|
||||
invalid_context_df.loc[:, [schemas.CONTEXT_STRING]] = invalid_context_df[
|
||||
schemas.ALL_CONTEXT
|
||||
].apply(lambda x: sort_context(x, max_context_tokens=max_context_tokens))
|
||||
].apply(
|
||||
lambda x: sort_context(x, tokenizer, max_context_tokens=max_context_tokens)
|
||||
)
|
||||
invalid_context_df.loc[:, [schemas.CONTEXT_SIZE]] = invalid_context_df[
|
||||
schemas.CONTEXT_STRING
|
||||
].apply(lambda x: num_tokens(x))
|
||||
].apply(lambda x: tokenizer.num_tokens(x))
|
||||
invalid_context_df.loc[:, [schemas.CONTEXT_EXCEED_FLAG]] = False
|
||||
|
||||
return pd.concat([valid_context_df, invalid_context_df])
|
||||
@ -199,10 +203,10 @@ def build_level_context(
|
||||
.reset_index()
|
||||
)
|
||||
community_df[schemas.CONTEXT_STRING] = community_df[schemas.ALL_CONTEXT].apply(
|
||||
lambda x: build_mixed_context(x, max_context_tokens)
|
||||
lambda x: build_mixed_context(x, tokenizer, max_context_tokens)
|
||||
)
|
||||
community_df[schemas.CONTEXT_SIZE] = community_df[schemas.CONTEXT_STRING].apply(
|
||||
lambda x: num_tokens(x)
|
||||
lambda x: tokenizer.num_tokens(x)
|
||||
)
|
||||
community_df[schemas.CONTEXT_EXCEED_FLAG] = False
|
||||
community_df[schemas.COMMUNITY_LEVEL] = level
|
||||
@ -220,10 +224,10 @@ def build_level_context(
|
||||
)
|
||||
remaining_df[schemas.CONTEXT_STRING] = cast(
|
||||
"pd.DataFrame", remaining_df[schemas.ALL_CONTEXT]
|
||||
).apply(lambda x: sort_context(x, max_context_tokens=max_context_tokens))
|
||||
).apply(lambda x: sort_context(x, tokenizer, max_context_tokens=max_context_tokens))
|
||||
remaining_df[schemas.CONTEXT_SIZE] = cast(
|
||||
"pd.DataFrame", remaining_df[schemas.CONTEXT_STRING]
|
||||
).apply(lambda x: num_tokens(x))
|
||||
).apply(lambda x: tokenizer.num_tokens(x))
|
||||
remaining_df[schemas.CONTEXT_EXCEED_FLAG] = False
|
||||
|
||||
return cast(
|
||||
|
||||
@ -8,7 +8,7 @@ import logging
|
||||
import pandas as pd
|
||||
|
||||
import graphrag.data_model.schemas as schemas
|
||||
from graphrag.query.llm.text_utils import num_tokens
|
||||
from graphrag.tokenizer.tokenizer import Tokenizer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -57,6 +57,7 @@ def get_context_string(
|
||||
|
||||
def sort_context(
|
||||
local_context: list[dict],
|
||||
tokenizer: Tokenizer,
|
||||
sub_community_reports: list[dict] | None = None,
|
||||
max_context_tokens: int | None = None,
|
||||
) -> str:
|
||||
@ -73,7 +74,7 @@ def sort_context(
|
||||
new_context_string = get_context_string(
|
||||
current_text_units, sub_community_reports
|
||||
)
|
||||
if num_tokens(new_context_string) > max_context_tokens:
|
||||
if tokenizer.num_tokens(new_context_string) > max_context_tokens:
|
||||
break
|
||||
|
||||
context_string = new_context_string
|
||||
|
||||
@ -2,17 +2,3 @@
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""Root package for description summarization."""
|
||||
|
||||
from graphrag.index.operations.summarize_descriptions.summarize_descriptions import (
|
||||
summarize_descriptions,
|
||||
)
|
||||
from graphrag.index.operations.summarize_descriptions.typing import (
|
||||
SummarizationStrategy,
|
||||
SummarizeStrategyType,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"SummarizationStrategy",
|
||||
"SummarizeStrategyType",
|
||||
"summarize_descriptions",
|
||||
]
|
||||
|
||||
@ -7,9 +7,9 @@ import json
|
||||
from dataclasses import dataclass
|
||||
|
||||
from graphrag.index.typing.error_handler import ErrorHandlerFn
|
||||
from graphrag.index.utils.tokens import num_tokens_from_string
|
||||
from graphrag.language_model.protocol.base import ChatModel
|
||||
from graphrag.prompts.index.summarize_descriptions import SUMMARIZE_PROMPT
|
||||
from graphrag.tokenizer.get_tokenizer import get_tokenizer
|
||||
|
||||
# these tokens are used in the prompt
|
||||
ENTITY_NAME_KEY = "entity_name"
|
||||
@ -45,7 +45,7 @@ class SummarizeExtractor:
|
||||
"""Init method definition."""
|
||||
# TODO: streamline construction
|
||||
self._model = model_invoker
|
||||
|
||||
self._tokenizer = get_tokenizer(model_invoker.config)
|
||||
self._summarization_prompt = summarization_prompt or SUMMARIZE_PROMPT
|
||||
self._on_error = on_error or (lambda _e, _s, _d: None)
|
||||
self._max_summary_length = max_summary_length
|
||||
@ -85,14 +85,14 @@ class SummarizeExtractor:
|
||||
descriptions = sorted(descriptions)
|
||||
|
||||
# Iterate over descriptions, adding all until the max input tokens is reached
|
||||
usable_tokens = self._max_input_tokens - num_tokens_from_string(
|
||||
usable_tokens = self._max_input_tokens - self._tokenizer.num_tokens(
|
||||
self._summarization_prompt
|
||||
)
|
||||
descriptions_collected = []
|
||||
result = ""
|
||||
|
||||
for i, description in enumerate(descriptions):
|
||||
usable_tokens -= num_tokens_from_string(description)
|
||||
usable_tokens -= self._tokenizer.num_tokens(description)
|
||||
descriptions_collected.append(description)
|
||||
|
||||
# If buffer is full, or all descriptions have been added, summarize
|
||||
@ -109,8 +109,8 @@ class SummarizeExtractor:
|
||||
descriptions_collected = [result]
|
||||
usable_tokens = (
|
||||
self._max_input_tokens
|
||||
- num_tokens_from_string(self._summarization_prompt)
|
||||
- num_tokens_from_string(result)
|
||||
- self._tokenizer.num_tokens(self._summarization_prompt)
|
||||
- self._tokenizer.num_tokens(result)
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
@ -9,6 +9,9 @@ import re
|
||||
import time
|
||||
from collections.abc import AsyncIterable
|
||||
from dataclasses import asdict
|
||||
from typing import Any
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
|
||||
from graphrag.config.models.graph_rag_config import GraphRagConfig
|
||||
@ -28,6 +31,8 @@ async def run_pipeline(
|
||||
config: GraphRagConfig,
|
||||
callbacks: WorkflowCallbacks,
|
||||
is_update_run: bool = False,
|
||||
additional_context: dict[str, Any] | None = None,
|
||||
input_documents: pd.DataFrame | None = None,
|
||||
) -> AsyncIterable[PipelineRunResult]:
|
||||
"""Run all workflows using a simplified pipeline."""
|
||||
root_dir = config.root_dir
|
||||
@ -40,6 +45,9 @@ async def run_pipeline(
|
||||
state_json = await output_storage.get("context.json")
|
||||
state = json.loads(state_json) if state_json else {}
|
||||
|
||||
if additional_context:
|
||||
state.setdefault("additional_context", {}).update(additional_context)
|
||||
|
||||
if is_update_run:
|
||||
logger.info("Running incremental indexing.")
|
||||
|
||||
@ -55,6 +63,11 @@ async def run_pipeline(
|
||||
|
||||
state["update_timestamp"] = update_timestamp
|
||||
|
||||
# if the user passes in a df directly, write directly to storage so we can skip finding/parsing later
|
||||
if input_documents is not None:
|
||||
await write_table_to_storage(input_documents, "documents", delta_storage)
|
||||
pipeline.remove("load_update_documents")
|
||||
|
||||
context = create_run_context(
|
||||
input_storage=input_storage,
|
||||
output_storage=delta_storage,
|
||||
@ -67,6 +80,11 @@ async def run_pipeline(
|
||||
else:
|
||||
logger.info("Running standard indexing.")
|
||||
|
||||
# if the user passes in a df directly, write directly to storage so we can skip finding/parsing later
|
||||
if input_documents is not None:
|
||||
await write_table_to_storage(input_documents, "documents", output_storage)
|
||||
pipeline.remove("load_input_documents")
|
||||
|
||||
context = create_run_context(
|
||||
input_storage=input_storage,
|
||||
output_storage=output_storage,
|
||||
@ -126,9 +144,17 @@ async def _dump_json(context: PipelineRunContext) -> None:
|
||||
await context.output_storage.set(
|
||||
"stats.json", json.dumps(asdict(context.stats), indent=4, ensure_ascii=False)
|
||||
)
|
||||
await context.output_storage.set(
|
||||
"context.json", json.dumps(context.state, indent=4, ensure_ascii=False)
|
||||
)
|
||||
# Dump context state, excluding additional_context
|
||||
temp_context = context.state.pop(
|
||||
"additional_context", None
|
||||
) # Remove reference only, as object size is uncertain
|
||||
try:
|
||||
state_blob = json.dumps(context.state, indent=4, ensure_ascii=False)
|
||||
finally:
|
||||
if temp_context:
|
||||
context.state["additional_context"] = temp_context
|
||||
|
||||
await context.output_storage.set("context.json", state_blob)
|
||||
|
||||
|
||||
async def _copy_previous_output(
|
||||
|
||||
@ -5,16 +5,16 @@
|
||||
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable, Collection, Iterable
|
||||
from collections.abc import Callable, Iterable
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Literal, cast
|
||||
from typing import Any, cast
|
||||
|
||||
import pandas as pd
|
||||
import tiktoken
|
||||
|
||||
import graphrag.config.defaults as defs
|
||||
from graphrag.index.operations.chunk_text.typing import TextChunk
|
||||
from graphrag.logger.progress import ProgressTicker
|
||||
from graphrag.tokenizer.get_tokenizer import get_tokenizer
|
||||
from graphrag.tokenizer.tokenizer import Tokenizer
|
||||
|
||||
EncodedText = list[int]
|
||||
DecodeFn = Callable[[EncodedText], str]
|
||||
@ -25,8 +25,8 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Tokenizer:
|
||||
"""Tokenizer data class."""
|
||||
class TokenChunkerOptions:
|
||||
"""TokenChunkerOptions data class."""
|
||||
|
||||
chunk_overlap: int
|
||||
"""Overlap in tokens between chunks"""
|
||||
@ -83,44 +83,18 @@ class NoopTextSplitter(TextSplitter):
|
||||
class TokenTextSplitter(TextSplitter):
|
||||
"""Token text splitter class definition."""
|
||||
|
||||
_allowed_special: Literal["all"] | set[str]
|
||||
_disallowed_special: Literal["all"] | Collection[str]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
encoding_name: str = defs.ENCODING_MODEL,
|
||||
model_name: str | None = None,
|
||||
allowed_special: Literal["all"] | set[str] | None = None,
|
||||
disallowed_special: Literal["all"] | Collection[str] = "all",
|
||||
tokenizer: Tokenizer | None = None,
|
||||
**kwargs: Any,
|
||||
):
|
||||
"""Init method definition."""
|
||||
super().__init__(**kwargs)
|
||||
if model_name is not None:
|
||||
try:
|
||||
enc = tiktoken.encoding_for_model(model_name)
|
||||
except KeyError:
|
||||
logger.exception(
|
||||
"Model %s not found, using %s", model_name, encoding_name
|
||||
)
|
||||
enc = tiktoken.get_encoding(encoding_name)
|
||||
else:
|
||||
enc = tiktoken.get_encoding(encoding_name)
|
||||
self._tokenizer = enc
|
||||
self._allowed_special = allowed_special or set()
|
||||
self._disallowed_special = disallowed_special
|
||||
|
||||
def encode(self, text: str) -> list[int]:
|
||||
"""Encode the given text into an int-vector."""
|
||||
return self._tokenizer.encode(
|
||||
text,
|
||||
allowed_special=self._allowed_special,
|
||||
disallowed_special=self._disallowed_special,
|
||||
)
|
||||
self._tokenizer = tokenizer or get_tokenizer()
|
||||
|
||||
def num_tokens(self, text: str) -> int:
|
||||
"""Return the number of tokens in a string."""
|
||||
return len(self.encode(text))
|
||||
return self._tokenizer.num_tokens(text)
|
||||
|
||||
def split_text(self, text: str | list[str]) -> list[str]:
|
||||
"""Split text method."""
|
||||
@ -132,17 +106,17 @@ class TokenTextSplitter(TextSplitter):
|
||||
msg = f"Attempting to split a non-string value, actual is {type(text)}"
|
||||
raise TypeError(msg)
|
||||
|
||||
tokenizer = Tokenizer(
|
||||
token_chunker_options = TokenChunkerOptions(
|
||||
chunk_overlap=self._chunk_overlap,
|
||||
tokens_per_chunk=self._chunk_size,
|
||||
decode=self._tokenizer.decode,
|
||||
encode=lambda text: self.encode(text),
|
||||
encode=self._tokenizer.encode,
|
||||
)
|
||||
|
||||
return split_single_text_on_tokens(text=text, tokenizer=tokenizer)
|
||||
return split_single_text_on_tokens(text=text, tokenizer=token_chunker_options)
|
||||
|
||||
|
||||
def split_single_text_on_tokens(text: str, tokenizer: Tokenizer) -> list[str]:
|
||||
def split_single_text_on_tokens(text: str, tokenizer: TokenChunkerOptions) -> list[str]:
|
||||
"""Split a single text and return chunks using the tokenizer."""
|
||||
result = []
|
||||
input_ids = tokenizer.encode(text)
|
||||
@ -166,7 +140,7 @@ def split_single_text_on_tokens(text: str, tokenizer: Tokenizer) -> list[str]:
|
||||
# Adapted from - https://github.com/langchain-ai/langchain/blob/77b359edf5df0d37ef0d539f678cf64f5557cb54/libs/langchain/langchain/text_splitter.py#L471
|
||||
# So we could have better control over the chunking process
|
||||
def split_multiple_texts_on_tokens(
|
||||
texts: list[str], tokenizer: Tokenizer, tick: ProgressTicker
|
||||
texts: list[str], tokenizer: TokenChunkerOptions, tick: ProgressTicker
|
||||
) -> list[TextChunk]:
|
||||
"""Split multiple texts and return chunks with metadata using the tokenizer."""
|
||||
result = []
|
||||
|
||||
@ -21,3 +21,7 @@ class Pipeline:
|
||||
def names(self) -> list[str]:
|
||||
"""Return the names of the workflows in the pipeline."""
|
||||
return [name for name, _ in self.workflows]
|
||||
|
||||
def remove(self, name: str) -> None:
|
||||
"""Remove a workflow from the pipeline by name."""
|
||||
self.workflows = [w for w in self.workflows if w[0] != name]
|
||||
|
||||
@ -1,40 +0,0 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""Rate limiter utility."""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
|
||||
class RateLimiter:
|
||||
"""
|
||||
The original TpmRpmLLMLimiter strategy did not account for minute-based rate limiting when scheduled.
|
||||
|
||||
The RateLimiter was introduced to ensure that the CommunityReportsExtractor could be scheduled to adhere to rate configurations on a per-minute basis.
|
||||
"""
|
||||
|
||||
# TODO: RateLimiter scheduled: using asyncio for async_mode
|
||||
|
||||
def __init__(self, rate: int, per: int):
|
||||
self.rate = rate
|
||||
self.per = per
|
||||
self.allowance = rate
|
||||
self.last_check = time.monotonic()
|
||||
|
||||
async def acquire(self):
|
||||
"""Acquire a token from the rate limiter."""
|
||||
current = time.monotonic()
|
||||
elapsed = current - self.last_check
|
||||
self.last_check = current
|
||||
self.allowance += elapsed * (self.rate / self.per)
|
||||
|
||||
if self.allowance > self.rate:
|
||||
self.allowance = self.rate
|
||||
|
||||
if self.allowance < 1.0:
|
||||
sleep_time = (1.0 - self.allowance) * (self.per / self.rate)
|
||||
await asyncio.sleep(sleep_time)
|
||||
self.allowance = 0.0
|
||||
else:
|
||||
self.allowance -= 1.0
|
||||
@ -1,44 +0,0 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""Utilities for working with tokens."""
|
||||
|
||||
import logging
|
||||
|
||||
import tiktoken
|
||||
|
||||
import graphrag.config.defaults as defs
|
||||
|
||||
DEFAULT_ENCODING_NAME = defs.ENCODING_MODEL
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def num_tokens_from_string(
|
||||
string: str, model: str | None = None, encoding_name: str | None = None
|
||||
) -> int:
|
||||
"""Return the number of tokens in a text string."""
|
||||
if model is not None:
|
||||
try:
|
||||
encoding = tiktoken.encoding_for_model(model)
|
||||
except KeyError:
|
||||
msg = f"Failed to get encoding for {model} when getting num_tokens_from_string. Fall back to default encoding {DEFAULT_ENCODING_NAME}"
|
||||
logger.warning(msg)
|
||||
encoding = tiktoken.get_encoding(DEFAULT_ENCODING_NAME)
|
||||
else:
|
||||
encoding = tiktoken.get_encoding(encoding_name or DEFAULT_ENCODING_NAME)
|
||||
return len(encoding.encode(string))
|
||||
|
||||
|
||||
def string_from_tokens(
|
||||
tokens: list[int], model: str | None = None, encoding_name: str | None = None
|
||||
) -> str:
|
||||
"""Return a text string from a list of tokens."""
|
||||
if model is not None:
|
||||
encoding = tiktoken.encoding_for_model(model)
|
||||
elif encoding_name is not None:
|
||||
encoding = tiktoken.get_encoding(encoding_name)
|
||||
else:
|
||||
msg = "Either model or encoding_name must be specified."
|
||||
raise ValueError(msg)
|
||||
return encoding.decode(tokens)
|
||||
@ -15,42 +15,39 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def validate_config_names(parameters: GraphRagConfig) -> None:
|
||||
"""Validate config file for LLM deployment name typos."""
|
||||
# Validate Chat LLM configs
|
||||
# TODO: Replace default_chat_model with a way to select the model
|
||||
default_llm_settings = parameters.get_language_model_config("default_chat_model")
|
||||
|
||||
llm = ModelManager().register_chat(
|
||||
name="test-llm",
|
||||
model_type=default_llm_settings.type,
|
||||
config=default_llm_settings,
|
||||
callbacks=NoopWorkflowCallbacks(),
|
||||
cache=None,
|
||||
)
|
||||
|
||||
try:
|
||||
asyncio.run(llm.achat("This is an LLM connectivity test. Say Hello World"))
|
||||
logger.info("LLM Config Params Validated")
|
||||
except Exception as e: # noqa: BLE001
|
||||
logger.error(f"LLM configuration error detected. Exiting...\n{e}") # noqa
|
||||
sys.exit(1)
|
||||
|
||||
# Validate Embeddings LLM configs
|
||||
embedding_llm_settings = parameters.get_language_model_config(
|
||||
parameters.embed_text.model_id
|
||||
)
|
||||
|
||||
embed_llm = ModelManager().register_embedding(
|
||||
name="test-embed-llm",
|
||||
model_type=embedding_llm_settings.type,
|
||||
config=embedding_llm_settings,
|
||||
callbacks=NoopWorkflowCallbacks(),
|
||||
cache=None,
|
||||
)
|
||||
|
||||
try:
|
||||
asyncio.run(embed_llm.aembed_batch(["This is an LLM Embedding Test String"]))
|
||||
logger.info("Embedding LLM Config Params Validated")
|
||||
except Exception as e: # noqa: BLE001
|
||||
logger.error(f"Embedding LLM configuration error detected. Exiting...\n{e}") # noqa
|
||||
sys.exit(1)
|
||||
"""Validate config file for model deployment name typos, by running a quick test message for each."""
|
||||
for id, config in parameters.models.items():
|
||||
if config.type in ["chat", "azure_openai", "openai"]:
|
||||
llm = ModelManager().register_chat(
|
||||
name="test-llm",
|
||||
model_type=config.type,
|
||||
config=config,
|
||||
callbacks=NoopWorkflowCallbacks(),
|
||||
cache=None,
|
||||
)
|
||||
try:
|
||||
asyncio.run(
|
||||
llm.achat("This is an LLM connectivity test. Say Hello World")
|
||||
)
|
||||
logger.info("LLM Config Params Validated")
|
||||
except Exception as e: # noqa: BLE001
|
||||
logger.error(f"LLM configuration error detected.\n{e}") # noqa
|
||||
print(f"Failed to validate language model ({id}) params", e) # noqa: T201
|
||||
sys.exit(1)
|
||||
elif config.type in ["embedding", "azure_openai_embedding", "openai_embedding"]:
|
||||
embed_llm = ModelManager().register_embedding(
|
||||
name="test-embed-llm",
|
||||
model_type=config.type,
|
||||
config=config,
|
||||
callbacks=NoopWorkflowCallbacks(),
|
||||
cache=None,
|
||||
)
|
||||
try:
|
||||
asyncio.run(
|
||||
embed_llm.aembed_batch(["This is an LLM Embedding Test String"])
|
||||
)
|
||||
logger.info("Embedding LLM Config Params Validated")
|
||||
except Exception as e: # noqa: BLE001
|
||||
logger.error(f"Embedding configuration error detected.\n{e}") # noqa
|
||||
print(f"Failed to validate embedding model ({id}) params", e) # noqa: T201
|
||||
sys.exit(1)
|
||||
|
||||
@ -17,7 +17,6 @@ from graphrag.index.operations.chunk_text.strategies import get_encoding_fn
|
||||
from graphrag.index.typing.context import PipelineRunContext
|
||||
from graphrag.index.typing.workflow import WorkflowFunctionOutput
|
||||
from graphrag.index.utils.hashing import gen_sha512_hash
|
||||
from graphrag.logger.progress import Progress
|
||||
from graphrag.utils.storage import load_table_from_storage, write_table_to_storage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -69,8 +68,6 @@ def create_base_text_units(
|
||||
zip(*[sort[col] for col in ["id", "text"]], strict=True)
|
||||
)
|
||||
|
||||
callbacks.progress(Progress(percent=0))
|
||||
|
||||
agg_dict = {"text_with_ids": list}
|
||||
if "metadata" in documents:
|
||||
agg_dict["metadata"] = "first" # type: ignore
|
||||
|
||||
@ -13,6 +13,7 @@ from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
|
||||
from graphrag.config.defaults import graphrag_config_defaults
|
||||
from graphrag.config.enums import AsyncType
|
||||
from graphrag.config.models.graph_rag_config import GraphRagConfig
|
||||
from graphrag.config.models.language_model_config import LanguageModelConfig
|
||||
from graphrag.index.operations.finalize_community_reports import (
|
||||
finalize_community_reports,
|
||||
)
|
||||
@ -28,6 +29,7 @@ from graphrag.index.operations.summarize_communities.summarize_communities impor
|
||||
)
|
||||
from graphrag.index.typing.context import PipelineRunContext
|
||||
from graphrag.index.typing.workflow import WorkflowFunctionOutput
|
||||
from graphrag.tokenizer.get_tokenizer import get_tokenizer
|
||||
from graphrag.utils.storage import (
|
||||
load_table_from_storage,
|
||||
storage_has_table,
|
||||
@ -102,6 +104,9 @@ async def create_community_reports(
|
||||
|
||||
summarization_strategy["extraction_prompt"] = summarization_strategy["graph_prompt"]
|
||||
|
||||
model_config = LanguageModelConfig(**summarization_strategy["llm"])
|
||||
tokenizer = get_tokenizer(model_config)
|
||||
|
||||
max_input_length = summarization_strategy.get(
|
||||
"max_input_length", graphrag_config_defaults.community_reports.max_input_length
|
||||
)
|
||||
@ -110,6 +115,7 @@ async def create_community_reports(
|
||||
nodes,
|
||||
edges,
|
||||
claims,
|
||||
tokenizer,
|
||||
callbacks,
|
||||
max_input_length,
|
||||
)
|
||||
@ -122,6 +128,7 @@ async def create_community_reports(
|
||||
callbacks,
|
||||
cache,
|
||||
summarization_strategy,
|
||||
tokenizer=tokenizer,
|
||||
max_input_length=max_input_length,
|
||||
async_mode=async_mode,
|
||||
num_threads=num_threads,
|
||||
|
||||
@ -12,6 +12,7 @@ from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
|
||||
from graphrag.config.defaults import graphrag_config_defaults
|
||||
from graphrag.config.enums import AsyncType
|
||||
from graphrag.config.models.graph_rag_config import GraphRagConfig
|
||||
from graphrag.config.models.language_model_config import LanguageModelConfig
|
||||
from graphrag.index.operations.finalize_community_reports import (
|
||||
finalize_community_reports,
|
||||
)
|
||||
@ -27,6 +28,7 @@ from graphrag.index.operations.summarize_communities.text_unit_context.context_b
|
||||
)
|
||||
from graphrag.index.typing.context import PipelineRunContext
|
||||
from graphrag.index.typing.workflow import WorkflowFunctionOutput
|
||||
from graphrag.tokenizer.get_tokenizer import get_tokenizer
|
||||
from graphrag.utils.storage import load_table_from_storage, write_table_to_storage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -88,8 +90,11 @@ async def create_community_reports_text(
|
||||
"max_input_length", graphrag_config_defaults.community_reports.max_input_length
|
||||
)
|
||||
|
||||
model_config = LanguageModelConfig(**summarization_strategy["llm"])
|
||||
tokenizer = get_tokenizer(model_config)
|
||||
|
||||
local_contexts = build_local_context(
|
||||
communities, text_units, nodes, max_input_length
|
||||
communities, text_units, nodes, tokenizer, max_input_length
|
||||
)
|
||||
|
||||
community_reports = await summarize_communities(
|
||||
@ -100,6 +105,7 @@ async def create_community_reports_text(
|
||||
callbacks,
|
||||
cache,
|
||||
summarization_strategy,
|
||||
tokenizer=tokenizer,
|
||||
max_input_length=max_input_length,
|
||||
async_mode=async_mode,
|
||||
num_threads=num_threads,
|
||||
|
||||
@ -69,7 +69,7 @@ def create_final_documents(
|
||||
).reset_index(drop=True)
|
||||
|
||||
rejoined["id"] = rejoined["id"].astype(str)
|
||||
rejoined["human_readable_id"] = rejoined.index + 1
|
||||
rejoined["human_readable_id"] = rejoined.index
|
||||
|
||||
if "metadata" not in rejoined.columns:
|
||||
rejoined["metadata"] = pd.Series(dtype="object")
|
||||
|
||||
@ -60,7 +60,7 @@ def create_final_text_units(
|
||||
) -> pd.DataFrame:
|
||||
"""All the steps to transform the text units."""
|
||||
selected = text_units.loc[:, ["id", "text", "document_ids", "n_tokens"]]
|
||||
selected["human_readable_id"] = selected.index + 1
|
||||
selected["human_readable_id"] = selected.index
|
||||
|
||||
entity_join = _entities(final_entities)
|
||||
relationship_join = _relationships(final_relationships)
|
||||
|
||||
@ -88,6 +88,6 @@ async def extract_covariates(
|
||||
)
|
||||
text_units.drop(columns=["text_unit_id"], inplace=True) # don't pollute the global
|
||||
covariates["id"] = covariates["covariate_type"].apply(lambda _x: str(uuid4()))
|
||||
covariates["human_readable_id"] = covariates.index + 1
|
||||
covariates["human_readable_id"] = covariates.index
|
||||
|
||||
return covariates.loc[:, COVARIATES_FINAL_COLUMNS]
|
||||
|
||||
@ -15,7 +15,7 @@ from graphrag.config.models.graph_rag_config import GraphRagConfig
|
||||
from graphrag.index.operations.extract_graph.extract_graph import (
|
||||
extract_graph as extractor,
|
||||
)
|
||||
from graphrag.index.operations.summarize_descriptions import (
|
||||
from graphrag.index.operations.summarize_descriptions.summarize_descriptions import (
|
||||
summarize_descriptions,
|
||||
)
|
||||
from graphrag.index.typing.context import PipelineRunContext
|
||||
|
||||
@ -61,6 +61,7 @@ async def extract_graph_nlp(
|
||||
text_analyzer=text_analyzer,
|
||||
normalize_edge_weights=extraction_config.normalize_edge_weights,
|
||||
num_threads=extraction_config.concurrent_requests,
|
||||
async_mode=extraction_config.async_mode,
|
||||
cache=cache,
|
||||
)
|
||||
|
||||
|
||||
@ -21,12 +21,11 @@ from graphrag.config.embeddings import (
|
||||
)
|
||||
from graphrag.config.get_embedding_settings import get_embedding_settings
|
||||
from graphrag.config.models.graph_rag_config import GraphRagConfig
|
||||
from graphrag.index.operations.embed_text import embed_text
|
||||
from graphrag.index.operations.embed_text.embed_text import embed_text
|
||||
from graphrag.index.typing.context import PipelineRunContext
|
||||
from graphrag.index.typing.workflow import WorkflowFunctionOutput
|
||||
from graphrag.utils.storage import (
|
||||
load_table_from_storage,
|
||||
storage_has_table,
|
||||
write_table_to_storage,
|
||||
)
|
||||
|
||||
@ -39,27 +38,35 @@ async def run_workflow(
|
||||
) -> WorkflowFunctionOutput:
|
||||
"""All the steps to transform community reports."""
|
||||
logger.info("Workflow started: generate_text_embeddings")
|
||||
embedded_fields = config.embed_text.names
|
||||
logger.info("Embedding the following fields: %s", embedded_fields)
|
||||
documents = None
|
||||
relationships = None
|
||||
text_units = None
|
||||
entities = None
|
||||
community_reports = None
|
||||
if await storage_has_table("documents", context.output_storage):
|
||||
if document_text_embedding in embedded_fields:
|
||||
documents = await load_table_from_storage("documents", context.output_storage)
|
||||
if await storage_has_table("relationships", context.output_storage):
|
||||
if relationship_description_embedding in embedded_fields:
|
||||
relationships = await load_table_from_storage(
|
||||
"relationships", context.output_storage
|
||||
)
|
||||
if await storage_has_table("text_units", context.output_storage):
|
||||
if text_unit_text_embedding in embedded_fields:
|
||||
text_units = await load_table_from_storage("text_units", context.output_storage)
|
||||
if await storage_has_table("entities", context.output_storage):
|
||||
if (
|
||||
entity_title_embedding in embedded_fields
|
||||
or entity_description_embedding in embedded_fields
|
||||
):
|
||||
entities = await load_table_from_storage("entities", context.output_storage)
|
||||
if await storage_has_table("community_reports", context.output_storage):
|
||||
if (
|
||||
community_title_embedding in embedded_fields
|
||||
or community_summary_embedding in embedded_fields
|
||||
or community_full_content_embedding in embedded_fields
|
||||
):
|
||||
community_reports = await load_table_from_storage(
|
||||
"community_reports", context.output_storage
|
||||
)
|
||||
|
||||
embedded_fields = config.embed_text.names
|
||||
text_embed = get_embedding_settings(config)
|
||||
|
||||
output = await generate_text_embeddings(
|
||||
|
||||
@ -7,13 +7,17 @@ from collections.abc import Callable
|
||||
from typing import Any, ClassVar
|
||||
|
||||
from graphrag.config.enums import ModelType
|
||||
from graphrag.language_model.protocol import ChatModel, EmbeddingModel
|
||||
from graphrag.language_model.protocol.base import ChatModel, EmbeddingModel
|
||||
from graphrag.language_model.providers.fnllm.models import (
|
||||
AzureOpenAIChatFNLLM,
|
||||
AzureOpenAIEmbeddingFNLLM,
|
||||
OpenAIChatFNLLM,
|
||||
OpenAIEmbeddingFNLLM,
|
||||
)
|
||||
from graphrag.language_model.providers.litellm.chat_model import LitellmChatModel
|
||||
from graphrag.language_model.providers.litellm.embedding_model import (
|
||||
LitellmEmbeddingModel,
|
||||
)
|
||||
|
||||
|
||||
class ModelFactory:
|
||||
@ -100,15 +104,20 @@ class ModelFactory:
|
||||
|
||||
# --- Register default implementations ---
|
||||
ModelFactory.register_chat(
|
||||
ModelType.AzureOpenAIChat, lambda **kwargs: AzureOpenAIChatFNLLM(**kwargs)
|
||||
ModelType.AzureOpenAIChat.value, lambda **kwargs: AzureOpenAIChatFNLLM(**kwargs)
|
||||
)
|
||||
ModelFactory.register_chat(
|
||||
ModelType.OpenAIChat, lambda **kwargs: OpenAIChatFNLLM(**kwargs)
|
||||
ModelType.OpenAIChat.value, lambda **kwargs: OpenAIChatFNLLM(**kwargs)
|
||||
)
|
||||
ModelFactory.register_chat(ModelType.Chat, lambda **kwargs: LitellmChatModel(**kwargs))
|
||||
|
||||
ModelFactory.register_embedding(
|
||||
ModelType.AzureOpenAIEmbedding, lambda **kwargs: AzureOpenAIEmbeddingFNLLM(**kwargs)
|
||||
ModelType.AzureOpenAIEmbedding.value,
|
||||
lambda **kwargs: AzureOpenAIEmbeddingFNLLM(**kwargs),
|
||||
)
|
||||
ModelFactory.register_embedding(
|
||||
ModelType.OpenAIEmbedding, lambda **kwargs: OpenAIEmbeddingFNLLM(**kwargs)
|
||||
ModelType.OpenAIEmbedding.value, lambda **kwargs: OpenAIEmbeddingFNLLM(**kwargs)
|
||||
)
|
||||
ModelFactory.register_embedding(
|
||||
ModelType.Embedding, lambda **kwargs: LitellmEmbeddingModel(**kwargs)
|
||||
)
|
||||
|
||||
@ -16,7 +16,7 @@ from typing_extensions import Self
|
||||
from graphrag.language_model.factory import ModelFactory
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from graphrag.language_model.protocol import ChatModel, EmbeddingModel
|
||||
from graphrag.language_model.protocol.base import ChatModel, EmbeddingModel
|
||||
|
||||
|
||||
class ModelManager:
|
||||
|
||||
@ -2,7 +2,3 @@
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""Base protocol definitions for LLMs."""
|
||||
|
||||
from .base import ChatModel, EmbeddingModel
|
||||
|
||||
__all__ = ["ChatModel", "EmbeddingModel"]
|
||||
|
||||
4
graphrag/language_model/providers/litellm/__init__.py
Normal file
4
graphrag/language_model/providers/litellm/__init__.py
Normal file
@ -0,0 +1,4 @@
|
||||
# Copyright (c) 2025 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""GraphRAG LiteLLM module. Provides LiteLLM-based implementations of chat and embedding models."""
|
||||
414
graphrag/language_model/providers/litellm/chat_model.py
Normal file
414
graphrag/language_model/providers/litellm/chat_model.py
Normal file
@ -0,0 +1,414 @@
|
||||
# Copyright (c) 2025 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""Chat model implementation using Litellm."""
|
||||
|
||||
import inspect
|
||||
import json
|
||||
from collections.abc import AsyncGenerator, Generator
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
import litellm
|
||||
from azure.identity import DefaultAzureCredential, get_bearer_token_provider
|
||||
from litellm import (
|
||||
CustomStreamWrapper,
|
||||
ModelResponse, # type: ignore
|
||||
acompletion,
|
||||
completion,
|
||||
)
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from graphrag.config.defaults import COGNITIVE_SERVICES_AUDIENCE
|
||||
from graphrag.config.enums import AuthType
|
||||
from graphrag.language_model.providers.litellm.request_wrappers.with_cache import (
|
||||
with_cache,
|
||||
)
|
||||
from graphrag.language_model.providers.litellm.request_wrappers.with_logging import (
|
||||
with_logging,
|
||||
)
|
||||
from graphrag.language_model.providers.litellm.request_wrappers.with_rate_limiter import (
|
||||
with_rate_limiter,
|
||||
)
|
||||
from graphrag.language_model.providers.litellm.request_wrappers.with_retries import (
|
||||
with_retries,
|
||||
)
|
||||
from graphrag.language_model.providers.litellm.types import (
|
||||
AFixedModelCompletion,
|
||||
FixedModelCompletion,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from graphrag.cache.pipeline_cache import PipelineCache
|
||||
from graphrag.config.models.language_model_config import LanguageModelConfig
|
||||
from graphrag.language_model.response.base import ModelResponse as MR # noqa: N817
|
||||
|
||||
litellm.suppress_debug_info = True
|
||||
|
||||
|
||||
def _create_base_completions(
|
||||
model_config: "LanguageModelConfig",
|
||||
) -> tuple[FixedModelCompletion, AFixedModelCompletion]:
|
||||
"""Wrap the base litellm completion function with the model configuration.
|
||||
|
||||
Args
|
||||
----
|
||||
model_config: The configuration for the language model.
|
||||
|
||||
Returns
|
||||
-------
|
||||
A tuple containing the synchronous and asynchronous completion functions.
|
||||
"""
|
||||
model_provider = model_config.model_provider
|
||||
model = model_config.deployment_name or model_config.model
|
||||
|
||||
base_args: dict[str, Any] = {
|
||||
"drop_params": True, # LiteLLM drop unsupported params for selected model.
|
||||
"model": f"{model_provider}/{model}",
|
||||
"timeout": model_config.request_timeout,
|
||||
"top_p": model_config.top_p,
|
||||
"n": model_config.n,
|
||||
"temperature": model_config.temperature,
|
||||
"frequency_penalty": model_config.frequency_penalty,
|
||||
"presence_penalty": model_config.presence_penalty,
|
||||
"api_base": model_config.api_base,
|
||||
"api_version": model_config.api_version,
|
||||
"api_key": model_config.api_key,
|
||||
"organization": model_config.organization,
|
||||
"proxy": model_config.proxy,
|
||||
"audience": model_config.audience,
|
||||
"max_tokens": model_config.max_tokens,
|
||||
"max_completion_tokens": model_config.max_completion_tokens,
|
||||
"reasoning_effort": model_config.reasoning_effort,
|
||||
}
|
||||
|
||||
if model_config.auth_type == AuthType.AzureManagedIdentity:
|
||||
if model_config.model_provider != "azure":
|
||||
msg = "Azure Managed Identity authentication is only supported for Azure models."
|
||||
raise ValueError(msg)
|
||||
|
||||
base_args["azure_scope"] = base_args.pop("audience")
|
||||
base_args["azure_ad_token_provider"] = get_bearer_token_provider(
|
||||
DefaultAzureCredential(),
|
||||
model_config.audience or COGNITIVE_SERVICES_AUDIENCE,
|
||||
)
|
||||
|
||||
def _base_completion(**kwargs: Any) -> ModelResponse | CustomStreamWrapper:
|
||||
new_args = {**base_args, **kwargs}
|
||||
|
||||
if "name" in new_args:
|
||||
new_args.pop("name")
|
||||
|
||||
return completion(**new_args)
|
||||
|
||||
async def _base_acompletion(**kwargs: Any) -> ModelResponse | CustomStreamWrapper:
|
||||
new_args = {**base_args, **kwargs}
|
||||
|
||||
if "name" in new_args:
|
||||
new_args.pop("name")
|
||||
|
||||
return await acompletion(**new_args)
|
||||
|
||||
return (_base_completion, _base_acompletion)
|
||||
|
||||
|
||||
def _create_completions(
|
||||
model_config: "LanguageModelConfig",
|
||||
cache: "PipelineCache | None",
|
||||
cache_key_prefix: str,
|
||||
) -> tuple[FixedModelCompletion, AFixedModelCompletion]:
|
||||
"""Wrap the base litellm completion function with the model configuration and additional features.
|
||||
|
||||
Wrap the base litellm completion function with instance variables based on the model configuration.
|
||||
Then wrap additional features such as rate limiting, retries, and caching, if enabled.
|
||||
|
||||
Final function composition order:
|
||||
- Logging(Cache(Retries(RateLimiter(ModelCompletion()))))
|
||||
|
||||
Args
|
||||
----
|
||||
model_config: The configuration for the language model.
|
||||
cache: Optional cache for storing responses.
|
||||
cache_key_prefix: Prefix for cache keys.
|
||||
|
||||
Returns
|
||||
-------
|
||||
A tuple containing the synchronous and asynchronous completion functions.
|
||||
|
||||
"""
|
||||
completion, acompletion = _create_base_completions(model_config)
|
||||
|
||||
# TODO: For v2.x release, rpm/tpm can be int or str (auto) for backwards compatibility with fnllm.
|
||||
# LiteLLM does not support "auto", so we have to check those values here.
|
||||
# For v3 release, force rpm/tpm to be int and remove the type checks below
|
||||
# and just check if rate_limit_strategy is enabled.
|
||||
if model_config.rate_limit_strategy is not None:
|
||||
rpm = (
|
||||
model_config.requests_per_minute
|
||||
if type(model_config.requests_per_minute) is int
|
||||
else None
|
||||
)
|
||||
tpm = (
|
||||
model_config.tokens_per_minute
|
||||
if type(model_config.tokens_per_minute) is int
|
||||
else None
|
||||
)
|
||||
if rpm is not None or tpm is not None:
|
||||
completion, acompletion = with_rate_limiter(
|
||||
sync_fn=completion,
|
||||
async_fn=acompletion,
|
||||
model_config=model_config,
|
||||
rpm=rpm,
|
||||
tpm=tpm,
|
||||
)
|
||||
|
||||
if model_config.retry_strategy != "none":
|
||||
completion, acompletion = with_retries(
|
||||
sync_fn=completion,
|
||||
async_fn=acompletion,
|
||||
model_config=model_config,
|
||||
)
|
||||
|
||||
if cache is not None:
|
||||
completion, acompletion = with_cache(
|
||||
sync_fn=completion,
|
||||
async_fn=acompletion,
|
||||
model_config=model_config,
|
||||
cache=cache,
|
||||
request_type="chat",
|
||||
cache_key_prefix=cache_key_prefix,
|
||||
)
|
||||
|
||||
completion, acompletion = with_logging(
|
||||
sync_fn=completion,
|
||||
async_fn=acompletion,
|
||||
)
|
||||
|
||||
return (completion, acompletion)
|
||||
|
||||
|
||||
class LitellmModelOutput(BaseModel):
|
||||
"""A model representing the output from a language model."""
|
||||
|
||||
content: str = Field(description="The generated text content")
|
||||
full_response: None = Field(
|
||||
default=None, description="The full response from the model, if available"
|
||||
)
|
||||
|
||||
|
||||
class LitellmModelResponse(BaseModel):
|
||||
"""A model representing the response from a language model."""
|
||||
|
||||
output: LitellmModelOutput = Field(description="The output from the model")
|
||||
parsed_response: BaseModel | None = Field(
|
||||
default=None, description="Parsed response from the model"
|
||||
)
|
||||
history: list = Field(
|
||||
default_factory=list,
|
||||
description="Conversation history including the prompt and response",
|
||||
)
|
||||
|
||||
|
||||
class LitellmChatModel:
|
||||
"""LiteLLM-based Chat Model."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
config: "LanguageModelConfig",
|
||||
cache: "PipelineCache | None" = None,
|
||||
**kwargs: Any,
|
||||
):
|
||||
self.name = name
|
||||
self.config = config
|
||||
self.cache = cache.child(self.name) if cache else None
|
||||
self.completion, self.acompletion = _create_completions(
|
||||
config, self.cache, "chat"
|
||||
)
|
||||
|
||||
def _get_kwargs(self, **kwargs: Any) -> dict[str, Any]:
|
||||
"""Get model arguments supported by litellm."""
|
||||
args_to_include = [
|
||||
"name",
|
||||
"modalities",
|
||||
"prediction",
|
||||
"audio",
|
||||
"logit_bias",
|
||||
"metadata",
|
||||
"user",
|
||||
"response_format",
|
||||
"seed",
|
||||
"tools",
|
||||
"tool_choice",
|
||||
"logprobs",
|
||||
"top_logprobs",
|
||||
"parallel_tool_calls",
|
||||
"web_search_options",
|
||||
"extra_headers",
|
||||
"functions",
|
||||
"function_call",
|
||||
"thinking",
|
||||
]
|
||||
new_args = {k: v for k, v in kwargs.items() if k in args_to_include}
|
||||
|
||||
# If using JSON, check if response_format should be a Pydantic model or just a general JSON object
|
||||
if kwargs.get("json"):
|
||||
new_args["response_format"] = {"type": "json_object"}
|
||||
|
||||
if (
|
||||
"json_model" in kwargs
|
||||
and inspect.isclass(kwargs["json_model"])
|
||||
and issubclass(kwargs["json_model"], BaseModel)
|
||||
):
|
||||
new_args["response_format"] = kwargs["json_model"]
|
||||
|
||||
return new_args
|
||||
|
||||
async def achat(
|
||||
self, prompt: str, history: list | None = None, **kwargs: Any
|
||||
) -> "MR":
|
||||
"""
|
||||
Generate a response for the given prompt and history.
|
||||
|
||||
Args
|
||||
----
|
||||
prompt: The prompt to generate a response for.
|
||||
history: Optional chat history.
|
||||
**kwargs: Additional keyword arguments (e.g., model parameters).
|
||||
|
||||
Returns
|
||||
-------
|
||||
LitellmModelResponse: The generated model response.
|
||||
"""
|
||||
new_kwargs = self._get_kwargs(**kwargs)
|
||||
messages: list[dict[str, str]] = history or []
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
|
||||
response = await self.acompletion(messages=messages, stream=False, **new_kwargs) # type: ignore
|
||||
|
||||
messages.append({
|
||||
"role": "assistant",
|
||||
"content": response.choices[0].message.content or "", # type: ignore
|
||||
})
|
||||
|
||||
parsed_response: BaseModel | None = None
|
||||
if "response_format" in new_kwargs:
|
||||
parsed_dict: dict[str, Any] = json.loads(
|
||||
response.choices[0].message.content or "{}" # type: ignore
|
||||
)
|
||||
parsed_response = parsed_dict # type: ignore
|
||||
if inspect.isclass(new_kwargs["response_format"]) and issubclass(
|
||||
new_kwargs["response_format"], BaseModel
|
||||
):
|
||||
# If response_format is a pydantic model, instantiate it
|
||||
model_initializer = cast(
|
||||
"type[BaseModel]", new_kwargs["response_format"]
|
||||
)
|
||||
parsed_response = model_initializer(**parsed_dict)
|
||||
|
||||
return LitellmModelResponse(
|
||||
output=LitellmModelOutput(
|
||||
content=response.choices[0].message.content or "" # type: ignore
|
||||
),
|
||||
parsed_response=parsed_response,
|
||||
history=messages,
|
||||
)
|
||||
|
||||
async def achat_stream(
|
||||
self, prompt: str, history: list | None = None, **kwargs: Any
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""
|
||||
Generate a response for the given prompt and history.
|
||||
|
||||
Args
|
||||
----
|
||||
prompt: The prompt to generate a response for.
|
||||
history: Optional chat history.
|
||||
**kwargs: Additional keyword arguments (e.g., model parameters).
|
||||
|
||||
Returns
|
||||
-------
|
||||
AsyncGenerator[str, None]: The generated response as a stream of strings.
|
||||
"""
|
||||
new_kwargs = self._get_kwargs(**kwargs)
|
||||
messages: list[dict[str, str]] = history or []
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
|
||||
response = await self.acompletion(messages=messages, stream=True, **new_kwargs) # type: ignore
|
||||
|
||||
async for chunk in response: # type: ignore
|
||||
if chunk.choices and chunk.choices[0].delta.content:
|
||||
yield chunk.choices[0].delta.content
|
||||
|
||||
def chat(self, prompt: str, history: list | None = None, **kwargs: Any) -> "MR":
|
||||
"""
|
||||
Generate a response for the given prompt and history.
|
||||
|
||||
Args
|
||||
----
|
||||
prompt: The prompt to generate a response for.
|
||||
history: Optional chat history.
|
||||
**kwargs: Additional keyword arguments (e.g., model parameters).
|
||||
|
||||
Returns
|
||||
-------
|
||||
LitellmModelResponse: The generated model response.
|
||||
"""
|
||||
new_kwargs = self._get_kwargs(**kwargs)
|
||||
messages: list[dict[str, str]] = history or []
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
|
||||
response = self.completion(messages=messages, stream=False, **new_kwargs) # type: ignore
|
||||
|
||||
messages.append({
|
||||
"role": "assistant",
|
||||
"content": response.choices[0].message.content or "", # type: ignore
|
||||
})
|
||||
|
||||
parsed_response: BaseModel | None = None
|
||||
if "response_format" in new_kwargs:
|
||||
parsed_dict: dict[str, Any] = json.loads(
|
||||
response.choices[0].message.content or "{}" # type: ignore
|
||||
)
|
||||
parsed_response = parsed_dict # type: ignore
|
||||
if inspect.isclass(new_kwargs["response_format"]) and issubclass(
|
||||
new_kwargs["response_format"], BaseModel
|
||||
):
|
||||
# If response_format is a pydantic model, instantiate it
|
||||
model_initializer = cast(
|
||||
"type[BaseModel]", new_kwargs["response_format"]
|
||||
)
|
||||
parsed_response = model_initializer(**parsed_dict)
|
||||
|
||||
return LitellmModelResponse(
|
||||
output=LitellmModelOutput(
|
||||
content=response.choices[0].message.content or "" # type: ignore
|
||||
),
|
||||
parsed_response=parsed_response,
|
||||
history=messages,
|
||||
)
|
||||
|
||||
def chat_stream(
|
||||
self, prompt: str, history: list | None = None, **kwargs: Any
|
||||
) -> Generator[str, None]:
|
||||
"""
|
||||
Generate a response for the given prompt and history.
|
||||
|
||||
Args
|
||||
----
|
||||
prompt: The prompt to generate a response for.
|
||||
history: Optional chat history.
|
||||
**kwargs: Additional keyword arguments (e.g., model parameters).
|
||||
|
||||
Returns
|
||||
-------
|
||||
Generator[str, None]: The generated response as a stream of strings.
|
||||
"""
|
||||
new_kwargs = self._get_kwargs(**kwargs)
|
||||
messages: list[dict[str, str]] = history or []
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
|
||||
response = self.completion(messages=messages, stream=True, **new_kwargs) # type: ignore
|
||||
|
||||
for chunk in response:
|
||||
if chunk.choices and chunk.choices[0].delta.content: # type: ignore
|
||||
yield chunk.choices[0].delta.content # type: ignore
|
||||
280
graphrag/language_model/providers/litellm/embedding_model.py
Normal file
280
graphrag/language_model/providers/litellm/embedding_model.py
Normal file
@ -0,0 +1,280 @@
|
||||
# Copyright (c) 2025 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""Embedding model implementation using Litellm."""
|
||||
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import litellm
|
||||
from azure.identity import DefaultAzureCredential, get_bearer_token_provider
|
||||
from litellm import (
|
||||
EmbeddingResponse, # type: ignore
|
||||
aembedding,
|
||||
embedding,
|
||||
)
|
||||
|
||||
from graphrag.config.defaults import COGNITIVE_SERVICES_AUDIENCE
|
||||
from graphrag.config.enums import AuthType
|
||||
from graphrag.language_model.providers.litellm.request_wrappers.with_cache import (
|
||||
with_cache,
|
||||
)
|
||||
from graphrag.language_model.providers.litellm.request_wrappers.with_logging import (
|
||||
with_logging,
|
||||
)
|
||||
from graphrag.language_model.providers.litellm.request_wrappers.with_rate_limiter import (
|
||||
with_rate_limiter,
|
||||
)
|
||||
from graphrag.language_model.providers.litellm.request_wrappers.with_retries import (
|
||||
with_retries,
|
||||
)
|
||||
from graphrag.language_model.providers.litellm.types import (
|
||||
AFixedModelEmbedding,
|
||||
FixedModelEmbedding,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from graphrag.cache.pipeline_cache import PipelineCache
|
||||
from graphrag.config.models.language_model_config import LanguageModelConfig
|
||||
|
||||
litellm.suppress_debug_info = True
|
||||
|
||||
|
||||
def _create_base_embeddings(
|
||||
model_config: "LanguageModelConfig",
|
||||
) -> tuple[FixedModelEmbedding, AFixedModelEmbedding]:
|
||||
"""Wrap the base litellm embedding function with the model configuration.
|
||||
|
||||
Args
|
||||
----
|
||||
model_config: The configuration for the language model.
|
||||
|
||||
Returns
|
||||
-------
|
||||
A tuple containing the synchronous and asynchronous embedding functions.
|
||||
"""
|
||||
model_provider = model_config.model_provider
|
||||
model = model_config.deployment_name or model_config.model
|
||||
|
||||
base_args: dict[str, Any] = {
|
||||
"drop_params": True, # LiteLLM drop unsupported params for selected model.
|
||||
"model": f"{model_provider}/{model}",
|
||||
"timeout": model_config.request_timeout,
|
||||
"api_base": model_config.api_base,
|
||||
"api_version": model_config.api_version,
|
||||
"api_key": model_config.api_key,
|
||||
"organization": model_config.organization,
|
||||
"proxy": model_config.proxy,
|
||||
"audience": model_config.audience,
|
||||
}
|
||||
|
||||
if model_config.auth_type == AuthType.AzureManagedIdentity:
|
||||
if model_config.model_provider != "azure":
|
||||
msg = "Azure Managed Identity authentication is only supported for Azure models."
|
||||
raise ValueError(msg)
|
||||
|
||||
base_args["azure_scope"] = base_args.pop("audience")
|
||||
base_args["azure_ad_token_provider"] = get_bearer_token_provider(
|
||||
DefaultAzureCredential(),
|
||||
model_config.audience or COGNITIVE_SERVICES_AUDIENCE,
|
||||
)
|
||||
|
||||
def _base_embedding(**kwargs: Any) -> EmbeddingResponse:
|
||||
new_args = {**base_args, **kwargs}
|
||||
|
||||
if "name" in new_args:
|
||||
new_args.pop("name")
|
||||
|
||||
return embedding(**new_args)
|
||||
|
||||
async def _base_aembedding(**kwargs: Any) -> EmbeddingResponse:
|
||||
new_args = {**base_args, **kwargs}
|
||||
|
||||
if "name" in new_args:
|
||||
new_args.pop("name")
|
||||
|
||||
return await aembedding(**new_args)
|
||||
|
||||
return (_base_embedding, _base_aembedding)
|
||||
|
||||
|
||||
def _create_embeddings(
|
||||
model_config: "LanguageModelConfig",
|
||||
cache: "PipelineCache | None",
|
||||
cache_key_prefix: str,
|
||||
) -> tuple[FixedModelEmbedding, AFixedModelEmbedding]:
|
||||
"""Wrap the base litellm embedding function with the model configuration and additional features.
|
||||
|
||||
Wrap the base litellm embedding function with instance variables based on the model configuration.
|
||||
Then wrap additional features such as rate limiting, retries, and caching, if enabled.
|
||||
|
||||
Final function composition order:
|
||||
- Logging(Cache(Retries(RateLimiter(ModelEmbedding()))))
|
||||
|
||||
Args
|
||||
----
|
||||
model_config: The configuration for the language model.
|
||||
cache: Optional cache for storing responses.
|
||||
cache_key_prefix: Prefix for cache keys.
|
||||
|
||||
Returns
|
||||
-------
|
||||
A tuple containing the synchronous and asynchronous embedding functions.
|
||||
|
||||
"""
|
||||
embedding, aembedding = _create_base_embeddings(model_config)
|
||||
|
||||
# TODO: For v2.x release, rpm/tpm can be int or str (auto) for backwards compatibility with fnllm.
|
||||
# LiteLLM does not support "auto", so we have to check those values here.
|
||||
# For v3 release, force rpm/tpm to be int and remove the type checks below
|
||||
# and just check if rate_limit_strategy is enabled.
|
||||
if model_config.rate_limit_strategy is not None:
|
||||
rpm = (
|
||||
model_config.requests_per_minute
|
||||
if type(model_config.requests_per_minute) is int
|
||||
else None
|
||||
)
|
||||
tpm = (
|
||||
model_config.tokens_per_minute
|
||||
if type(model_config.tokens_per_minute) is int
|
||||
else None
|
||||
)
|
||||
if rpm is not None or tpm is not None:
|
||||
embedding, aembedding = with_rate_limiter(
|
||||
sync_fn=embedding,
|
||||
async_fn=aembedding,
|
||||
model_config=model_config,
|
||||
rpm=rpm,
|
||||
tpm=tpm,
|
||||
)
|
||||
|
||||
if model_config.retry_strategy != "none":
|
||||
embedding, aembedding = with_retries(
|
||||
sync_fn=embedding,
|
||||
async_fn=aembedding,
|
||||
model_config=model_config,
|
||||
)
|
||||
|
||||
if cache is not None:
|
||||
embedding, aembedding = with_cache(
|
||||
sync_fn=embedding,
|
||||
async_fn=aembedding,
|
||||
model_config=model_config,
|
||||
cache=cache,
|
||||
request_type="embedding",
|
||||
cache_key_prefix=cache_key_prefix,
|
||||
)
|
||||
|
||||
embedding, aembedding = with_logging(
|
||||
sync_fn=embedding,
|
||||
async_fn=aembedding,
|
||||
)
|
||||
|
||||
return (embedding, aembedding)
|
||||
|
||||
|
||||
class LitellmEmbeddingModel:
|
||||
"""LiteLLM-based Embedding Model."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
config: "LanguageModelConfig",
|
||||
cache: "PipelineCache | None" = None,
|
||||
**kwargs: Any,
|
||||
):
|
||||
self.name = name
|
||||
self.config = config
|
||||
self.cache = cache.child(self.name) if cache else None
|
||||
self.embedding, self.aembedding = _create_embeddings(
|
||||
config, self.cache, "embeddings"
|
||||
)
|
||||
|
||||
def _get_kwargs(self, **kwargs: Any) -> dict[str, Any]:
|
||||
"""Get model arguments supported by litellm."""
|
||||
args_to_include = [
|
||||
"name",
|
||||
"dimensions",
|
||||
"encoding_format",
|
||||
"timeout",
|
||||
"user",
|
||||
]
|
||||
return {k: v for k, v in kwargs.items() if k in args_to_include}
|
||||
|
||||
async def aembed_batch(
|
||||
self, text_list: list[str], **kwargs: Any
|
||||
) -> list[list[float]]:
|
||||
"""
|
||||
Batch generate embeddings.
|
||||
|
||||
Args
|
||||
----
|
||||
text_list: A batch of text inputs to generate embeddings for.
|
||||
**kwargs: Additional keyword arguments (e.g., model parameters).
|
||||
|
||||
Returns
|
||||
-------
|
||||
A Batch of embeddings.
|
||||
"""
|
||||
new_kwargs = self._get_kwargs(**kwargs)
|
||||
response = await self.aembedding(input=text_list, **new_kwargs)
|
||||
|
||||
return [emb.get("embedding", []) for emb in response.data]
|
||||
|
||||
async def aembed(self, text: str, **kwargs: Any) -> list[float]:
|
||||
"""
|
||||
Async embed.
|
||||
|
||||
Args:
|
||||
text: The text to generate an embedding for.
|
||||
**kwargs: Additional keyword arguments (e.g., model parameters).
|
||||
|
||||
Returns
|
||||
-------
|
||||
An embedding.
|
||||
"""
|
||||
new_kwargs = self._get_kwargs(**kwargs)
|
||||
response = await self.aembedding(input=[text], **new_kwargs)
|
||||
|
||||
return (
|
||||
response.data[0].get("embedding", [])
|
||||
if response.data and response.data[0]
|
||||
else []
|
||||
)
|
||||
|
||||
def embed_batch(self, text_list: list[str], **kwargs: Any) -> list[list[float]]:
|
||||
"""
|
||||
Batch generate embeddings.
|
||||
|
||||
Args:
|
||||
text_list: A batch of text inputs to generate embeddings for.
|
||||
**kwargs: Additional keyword arguments (e.g., model parameters).
|
||||
|
||||
Returns
|
||||
-------
|
||||
A Batch of embeddings.
|
||||
"""
|
||||
new_kwargs = self._get_kwargs(**kwargs)
|
||||
response = self.embedding(input=text_list, **new_kwargs)
|
||||
|
||||
return [emb.get("embedding", []) for emb in response.data]
|
||||
|
||||
def embed(self, text: str, **kwargs: Any) -> list[float]:
|
||||
"""
|
||||
Embed a single text input.
|
||||
|
||||
Args:
|
||||
text: The text to generate an embedding for.
|
||||
**kwargs: Additional keyword arguments (e.g., model parameters).
|
||||
|
||||
Returns
|
||||
-------
|
||||
An embedding.
|
||||
"""
|
||||
new_kwargs = self._get_kwargs(**kwargs)
|
||||
response = self.embedding(input=[text], **new_kwargs)
|
||||
|
||||
return (
|
||||
response.data[0].get("embedding", [])
|
||||
if response.data and response.data[0]
|
||||
else []
|
||||
)
|
||||
140
graphrag/language_model/providers/litellm/get_cache_key.py
Normal file
140
graphrag/language_model/providers/litellm/get_cache_key.py
Normal file
@ -0,0 +1,140 @@
|
||||
# Copyright (c) 2025 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""
|
||||
LiteLLM cache key generation.
|
||||
|
||||
Modeled after the fnllm cache key generation.
|
||||
https://github.com/microsoft/essex-toolkit/blob/23d3077b65c0e8f1d89c397a2968fe570a25f790/python/fnllm/fnllm/caching/base.py#L50
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import inspect
|
||||
import json
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from graphrag.config.models.language_model_config import LanguageModelConfig
|
||||
|
||||
|
||||
_CACHE_VERSION = 3
|
||||
"""
|
||||
If there's a breaking change in what we cache, we should increment this version number to invalidate existing caches.
|
||||
|
||||
fnllm was on cache version 2 and though we generate
|
||||
similar cache keys, the objects stored in cache by fnllm and litellm are different.
|
||||
Using litellm model providers will not be able to reuse caches generated by fnllm
|
||||
thus we start with version 3 for litellm.
|
||||
"""
|
||||
|
||||
|
||||
def get_cache_key(
|
||||
model_config: "LanguageModelConfig",
|
||||
prefix: str,
|
||||
messages: str | None = None,
|
||||
input: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Generate a cache key based on the model configuration and input arguments.
|
||||
|
||||
Modeled after the fnllm cache key generation.
|
||||
https://github.com/microsoft/essex-toolkit/blob/23d3077b65c0e8f1d89c397a2968fe570a25f790/python/fnllm/fnllm/caching/base.py#L50
|
||||
|
||||
Args
|
||||
____
|
||||
model_config: The configuration of the language model.
|
||||
prefix: A prefix for the cache key.
|
||||
**kwargs: Additional model input parameters.
|
||||
|
||||
Returns
|
||||
-------
|
||||
`{prefix}_{data_hash}_v{version}` if prefix is provided.
|
||||
"""
|
||||
cache_key: dict[str, Any] = {
|
||||
"parameters": _get_parameters(model_config, **kwargs),
|
||||
}
|
||||
|
||||
if messages is not None and input is not None:
|
||||
msg = "Only one of 'messages' or 'input' should be provided."
|
||||
raise ValueError(msg)
|
||||
|
||||
if messages is not None:
|
||||
cache_key["messages"] = messages
|
||||
elif input is not None:
|
||||
cache_key["input"] = input
|
||||
else:
|
||||
msg = "Either 'messages' or 'input' must be provided."
|
||||
raise ValueError(msg)
|
||||
|
||||
data_hash = _hash(json.dumps(cache_key, sort_keys=True))
|
||||
|
||||
name = kwargs.get("name")
|
||||
|
||||
if name:
|
||||
prefix += f"_{name}"
|
||||
|
||||
return f"{prefix}_{data_hash}_v{_CACHE_VERSION}"
|
||||
|
||||
|
||||
def _get_parameters(
|
||||
model_config: "LanguageModelConfig",
|
||||
**kwargs: Any,
|
||||
) -> dict[str, Any]:
|
||||
"""Pluck out the parameters that define a cache key.
|
||||
|
||||
Use the same parameters as fnllm except request timeout.
|
||||
- embeddings: https://github.com/microsoft/essex-toolkit/blob/main/python/fnllm/fnllm/openai/types/embeddings/parameters.py#L12
|
||||
- chat: https://github.com/microsoft/essex-toolkit/blob/main/python/fnllm/fnllm/openai/types/chat/parameters.py#L25
|
||||
|
||||
Args
|
||||
____
|
||||
model_config: The configuration of the language model.
|
||||
**kwargs: Additional model input parameters.
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict[str, Any]: A dictionary of parameters that define the cache key.
|
||||
"""
|
||||
parameters = {
|
||||
"model": model_config.deployment_name or model_config.model,
|
||||
"frequency_penalty": model_config.frequency_penalty,
|
||||
"max_tokens": model_config.max_tokens,
|
||||
"max_completion_tokens": model_config.max_completion_tokens,
|
||||
"n": model_config.n,
|
||||
"presence_penalty": model_config.presence_penalty,
|
||||
"temperature": model_config.temperature,
|
||||
"top_p": model_config.top_p,
|
||||
"reasoning_effort": model_config.reasoning_effort,
|
||||
}
|
||||
keys_to_cache = [
|
||||
"function_call",
|
||||
"functions",
|
||||
"logit_bias",
|
||||
"logprobs",
|
||||
"parallel_tool_calls",
|
||||
"seed",
|
||||
"service_tier",
|
||||
"stop",
|
||||
"tool_choice",
|
||||
"tools",
|
||||
"top_logprobs",
|
||||
"user",
|
||||
"dimensions",
|
||||
"encoding_format",
|
||||
]
|
||||
parameters.update({key: kwargs.get(key) for key in keys_to_cache if key in kwargs})
|
||||
|
||||
response_format = kwargs.get("response_format")
|
||||
if inspect.isclass(response_format) and issubclass(response_format, BaseModel):
|
||||
parameters["response_format"] = str(response_format)
|
||||
elif response_format is not None:
|
||||
parameters["response_format"] = response_format
|
||||
|
||||
return parameters
|
||||
|
||||
|
||||
def _hash(input: str) -> str:
|
||||
"""Generate a hash for the input string."""
|
||||
return hashlib.sha256(input.encode()).hexdigest()
|
||||
@ -0,0 +1,4 @@
|
||||
# Copyright (c) 2025 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""LiteLLM completion/embedding function wrappers."""
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user