mirror of
https://github.com/microsoft/graphrag.git
synced 2026-01-14 00:57:23 +08:00
Compare commits
29 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
fdb7e3835b | ||
|
|
ac8a7f5eef | ||
|
|
6c86b0a7bb | ||
|
|
2bd3922d8d | ||
|
|
7f996cf584 | ||
|
|
9bc899fe95 | ||
|
|
2b70e4a4f3 | ||
|
|
82cd3b7df2 | ||
|
|
075cadd59a | ||
|
|
6d7a50b7f0 | ||
|
|
2bf7e7c018 | ||
|
|
6c66b7c30f | ||
|
|
a398cc38bb | ||
|
|
ac95c917d3 | ||
|
|
1cb20b66f5 | ||
|
|
2030f94eb4 | ||
|
|
69ad36e735 | ||
|
|
30bdb35cc8 | ||
|
|
77fb7d9d7d | ||
|
|
469ee8568f | ||
|
|
7c28c70d5c | ||
|
|
5713205210 | ||
|
|
1da1380615 | ||
|
|
dce02563eb | ||
|
|
13bf315a35 | ||
|
|
e84df28e64 | ||
|
|
27c6de846f | ||
|
|
1df89727c3 | ||
|
|
17e431cf42 |
17
.github/workflows/gh-pages.yml
vendored
17
.github/workflows/gh-pages.yml
vendored
@ -6,8 +6,7 @@ permissions:
|
||||
contents: write
|
||||
|
||||
env:
|
||||
POETRY_VERSION: '1.8.3'
|
||||
PYTHON_VERSION: '3.11'
|
||||
PYTHON_VERSION: "3.11"
|
||||
|
||||
jobs:
|
||||
build:
|
||||
@ -16,8 +15,6 @@ jobs:
|
||||
GH_PAGES: 1
|
||||
DEBUG: 1
|
||||
GRAPHRAG_API_KEY: ${{ secrets.GRAPHRAG_API_KEY }}
|
||||
GRAPHRAG_LLM_MODEL: ${{ secrets.GRAPHRAG_LLM_MODEL }}
|
||||
GRAPHRAG_EMBEDDING_MODEL: ${{ secrets.GRAPHRAG_EMBEDDING_MODEL }}
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
@ -29,18 +26,16 @@ jobs:
|
||||
with:
|
||||
python-version: ${{ env.PYTHON_VERSION }}
|
||||
|
||||
- name: Install Poetry ${{ env.POETRY_VERSION }}
|
||||
uses: abatilo/actions-poetry@v3.0.0
|
||||
with:
|
||||
poetry-version: ${{ env.POETRY_VERSION }}
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v6
|
||||
|
||||
- name: poetry intsall
|
||||
- name: Install dependencies
|
||||
shell: bash
|
||||
run: poetry install
|
||||
run: uv sync
|
||||
|
||||
- name: mkdocs build
|
||||
shell: bash
|
||||
run: poetry run poe build_docs
|
||||
run: uv run poe build_docs
|
||||
|
||||
- name: List Docsite Contents
|
||||
run: find site
|
||||
|
||||
24
.github/workflows/python-ci.yml
vendored
24
.github/workflows/python-ci.yml
vendored
@ -26,9 +26,6 @@ concurrency:
|
||||
# Only run the for the latest commit
|
||||
cancel-in-progress: true
|
||||
|
||||
env:
|
||||
POETRY_VERSION: 1.8.3
|
||||
|
||||
jobs:
|
||||
python-ci:
|
||||
# skip draft PRs
|
||||
@ -51,7 +48,7 @@ jobs:
|
||||
filters: |
|
||||
python:
|
||||
- 'graphrag/**/*'
|
||||
- 'poetry.lock'
|
||||
- 'uv.lock'
|
||||
- 'pyproject.toml'
|
||||
- '**/*.py'
|
||||
- '**/*.toml'
|
||||
@ -64,30 +61,27 @@ jobs:
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
- name: Install Poetry
|
||||
uses: abatilo/actions-poetry@v3.0.0
|
||||
with:
|
||||
poetry-version: $POETRY_VERSION
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v6
|
||||
|
||||
- name: Install dependencies
|
||||
shell: bash
|
||||
run: |
|
||||
poetry self add setuptools wheel
|
||||
poetry run python -m pip install gensim
|
||||
poetry install
|
||||
uv sync
|
||||
uv pip install gensim
|
||||
|
||||
- name: Check
|
||||
run: |
|
||||
poetry run poe check
|
||||
uv run poe check
|
||||
|
||||
- name: Build
|
||||
run: |
|
||||
poetry build
|
||||
uv build
|
||||
|
||||
- name: Unit Test
|
||||
run: |
|
||||
poetry run poe test_unit
|
||||
uv run poe test_unit
|
||||
|
||||
- name: Verb Test
|
||||
run: |
|
||||
poetry run poe test_verbs
|
||||
uv run poe test_verbs
|
||||
|
||||
28
.github/workflows/python-integration-tests.yml
vendored
28
.github/workflows/python-integration-tests.yml
vendored
@ -26,9 +26,6 @@ concurrency:
|
||||
# only run the for the latest commit
|
||||
cancel-in-progress: true
|
||||
|
||||
env:
|
||||
POETRY_VERSION: 1.8.3
|
||||
|
||||
jobs:
|
||||
python-ci:
|
||||
# skip draft PRs
|
||||
@ -51,7 +48,7 @@ jobs:
|
||||
filters: |
|
||||
python:
|
||||
- 'graphrag/**/*'
|
||||
- 'poetry.lock'
|
||||
- 'uv.lock'
|
||||
- 'pyproject.toml'
|
||||
- '**/*.py'
|
||||
- '**/*.toml'
|
||||
@ -64,25 +61,24 @@ jobs:
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
- name: Install Poetry
|
||||
uses: abatilo/actions-poetry@v3.0.0
|
||||
with:
|
||||
poetry-version: $POETRY_VERSION
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v6
|
||||
|
||||
- name: Install dependencies
|
||||
shell: bash
|
||||
run: |
|
||||
poetry self add setuptools wheel
|
||||
poetry run python -m pip install gensim
|
||||
poetry install
|
||||
uv sync
|
||||
uv pip install gensim
|
||||
|
||||
- name: Build
|
||||
run: |
|
||||
poetry build
|
||||
uv build
|
||||
|
||||
- name: Install Azurite
|
||||
id: azuright
|
||||
uses: potatoqualitee/azuright@v1.1
|
||||
- name: Install and start Azurite
|
||||
shell: bash
|
||||
run: |
|
||||
npm install -g azurite
|
||||
azurite --silent --skipApiVersionCheck --location /tmp/azurite --debug /tmp/azurite-debug.log &
|
||||
|
||||
# For more information on installation/setup of Azure Cosmos DB Emulator
|
||||
# https://learn.microsoft.com/en-us/azure/cosmos-db/how-to-develop-emulator?tabs=docker-linux%2Cpython&pivots=api-nosql
|
||||
@ -97,4 +93,4 @@ jobs:
|
||||
|
||||
- name: Integration Test
|
||||
run: |
|
||||
poetry run poe test_integration
|
||||
uv run poe test_integration
|
||||
|
||||
20
.github/workflows/python-notebook-tests.yml
vendored
20
.github/workflows/python-notebook-tests.yml
vendored
@ -26,9 +26,6 @@ concurrency:
|
||||
# Only run the for the latest commit
|
||||
cancel-in-progress: true
|
||||
|
||||
env:
|
||||
POETRY_VERSION: 1.8.3
|
||||
|
||||
jobs:
|
||||
python-ci:
|
||||
# skip draft PRs
|
||||
@ -41,8 +38,6 @@ jobs:
|
||||
env:
|
||||
DEBUG: 1
|
||||
GRAPHRAG_API_KEY: ${{ secrets.OPENAI_NOTEBOOK_KEY }}
|
||||
GRAPHRAG_LLM_MODEL: ${{ secrets.GRAPHRAG_LLM_MODEL }}
|
||||
GRAPHRAG_EMBEDDING_MODEL: ${{ secrets.GRAPHRAG_EMBEDDING_MODEL }}
|
||||
|
||||
runs-on: ${{ matrix.os }}
|
||||
steps:
|
||||
@ -54,7 +49,7 @@ jobs:
|
||||
filters: |
|
||||
python:
|
||||
- 'graphrag/**/*'
|
||||
- 'poetry.lock'
|
||||
- 'uv.lock'
|
||||
- 'pyproject.toml'
|
||||
- '**/*.py'
|
||||
- '**/*.toml'
|
||||
@ -66,18 +61,15 @@ jobs:
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
- name: Install Poetry
|
||||
uses: abatilo/actions-poetry@v3.0.0
|
||||
with:
|
||||
poetry-version: $POETRY_VERSION
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v6
|
||||
|
||||
- name: Install dependencies
|
||||
shell: bash
|
||||
run: |
|
||||
poetry self add setuptools wheel
|
||||
poetry run python -m pip install gensim
|
||||
poetry install
|
||||
uv sync
|
||||
uv pip install gensim
|
||||
|
||||
- name: Notebook Test
|
||||
run: |
|
||||
poetry run poe test_notebook
|
||||
uv run poe test_notebook
|
||||
|
||||
13
.github/workflows/python-publish.yml
vendored
13
.github/workflows/python-publish.yml
vendored
@ -6,7 +6,6 @@ on:
|
||||
branches: [main]
|
||||
|
||||
env:
|
||||
POETRY_VERSION: "1.8.3"
|
||||
PYTHON_VERSION: "3.10"
|
||||
|
||||
jobs:
|
||||
@ -31,21 +30,19 @@ jobs:
|
||||
with:
|
||||
python-version: ${{ env.PYTHON_VERSION }}
|
||||
|
||||
- name: Install Poetry
|
||||
uses: abatilo/actions-poetry@v3.0.0
|
||||
with:
|
||||
poetry-version: ${{ env.POETRY_VERSION }}
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v6
|
||||
|
||||
- name: Install dependencies
|
||||
shell: bash
|
||||
run: poetry install
|
||||
run: uv sync
|
||||
|
||||
- name: Export Publication Version
|
||||
run: echo "version=`poetry version --short`" >> $GITHUB_OUTPUT
|
||||
run: echo "version=$(uv version --short)" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Build Distributable
|
||||
shell: bash
|
||||
run: poetry build
|
||||
run: uv build
|
||||
|
||||
- name: Publish package distributions to PyPI
|
||||
uses: pypa/gh-action-pypi-publish@release/v1
|
||||
|
||||
42
.github/workflows/python-smoke-tests.yml
vendored
42
.github/workflows/python-smoke-tests.yml
vendored
@ -26,9 +26,6 @@ concurrency:
|
||||
# Only run the for the latest commit
|
||||
cancel-in-progress: true
|
||||
|
||||
env:
|
||||
POETRY_VERSION: 1.8.3
|
||||
|
||||
jobs:
|
||||
python-ci:
|
||||
# skip draft PRs
|
||||
@ -40,20 +37,8 @@ jobs:
|
||||
fail-fast: false # Continue running all jobs even if one fails
|
||||
env:
|
||||
DEBUG: 1
|
||||
GRAPHRAG_LLM_TYPE: "azure_openai_chat"
|
||||
GRAPHRAG_EMBEDDING_TYPE: "azure_openai_embedding"
|
||||
GRAPHRAG_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
GRAPHRAG_API_BASE: ${{ secrets.GRAPHRAG_API_BASE }}
|
||||
GRAPHRAG_API_VERSION: ${{ secrets.GRAPHRAG_API_VERSION }}
|
||||
GRAPHRAG_LLM_DEPLOYMENT_NAME: ${{ secrets.GRAPHRAG_LLM_DEPLOYMENT_NAME }}
|
||||
GRAPHRAG_EMBEDDING_DEPLOYMENT_NAME: ${{ secrets.GRAPHRAG_EMBEDDING_DEPLOYMENT_NAME }}
|
||||
GRAPHRAG_LLM_MODEL: ${{ secrets.GRAPHRAG_LLM_MODEL }}
|
||||
GRAPHRAG_EMBEDDING_MODEL: ${{ secrets.GRAPHRAG_EMBEDDING_MODEL }}
|
||||
# We have Windows + Linux runners in 3.10, so we need to divide the rate limits by 2
|
||||
GRAPHRAG_LLM_TPM: 200_000 # 400_000 / 2
|
||||
GRAPHRAG_LLM_RPM: 1_000 # 2_000 / 2
|
||||
GRAPHRAG_EMBEDDING_TPM: 225_000 # 450_000 / 2
|
||||
GRAPHRAG_EMBEDDING_RPM: 1_000 # 2_000 / 2
|
||||
# Azure AI Search config
|
||||
AZURE_AI_SEARCH_URL_ENDPOINT: ${{ secrets.AZURE_AI_SEARCH_URL_ENDPOINT }}
|
||||
AZURE_AI_SEARCH_API_KEY: ${{ secrets.AZURE_AI_SEARCH_API_KEY }}
|
||||
@ -68,7 +53,7 @@ jobs:
|
||||
filters: |
|
||||
python:
|
||||
- 'graphrag/**/*'
|
||||
- 'poetry.lock'
|
||||
- 'uv.lock'
|
||||
- 'pyproject.toml'
|
||||
- '**/*.py'
|
||||
- '**/*.toml'
|
||||
@ -81,33 +66,32 @@ jobs:
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
- name: Install Poetry
|
||||
uses: abatilo/actions-poetry@v3.0.0
|
||||
with:
|
||||
poetry-version: $POETRY_VERSION
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v6
|
||||
|
||||
- name: Install dependencies
|
||||
shell: bash
|
||||
run: |
|
||||
poetry self add setuptools wheel
|
||||
poetry run python -m pip install gensim
|
||||
poetry install
|
||||
uv sync
|
||||
uv pip install gensim
|
||||
|
||||
- name: Build
|
||||
run: |
|
||||
poetry build
|
||||
uv build
|
||||
|
||||
- name: Install Azurite
|
||||
id: azuright
|
||||
uses: potatoqualitee/azuright@v1.1
|
||||
- name: Install and start Azurite
|
||||
shell: bash
|
||||
run: |
|
||||
npm install -g azurite
|
||||
azurite --silent --skipApiVersionCheck --location /tmp/azurite --debug /tmp/azurite-debug.log &
|
||||
|
||||
- name: Smoke Test
|
||||
if: steps.changes.outputs.python == 'true'
|
||||
run: |
|
||||
poetry run poe test_smoke
|
||||
uv run poe test_smoke
|
||||
|
||||
- uses: actions/upload-artifact@v4
|
||||
if: always()
|
||||
with:
|
||||
name: smoke-test-artifacts-${{ matrix.python-version }}-${{ matrix.poetry-version }}-${{ runner.os }}
|
||||
name: smoke-test-artifacts-${{ matrix.python-version }}-${{ runner.os }}
|
||||
path: tests/fixtures/*
|
||||
|
||||
2
.gitignore
vendored
2
.gitignore
vendored
@ -1,6 +1,8 @@
|
||||
# Python Artifacts
|
||||
python/*/lib/
|
||||
dist/
|
||||
build/
|
||||
*.egg-info/
|
||||
|
||||
# Test Output
|
||||
.coverage
|
||||
|
||||
26
.semversioner/2.4.0.json
Normal file
26
.semversioner/2.4.0.json
Normal file
@ -0,0 +1,26 @@
|
||||
{
|
||||
"changes": [
|
||||
{
|
||||
"description": "Allow injection of custom pipelines.",
|
||||
"type": "minor"
|
||||
},
|
||||
{
|
||||
"description": "Refactored StorageFactory to use a registration-based approach",
|
||||
"type": "minor"
|
||||
},
|
||||
{
|
||||
"description": "Fix default values for tpm and rpm limiters on embeddings",
|
||||
"type": "patch"
|
||||
},
|
||||
{
|
||||
"description": "Update typer.",
|
||||
"type": "patch"
|
||||
},
|
||||
{
|
||||
"description": "cleaned up logging to follow python standards.",
|
||||
"type": "patch"
|
||||
}
|
||||
],
|
||||
"created_at": "2025-07-15T00:04:15+00:00",
|
||||
"version": "2.4.0"
|
||||
}
|
||||
14
.semversioner/2.5.0.json
Normal file
14
.semversioner/2.5.0.json
Normal file
@ -0,0 +1,14 @@
|
||||
{
|
||||
"changes": [
|
||||
{
|
||||
"description": "Add additional context variable to build index signature for custom parameter bag",
|
||||
"type": "minor"
|
||||
},
|
||||
{
|
||||
"description": "swap package management from Poetry -> UV",
|
||||
"type": "minor"
|
||||
}
|
||||
],
|
||||
"created_at": "2025-08-14T00:59:46+00:00",
|
||||
"version": "2.5.0"
|
||||
}
|
||||
54
.semversioner/2.6.0.json
Normal file
54
.semversioner/2.6.0.json
Normal file
@ -0,0 +1,54 @@
|
||||
{
|
||||
"changes": [
|
||||
{
|
||||
"description": "Add LiteLLM chat and embedding model providers.",
|
||||
"type": "minor"
|
||||
},
|
||||
{
|
||||
"description": "Add LoggerFactory and clean up related API.",
|
||||
"type": "minor"
|
||||
},
|
||||
{
|
||||
"description": "Add config for NLP async mode.",
|
||||
"type": "minor"
|
||||
},
|
||||
{
|
||||
"description": "Add optional input documents to indexing API.",
|
||||
"type": "minor"
|
||||
},
|
||||
{
|
||||
"description": "add customization to vector store",
|
||||
"type": "minor"
|
||||
},
|
||||
{
|
||||
"description": "Add gpt-5 support by updating fnllm dependency.",
|
||||
"type": "patch"
|
||||
},
|
||||
{
|
||||
"description": "Fix all human_readable_id fields to be 0-based.",
|
||||
"type": "patch"
|
||||
},
|
||||
{
|
||||
"description": "Fix multi-index search.",
|
||||
"type": "patch"
|
||||
},
|
||||
{
|
||||
"description": "Improve upon recent logging refactor",
|
||||
"type": "patch"
|
||||
},
|
||||
{
|
||||
"description": "Make cache, storage, and vector_store factories consistent with similar registration support",
|
||||
"type": "patch"
|
||||
},
|
||||
{
|
||||
"description": "Remove hard-coded community rate limiter.",
|
||||
"type": "patch"
|
||||
},
|
||||
{
|
||||
"description": "generate_text_embeddings only loads tables if embedding field is specified.",
|
||||
"type": "patch"
|
||||
}
|
||||
],
|
||||
"created_at": "2025-09-22T21:44:51+00:00",
|
||||
"version": "2.6.0"
|
||||
}
|
||||
18
.semversioner/2.7.0.json
Normal file
18
.semversioner/2.7.0.json
Normal file
@ -0,0 +1,18 @@
|
||||
{
|
||||
"changes": [
|
||||
{
|
||||
"description": "Set LiteLLM as default in init_content.",
|
||||
"type": "minor"
|
||||
},
|
||||
{
|
||||
"description": "Fix Azure auth scope issue with LiteLLM.",
|
||||
"type": "patch"
|
||||
},
|
||||
{
|
||||
"description": "Housekeeping toward 2.7.",
|
||||
"type": "patch"
|
||||
}
|
||||
],
|
||||
"created_at": "2025-10-08T22:39:42+00:00",
|
||||
"version": "2.7.0"
|
||||
}
|
||||
57
.vscode/launch.json
vendored
57
.vscode/launch.json
vendored
@ -6,21 +6,24 @@
|
||||
"name": "Indexer",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "poetry",
|
||||
"module": "graphrag",
|
||||
"args": [
|
||||
"poe", "index",
|
||||
"--root", "<path_to_ragtest_root_demo>"
|
||||
"index",
|
||||
"--root",
|
||||
"<path_to_index_folder>"
|
||||
],
|
||||
"console": "integratedTerminal"
|
||||
},
|
||||
{
|
||||
"name": "Query",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "poetry",
|
||||
"module": "graphrag",
|
||||
"args": [
|
||||
"poe", "query",
|
||||
"--root", "<path_to_ragtest_root_demo>",
|
||||
"--method", "global",
|
||||
"query",
|
||||
"--root",
|
||||
"<path_to_index_folder>",
|
||||
"--method", "basic",
|
||||
"--query", "What are the top themes in this story",
|
||||
]
|
||||
},
|
||||
@ -28,12 +31,48 @@
|
||||
"name": "Prompt Tuning",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "poetry",
|
||||
"module": "uv",
|
||||
"args": [
|
||||
"poe", "prompt-tune",
|
||||
"--config",
|
||||
"<path_to_ragtest_root_demo>/settings.yaml",
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Debug Integration Pytest",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "pytest",
|
||||
"args": [
|
||||
"./tests/integration/vector_stores",
|
||||
"-k", "test_azure_ai_search"
|
||||
],
|
||||
"console": "integratedTerminal",
|
||||
"justMyCode": false
|
||||
},
|
||||
{
|
||||
"name": "Debug Verbs Pytest",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "pytest",
|
||||
"args": [
|
||||
"./tests/verbs",
|
||||
"-k", "test_generate_text_embeddings"
|
||||
],
|
||||
"console": "integratedTerminal",
|
||||
"justMyCode": false
|
||||
},
|
||||
{
|
||||
"name": "Debug Smoke Pytest",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "pytest",
|
||||
"args": [
|
||||
"./tests/smoke",
|
||||
"-k", "test_fixtures"
|
||||
],
|
||||
"console": "integratedTerminal",
|
||||
"justMyCode": false
|
||||
},
|
||||
]
|
||||
}
|
||||
37
.vscode/settings.json
vendored
37
.vscode/settings.json
vendored
@ -1,43 +1,8 @@
|
||||
{
|
||||
"search.exclude": {
|
||||
"**/.yarn": true,
|
||||
"**/.pnp.*": true
|
||||
},
|
||||
"editor.formatOnSave": false,
|
||||
"eslint.nodePath": ".yarn/sdks",
|
||||
"typescript.tsdk": ".yarn/sdks/typescript/lib",
|
||||
"typescript.enablePromptUseWorkspaceTsdk": true,
|
||||
"javascript.preferences.importModuleSpecifier": "relative",
|
||||
"javascript.preferences.importModuleSpecifierEnding": "js",
|
||||
"typescript.preferences.importModuleSpecifier": "relative",
|
||||
"typescript.preferences.importModuleSpecifierEnding": "js",
|
||||
"explorer.fileNesting.enabled": true,
|
||||
"explorer.fileNesting.patterns": {
|
||||
"*.ts": "${capture}.ts, ${capture}.hooks.ts, ${capture}.hooks.tsx, ${capture}.contexts.ts, ${capture}.stories.tsx, ${capture}.story.tsx, ${capture}.spec.tsx, ${capture}.base.ts, ${capture}.base.tsx, ${capture}.types.ts, ${capture}.styles.ts, ${capture}.styles.tsx, ${capture}.utils.ts, ${capture}.utils.tsx, ${capture}.constants.ts, ${capture}.module.scss, ${capture}.module.css, ${capture}.md",
|
||||
"*.js": "${capture}.js.map, ${capture}.min.js, ${capture}.d.ts",
|
||||
"*.jsx": "${capture}.js",
|
||||
"*.tsx": "${capture}.ts, ${capture}.hooks.ts, ${capture}.hooks.tsx, ${capture}.contexts.ts, ${capture}.stories.tsx, ${capture}.story.tsx, ${capture}.spec.tsx, ${capture}.base.ts, ${capture}.base.tsx, ${capture}.types.ts, ${capture}.styles.ts, ${capture}.styles.tsx, ${capture}.utils.ts, ${capture}.utils.tsx, ${capture}.constants.ts, ${capture}.module.scss, ${capture}.module.css, ${capture}.md, ${capture}.css",
|
||||
"tsconfig.json": "tsconfig.*.json",
|
||||
"package.json": "package-lock.json, turbo.json, tsconfig.json, rome.json, biome.json, .npmignore, dictionary.txt, cspell.config.yaml",
|
||||
"README.md": "*.md, LICENSE, CODEOWNERS",
|
||||
".eslintrc": ".eslintignore",
|
||||
".prettierrc": ".prettierignore",
|
||||
".gitattributes": ".gitignore",
|
||||
".yarnrc.yml": "yarn.lock, .pnp.*",
|
||||
"jest.config.js": "jest.setup.mjs",
|
||||
"pyproject.toml": "poetry.lock, poetry.toml, mkdocs.yaml",
|
||||
"cspell.config.yaml": "dictionary.txt"
|
||||
},
|
||||
"azureFunctions.postDeployTask": "npm install (functions)",
|
||||
"azureFunctions.projectLanguage": "TypeScript",
|
||||
"azureFunctions.projectRuntime": "~4",
|
||||
"debug.internalConsoleOptions": "neverOpen",
|
||||
"azureFunctions.preDeployTask": "npm prune (functions)",
|
||||
"appService.zipIgnorePattern": [
|
||||
"node_modules{,/**}",
|
||||
".vscode{,/**}"
|
||||
],
|
||||
"python.defaultInterpreterPath": "python/services/.venv/bin/python",
|
||||
"python.defaultInterpreterPath": "${workspaceRoot}/.venv/bin/python",
|
||||
"python.languageServer": "Pylance",
|
||||
"cSpell.customDictionaries": {
|
||||
"project-words": {
|
||||
|
||||
34
CHANGELOG.md
34
CHANGELOG.md
@ -1,6 +1,40 @@
|
||||
# Changelog
|
||||
Note: version releases in the 0.x.y range may introduce breaking changes.
|
||||
|
||||
## 2.7.0
|
||||
|
||||
- minor: Set LiteLLM as default in init_content.
|
||||
- patch: Fix Azure auth scope issue with LiteLLM.
|
||||
- patch: Housekeeping toward 2.7.
|
||||
|
||||
## 2.6.0
|
||||
|
||||
- minor: Add LiteLLM chat and embedding model providers.
|
||||
- minor: Add LoggerFactory and clean up related API.
|
||||
- minor: Add config for NLP async mode.
|
||||
- minor: Add optional input documents to indexing API.
|
||||
- minor: add customization to vector store
|
||||
- patch: Add gpt-5 support by updating fnllm dependency.
|
||||
- patch: Fix all human_readable_id fields to be 0-based.
|
||||
- patch: Fix multi-index search.
|
||||
- patch: Improve upon recent logging refactor
|
||||
- patch: Make cache, storage, and vector_store factories consistent with similar registration support
|
||||
- patch: Remove hard-coded community rate limiter.
|
||||
- patch: generate_text_embeddings only loads tables if embedding field is specified.
|
||||
|
||||
## 2.5.0
|
||||
|
||||
- minor: Add additional context variable to build index signature for custom parameter bag
|
||||
- minor: swap package management from Poetry -> UV
|
||||
|
||||
## 2.4.0
|
||||
|
||||
- minor: Allow injection of custom pipelines.
|
||||
- minor: Refactored StorageFactory to use a registration-based approach
|
||||
- patch: Fix default values for tpm and rpm limiters on embeddings
|
||||
- patch: Update typer.
|
||||
- patch: cleaned up logging to follow python standards.
|
||||
|
||||
## 2.3.0
|
||||
|
||||
- minor: Remove Dynamic Max Retries support. Refactor typer typing in cli interface
|
||||
|
||||
@ -22,7 +22,7 @@ or contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any addi
|
||||
2. Create a new branch for your contribution: `git checkout -b my-contribution`.
|
||||
3. Make your changes and ensure that the code passes all tests.
|
||||
4. Commit your changes: `git commit -m "Add my contribution"`.
|
||||
5. Create and commit a semver impact document by running `poetry run semversioner add-change -t <major|minor|patch> -d <description>`.
|
||||
5. Create and commit a semver impact document by running `uv run semversioner add-change -t <major|minor|patch> -d <description>`.
|
||||
6. Push your changes to your forked repository: `git push origin my-contribution`.
|
||||
7. Open a pull request to the main repository.
|
||||
|
||||
|
||||
@ -5,29 +5,29 @@
|
||||
| Name | Installation | Purpose |
|
||||
| ------------------- | ------------------------------------------------------------ | ----------------------------------------------------------------------------------- |
|
||||
| Python 3.10 or 3.11 | [Download](https://www.python.org/downloads/) | The library is Python-based. |
|
||||
| Poetry | [Instructions](https://python-poetry.org/docs/#installation) | Poetry is used for package management and virtualenv management in Python codebases |
|
||||
| uv | [Instructions](https://docs.astral.sh/uv/) | uv is used for package management and virtualenv management in Python codebases |
|
||||
|
||||
# Getting Started
|
||||
|
||||
## Install Dependencies
|
||||
```shell
|
||||
# install python dependencies
|
||||
poetry install
|
||||
uv sync
|
||||
```
|
||||
|
||||
## Execute the indexing engine
|
||||
```shell
|
||||
poetry run poe index <...args>
|
||||
uv run poe index <...args>
|
||||
```
|
||||
|
||||
## Execute prompt tuning
|
||||
```shell
|
||||
poetry run poe prompt_tune <...args>
|
||||
uv run poe prompt_tune <...args>
|
||||
```
|
||||
|
||||
## Execute Queries
|
||||
```shell
|
||||
poetry run poe query <...args>
|
||||
uv run poe query <...args>
|
||||
```
|
||||
|
||||
## Repository Structure
|
||||
@ -63,7 +63,7 @@ Where appropriate, the factories expose a registration method for users to provi
|
||||
|
||||
We use [semversioner](https://github.com/raulgomis/semversioner) to automate and enforce semantic versioning in the release process. Our CI/CD pipeline checks that all PR's include a json file generated by semversioner. When submitting a PR, please run:
|
||||
```shell
|
||||
poetry run semversioner add-change -t patch -d "<a small sentence describing changes made>."
|
||||
uv run semversioner add-change -t patch -d "<a small sentence describing changes made>."
|
||||
```
|
||||
|
||||
# Azurite
|
||||
@ -78,29 +78,29 @@ or by simply running `azurite` in the terminal if already installed globally. Se
|
||||
|
||||
# Lifecycle Scripts
|
||||
|
||||
Our Python package utilizes Poetry to manage dependencies and [poethepoet](https://pypi.org/project/poethepoet/) to manage custom build scripts.
|
||||
Our Python package utilizes uv to manage dependencies and [poethepoet](https://pypi.org/project/poethepoet/) to manage custom build scripts.
|
||||
|
||||
Available scripts are:
|
||||
- `poetry run poe index` - Run the Indexing CLI
|
||||
- `poetry run poe query` - Run the Query CLI
|
||||
- `poetry build` - This invokes `poetry build`, which will build a wheel file and other distributable artifacts.
|
||||
- `poetry run poe test` - This will execute all tests.
|
||||
- `poetry run poe test_unit` - This will execute unit tests.
|
||||
- `poetry run poe test_integration` - This will execute integration tests.
|
||||
- `poetry run poe test_smoke` - This will execute smoke tests.
|
||||
- `poetry run poe check` - This will perform a suite of static checks across the package, including:
|
||||
- `uv run poe index` - Run the Indexing CLI
|
||||
- `uv run poe query` - Run the Query CLI
|
||||
- `uv build` - This invokes `uv build`, which will build a wheel file and other distributable artifacts.
|
||||
- `uv run poe test` - This will execute all tests.
|
||||
- `uv run poe test_unit` - This will execute unit tests.
|
||||
- `uv run poe test_integration` - This will execute integration tests.
|
||||
- `uv run poe test_smoke` - This will execute smoke tests.
|
||||
- `uv run poe check` - This will perform a suite of static checks across the package, including:
|
||||
- formatting
|
||||
- documentation formatting
|
||||
- linting
|
||||
- security patterns
|
||||
- type-checking
|
||||
- `poetry run poe fix` - This will apply any available auto-fixes to the package. Usually this is just formatting fixes.
|
||||
- `poetry run poe fix_unsafe` - This will apply any available auto-fixes to the package, including those that may be unsafe.
|
||||
- `poetry run poe format` - Explicitly run the formatter across the package.
|
||||
- `uv run poe fix` - This will apply any available auto-fixes to the package. Usually this is just formatting fixes.
|
||||
- `uv run poe fix_unsafe` - This will apply any available auto-fixes to the package, including those that may be unsafe.
|
||||
- `uv run poe format` - Explicitly run the formatter across the package.
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### "RuntimeError: llvm-config failed executing, please point LLVM_CONFIG to the path for llvm-config" when running poetry install
|
||||
### "RuntimeError: llvm-config failed executing, please point LLVM_CONFIG to the path for llvm-config" when running uv sync
|
||||
|
||||
Make sure llvm-9 and llvm-9-dev are installed:
|
||||
|
||||
@ -110,13 +110,8 @@ and then in your bashrc, add
|
||||
|
||||
`export LLVM_CONFIG=/usr/bin/llvm-config-9`
|
||||
|
||||
### "numba/\_pymodule.h:6:10: fatal error: Python.h: No such file or directory" when running poetry install
|
||||
### "numba/\_pymodule.h:6:10: fatal error: Python.h: No such file or directory" when running uv sync
|
||||
|
||||
Make sure you have python3.10-dev installed or more generally `python<version>-dev`
|
||||
|
||||
`sudo apt-get install python3.10-dev`
|
||||
|
||||
### LLM call constantly exceeds TPM, RPM or time limits
|
||||
|
||||
`GRAPHRAG_LLM_THREAD_COUNT` and `GRAPHRAG_EMBEDDING_THREAD_COUNT` are both set to 50 by default. You can modify these values
|
||||
to reduce concurrency. Please refer to the [Configuration Documents](https://microsoft.github.io/graphrag/config/overview/)
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
# GraphRAG
|
||||
|
||||
👉 [Use the GraphRAG Accelerator solution](https://github.com/Azure-Samples/graphrag-accelerator) <br/>
|
||||
👉 [Microsoft Research Blog Post](https://www.microsoft.com/en-us/research/blog/graphrag-unlocking-llm-discovery-on-narrative-private-data/)<br/>
|
||||
👉 [Read the docs](https://microsoft.github.io/graphrag)<br/>
|
||||
👉 [GraphRAG Arxiv](https://arxiv.org/pdf/2404.16130)
|
||||
@ -28,7 +27,7 @@ To learn more about GraphRAG and how it can be used to enhance your LLM's abilit
|
||||
|
||||
## Quickstart
|
||||
|
||||
To get started with the GraphRAG system we recommend trying the [Solution Accelerator](https://github.com/Azure-Samples/graphrag-accelerator) package. This provides a user-friendly end-to-end experience with Azure resources.
|
||||
To get started with the GraphRAG system we recommend trying the [command line quickstart](https://microsoft.github.io/graphrag/get_started/).
|
||||
|
||||
## Repository Guidance
|
||||
|
||||
|
||||
@ -12,6 +12,12 @@ There are five surface areas that may be impacted on any given release. They are
|
||||
|
||||
> TL;DR: Always run `graphrag init --path [path] --force` between minor version bumps to ensure you have the latest config format. Run the provided migration notebook between major version bumps if you want to avoid re-indexing prior datasets. Note that this will overwrite your configuration and prompts, so backup if necessary.
|
||||
|
||||
# v2
|
||||
|
||||
Run the [migration notebook](./docs/examples_notebooks/index_migration_to_v2.ipynb) to convert older tables to the v2 format.
|
||||
|
||||
The v2 release renamed all of our index tables to simply name the items each table contains. The previous naming was a leftover requirement of our use of DataShaper, which is no longer necessary.
|
||||
|
||||
# v1
|
||||
|
||||
Run the [migration notebook](./docs/examples_notebooks/index_migration_to_v1.ipynb) to convert older tables to the v1 format.
|
||||
@ -27,7 +33,7 @@ All of the breaking changes listed below are accounted for in the four steps abo
|
||||
- Alignment of fields from `create_final_entities` (such as name -> title) with `create_final_nodes`, and removal of redundant content across these tables
|
||||
- Rename of `document.raw_content` to `document.text`
|
||||
- Rename of `entity.name` to `entity.title`
|
||||
- Rename `rank` to `combined_degree` in `create_final_relationships` and removal of `source_degree` and `target_degree`fields
|
||||
- Rename `rank` to `combined_degree` in `create_final_relationships` and removal of `source_degree` and `target_degree` fields
|
||||
- Fixed community tables to use a proper UUID for the `id` field, and retain `community` and `human_readable_id` for the short IDs
|
||||
- Removal of all embeddings columns from parquet files in favor of direct vector store writes
|
||||
|
||||
|
||||
@ -81,6 +81,7 @@ typer
|
||||
spacy
|
||||
kwargs
|
||||
ollama
|
||||
litellm
|
||||
|
||||
# Library Methods
|
||||
iterrows
|
||||
@ -102,6 +103,9 @@ itertuples
|
||||
isin
|
||||
nocache
|
||||
nbconvert
|
||||
levelno
|
||||
acompletion
|
||||
aembedding
|
||||
|
||||
# HTML
|
||||
nbsp
|
||||
@ -186,6 +190,7 @@ Verdantis's
|
||||
# English
|
||||
skippable
|
||||
upvote
|
||||
unconfigured
|
||||
|
||||
# Misc
|
||||
Arxiv
|
||||
|
||||
@ -4,7 +4,7 @@ As of version 1.3, GraphRAG no longer supports a full complement of pre-built en
|
||||
|
||||
The only standard environment variable we expect, and include in the default settings.yml, is `GRAPHRAG_API_KEY`. If you are already using a number of the previous GRAPHRAG_* environment variables, you can insert them with template syntax into settings.yml and they will be adopted.
|
||||
|
||||
> **The environment variables below are documented as an aid for migration, but they WILL NOT be read unless you use template syntax in your settings.yml.**
|
||||
> **The environment variables below are documented as an aid for migration, but they WILL NOT be read unless you use template syntax in your settings.yml. We also WILL NOT be updating this page as the main config object changes.**
|
||||
|
||||
---
|
||||
|
||||
@ -178,11 +178,11 @@ This section controls the cache mechanism used by the pipeline. This is used to
|
||||
|
||||
### Reporting
|
||||
|
||||
This section controls the reporting mechanism used by the pipeline, for common events and error messages. The default is to write reports to a file in the output directory. However, you can also choose to write reports to the console or to an Azure Blob Storage container.
|
||||
This section controls the reporting mechanism used by the pipeline, for common events and error messages. The default is to write reports to a file in the output directory. However, you can also choose to write reports to an Azure Blob Storage container.
|
||||
|
||||
| Parameter | Description | Type | Required or Optional | Default |
|
||||
| --------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------ | ----- | -------------------- | ------- |
|
||||
| `GRAPHRAG_REPORTING_TYPE` | The type of reporter to use. Options are `file`, `console`, or `blob` | `str` | optional | `file` |
|
||||
| `GRAPHRAG_REPORTING_TYPE` | The type of reporter to use. Options are `file` or `blob` | `str` | optional | `file` |
|
||||
| `GRAPHRAG_REPORTING_STORAGE_ACCOUNT_BLOB_URL` | The Azure Storage blob endpoint to use when in `blob` mode and using managed identity. Will have the format `https://<storage_account_name>.blob.core.windows.net` | `str` | optional | None |
|
||||
| `GRAPHRAG_REPORTING_CONNECTION_STRING` | The Azure Storage connection string to use when in `blob` mode. | `str` | optional | None |
|
||||
| `GRAPHRAG_REPORTING_CONTAINER_NAME` | The Azure Storage container name to use when in `blob` mode. | `str` | optional | None |
|
||||
|
||||
@ -8,9 +8,38 @@ GraphRAG was built and tested using OpenAI models, so this is the default model
|
||||
|
||||
GraphRAG also utilizes a language model wrapper library used by several projects within our team, called fnllm. fnllm provides two important functions for GraphRAG: rate limiting configuration to help us maximize throughput for large indexing jobs, and robust caching of API calls to minimize consumption on repeated indexes for testing, experimentation, or incremental ingest. fnllm uses the OpenAI Python SDK under the covers, so OpenAI-compliant endpoints are a base requirement out-of-the-box.
|
||||
|
||||
Starting with version 2.6.0, GraphRAG supports using [LiteLLM](https://docs.litellm.ai/) instead of fnllm for calling language models. LiteLLM provides support for 100+ models though it is important to note that when choosing a model it must support returning [structured outputs](https://openai.com/index/introducing-structured-outputs-in-the-api/) adhering to a [JSON schema](https://docs.litellm.ai/docs/completion/json_mode).
|
||||
|
||||
Example using LiteLLm as the language model tool for GraphRAG:
|
||||
|
||||
```yaml
|
||||
models:
|
||||
default_chat_model:
|
||||
type: chat
|
||||
auth_type: api_key
|
||||
api_key: ${GEMINI_API_KEY}
|
||||
model_provider: gemini
|
||||
model: gemini-2.5-flash-lite
|
||||
default_embedding_model:
|
||||
type: embedding
|
||||
auth_type: api_key
|
||||
api_key: ${GEMINI_API_KEY}
|
||||
model_provider: gemini
|
||||
model: gemini-embedding-001
|
||||
```
|
||||
|
||||
To use LiteLLM one must
|
||||
|
||||
- Set `type` to either `chat` or `embedding`.
|
||||
- Provide a `model_provider`, e.g., `openai`, `azure`, `gemini`, etc.
|
||||
- Set the `model` to a one supported by the `model_provider`'s API.
|
||||
- Provide a `deployment_name` if using `azure` as the `model_provider`.
|
||||
|
||||
See [Detailed Configuration](yaml.md) for more details on configuration. [View LiteLLm basic usage](https://docs.litellm.ai/docs/#basic-usage) for details on how models are called (The `model_provider` is the portion prior to `/` while the `model` is the portion following the `/`).
|
||||
|
||||
## Model Selection Considerations
|
||||
|
||||
GraphRAG has been most thoroughly tested with the gpt-4 series of models from OpenAI, including gpt-4 gpt-4-turbo, gpt-4o, and gpt-4o-mini. Our [arXiv paper](https://arxiv.org/abs/2404.16130), for example, performed quality evaluation using gpt-4-turbo.
|
||||
GraphRAG has been most thoroughly tested with the gpt-4 series of models from OpenAI, including gpt-4 gpt-4-turbo, gpt-4o, and gpt-4o-mini. Our [arXiv paper](https://arxiv.org/abs/2404.16130), for example, performed quality evaluation using gpt-4-turbo. As stated above, non-OpenAI models are now supported with GraphRAG 2.6.0 and onwards through the use of LiteLLM but the suite of gpt-4 series of models from OpenAI remain the most tested and supported suite of models for GraphRAG.
|
||||
|
||||
Versions of GraphRAG before 2.2.0 made extensive use of `max_tokens` and `logit_bias` to control generated response length or content. The introduction of the o-series of models added new, non-compatible parameters because these models include a reasoning component that has different consumption patterns and response generation attributes than non-reasoning models. GraphRAG 2.2.0 now supports these models, but there are important differences that need to be understood before you switch.
|
||||
|
||||
@ -58,11 +87,11 @@ Another option would be to avoid using a language model at all for the graph ext
|
||||
|
||||
## Using Non-OpenAI Models
|
||||
|
||||
As noted above, our primary experience and focus has been on OpenAI models, so this is what is supported out-of-the-box. Many users have requested support for additional model types, but it's out of the scope of our research to handle the many models available today. There are two approaches you can use to connect to a non-OpenAI model:
|
||||
As shown above, non-OpenAI models may be used via LiteLLM starting with GraphRAG version 2.6.0 but cases may still exist in which some users wish to use models not supported by LiteLLM. There are two approaches one can use to connect to unsupported models:
|
||||
|
||||
### Proxy APIs
|
||||
|
||||
Many users have used platforms such as [ollama](https://ollama.com/) to proxy the underlying model HTTP calls to a different model provider. This seems to work reasonably well, but we frequently see issues with malformed responses (especially JSON), so if you do this please understand that your model needs to reliably return the specific response formats that GraphRAG expects. If you're having trouble with a model, you may need to try prompting to coax the format, or intercepting the response within your proxy to try and handle malformed responses.
|
||||
Many users have used platforms such as [ollama](https://ollama.com/) and [LiteLLM Proxy Server](https://docs.litellm.ai/docs/simple_proxy) to proxy the underlying model HTTP calls to a different model provider. This seems to work reasonably well, but we frequently see issues with malformed responses (especially JSON), so if you do this please understand that your model needs to reliably return the specific response formats that GraphRAG expects. If you're having trouble with a model, you may need to try prompting to coax the format, or intercepting the response within your proxy to try and handle malformed responses.
|
||||
|
||||
### Model Protocol
|
||||
|
||||
|
||||
@ -40,8 +40,9 @@ models:
|
||||
#### Fields
|
||||
|
||||
- `api_key` **str** - The OpenAI API key to use.
|
||||
- `auth_type` **api_key|managed_identity** - Indicate how you want to authenticate requests.
|
||||
- `type` **openai_chat|azure_openai_chat|openai_embedding|azure_openai_embedding|mock_chat|mock_embeddings** - The type of LLM to use.
|
||||
- `auth_type` **api_key|azure_managed_identity** - Indicate how you want to authenticate requests.
|
||||
- `type` **chat**|**embedding**|**openai_chat|azure_openai_chat|openai_embedding|azure_openai_embedding|mock_chat|mock_embeddings** - The type of LLM to use.
|
||||
- `model_provider` **str|None** - The model provider to use, e.g., openai, azure, anthropic, etc. Required when `type == chat|embedding`. When `type == chat|embedding`, [LiteLLM](https://docs.litellm.ai/) is used under the hood which has support for calling 100+ models. [View LiteLLm basic usage](https://docs.litellm.ai/docs/#basic-usage) for details on how models are called (The `model_provider` is the portion prior to `/` while the `model` is the portion following the `/`). [View Language Model Selection](models.md) for more details and examples on using LiteLLM.
|
||||
- `model` **str** - The model name.
|
||||
- `encoding_model` **str** - The text encoding model to use. Default is to use the encoding model aligned with the language model (i.e., it is retrieved from tiktoken if unset).
|
||||
- `api_base` **str** - The API base url to use.
|
||||
@ -73,16 +74,18 @@ models:
|
||||
|
||||
### input
|
||||
|
||||
Our pipeline can ingest .csv, .txt, or .json data from an input folder. See the [inputs page](../index/inputs.md) for more details and examples.
|
||||
Our pipeline can ingest .csv, .txt, or .json data from an input location. See the [inputs page](../index/inputs.md) for more details and examples.
|
||||
|
||||
#### Fields
|
||||
|
||||
- `type` **file|blob** - The input type to use. Default=`file`
|
||||
- `storage` **StorageConfig**
|
||||
- `type` **file|blob|cosmosdb** - The storage type to use. Default=`file`
|
||||
- `base_dir` **str** - The base directory to write output artifacts to, relative to the root.
|
||||
- `connection_string` **str** - (blob/cosmosdb only) The Azure Storage connection string.
|
||||
- `container_name` **str** - (blob/cosmosdb only) The Azure Storage container name.
|
||||
- `storage_account_blob_url` **str** - (blob only) The storage account blob URL to use.
|
||||
- `cosmosdb_account_blob_url` **str** - (cosmosdb only) The CosmosDB account blob URL to use.
|
||||
- `file_type` **text|csv|json** - The type of input data to load. Default is `text`
|
||||
- `base_dir` **str** - The base directory to read input from, relative to the root.
|
||||
- `connection_string` **str** - (blob only) The Azure Storage connection string.
|
||||
- `storage_account_blob_url` **str** - The storage account blob URL to use.
|
||||
- `container_name` **str** - (blob only) The Azure Storage container name.
|
||||
- `encoding` **str** - The encoding of the input file. Default is `utf-8`
|
||||
- `file_pattern` **str** - A regex to match input files. Default is `.*\.csv$`, `.*\.txt$`, or `.*\.json$` depending on the specified `file_type`, but you can customize it if needed.
|
||||
- `file_filter` **dict** - Key/value pairs to filter. Default is None.
|
||||
@ -147,11 +150,11 @@ This section controls the cache mechanism used by the pipeline. This is used to
|
||||
|
||||
### reporting
|
||||
|
||||
This section controls the reporting mechanism used by the pipeline, for common events and error messages. The default is to write reports to a file in the output directory. However, you can also choose to write reports to the console or to an Azure Blob Storage container.
|
||||
This section controls the reporting mechanism used by the pipeline, for common events and error messages. The default is to write reports to a file in the output directory. However, you can also choose to write reports to an Azure Blob Storage container.
|
||||
|
||||
#### Fields
|
||||
|
||||
- `type` **file|console|blob** - The reporting type to use. Default=`file`
|
||||
- `type` **file|blob** - The reporting type to use. Default=`file`
|
||||
- `base_dir` **str** - The base directory to write reports to, relative to the root.
|
||||
- `connection_string` **str** - (blob only) The Azure Storage connection string.
|
||||
- `container_name` **str** - (blob only) The Azure Storage container name.
|
||||
|
||||
@ -5,27 +5,27 @@
|
||||
| Name | Installation | Purpose |
|
||||
| ------------------- | ------------------------------------------------------------ | ----------------------------------------------------------------------------------- |
|
||||
| Python 3.10-3.12 | [Download](https://www.python.org/downloads/) | The library is Python-based. |
|
||||
| Poetry | [Instructions](https://python-poetry.org/docs/#installation) | Poetry is used for package management and virtualenv management in Python codebases |
|
||||
| uv | [Instructions](https://docs.astral.sh/uv/) | uv is used for package management and virtualenv management in Python codebases |
|
||||
|
||||
# Getting Started
|
||||
|
||||
## Install Dependencies
|
||||
|
||||
```sh
|
||||
# Install Python dependencies.
|
||||
poetry install
|
||||
# install python dependencies
|
||||
uv sync
|
||||
```
|
||||
|
||||
## Execute the Indexing Engine
|
||||
|
||||
```sh
|
||||
poetry run poe index <...args>
|
||||
uv run poe index <...args>
|
||||
```
|
||||
|
||||
## Executing Queries
|
||||
|
||||
```sh
|
||||
poetry run poe query <...args>
|
||||
uv run poe query <...args>
|
||||
```
|
||||
|
||||
# Azurite
|
||||
@ -40,31 +40,31 @@ or by simply running `azurite` in the terminal if already installed globally. Se
|
||||
|
||||
# Lifecycle Scripts
|
||||
|
||||
Our Python package utilizes Poetry to manage dependencies and [poethepoet](https://pypi.org/project/poethepoet/) to manage build scripts.
|
||||
Our Python package utilize uv to manage dependencies and [poethepoet](https://pypi.org/project/poethepoet/) to manage build scripts.
|
||||
|
||||
Available scripts are:
|
||||
|
||||
- `poetry run poe index` - Run the Indexing CLI
|
||||
- `poetry run poe query` - Run the Query CLI
|
||||
- `poetry build` - This invokes `poetry build`, which will build a wheel file and other distributable artifacts.
|
||||
- `poetry run poe test` - This will execute all tests.
|
||||
- `poetry run poe test_unit` - This will execute unit tests.
|
||||
- `poetry run poe test_integration` - This will execute integration tests.
|
||||
- `poetry run poe test_smoke` - This will execute smoke tests.
|
||||
- `poetry run poe test_verbs` - This will execute tests of the basic workflows.
|
||||
- `poetry run poe check` - This will perform a suite of static checks across the package, including:
|
||||
- `uv run poe index` - Run the Indexing CLI
|
||||
- `uv run poe query` - Run the Query CLI
|
||||
- `uv build` - This will build a wheel file and other distributable artifacts.
|
||||
- `uv run poe test` - This will execute all tests.
|
||||
- `uv run poe test_unit` - This will execute unit tests.
|
||||
- `uv run poe test_integration` - This will execute integration tests.
|
||||
- `uv run poe test_smoke` - This will execute smoke tests.
|
||||
- `uv run poe test_verbs` - This will execute tests of the basic workflows.
|
||||
- `uv run poe check` - This will perform a suite of static checks across the package, including:
|
||||
- formatting
|
||||
- documentation formatting
|
||||
- linting
|
||||
- security patterns
|
||||
- type-checking
|
||||
- `poetry run poe fix` - This will apply any available auto-fixes to the package. Usually this is just formatting fixes.
|
||||
- `poetry run poe fix_unsafe` - This will apply any available auto-fixes to the package, including those that may be unsafe.
|
||||
- `poetry run poe format` - Explicitly run the formatter across the package.
|
||||
- `uv run poe fix` - This will apply any available auto-fixes to the package. Usually this is just formatting fixes.
|
||||
- `uv run poe fix_unsafe` - This will apply any available auto-fixes to the package, including those that may be unsafe.
|
||||
- `uv run poe format` - Explicitly run the formatter across the package.
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### "RuntimeError: llvm-config failed executing, please point LLVM_CONFIG to the path for llvm-config" when running poetry install
|
||||
### "RuntimeError: llvm-config failed executing, please point LLVM_CONFIG to the path for llvm-config" when running uv install
|
||||
|
||||
Make sure llvm-9 and llvm-9-dev are installed:
|
||||
|
||||
@ -73,14 +73,3 @@ Make sure llvm-9 and llvm-9-dev are installed:
|
||||
and then in your bashrc, add
|
||||
|
||||
`export LLVM_CONFIG=/usr/bin/llvm-config-9`
|
||||
|
||||
### "numba/\_pymodule.h:6:10: fatal error: Python.h: No such file or directory" when running poetry install
|
||||
|
||||
Make sure you have python3.10-dev installed or more generally `python<version>-dev`
|
||||
|
||||
`sudo apt-get install python3.10-dev`
|
||||
|
||||
### LLM call constantly exceeds TPM, RPM or time limits
|
||||
|
||||
`GRAPHRAG_LLM_THREAD_COUNT` and `GRAPHRAG_EMBEDDING_THREAD_COUNT` are both set to 50 by default. You can modify these values
|
||||
to reduce concurrency. Please refer to the [Configuration Documents](config/overview.md)
|
||||
|
||||
@ -67,6 +67,8 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# note that we expect this to fail on the deployed docs because the PROJECT_DIRECTORY is not set to a real location.\n",
|
||||
"# if you run this notebook locally, make sure to point at a location containing your settings.yaml\n",
|
||||
"graphrag_config = load_config(Path(PROJECT_DIRECTORY))"
|
||||
]
|
||||
},
|
||||
|
||||
680
docs/examples_notebooks/custom_vector_store.ipynb
Normal file
680
docs/examples_notebooks/custom_vector_store.ipynb
Normal file
@ -0,0 +1,680 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Copyright (c) 2024 Microsoft Corporation.\n",
|
||||
"# Licensed under the MIT License."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Bring-Your-Own Vector Store\n",
|
||||
"\n",
|
||||
"This notebook demonstrates how to implement a custom vector store and register for usage with GraphRAG.\n",
|
||||
"\n",
|
||||
"## Overview\n",
|
||||
"\n",
|
||||
"GraphRAG uses a plug-and-play architecture that allow for easy integration of custom vector stores (outside of what is natively supported) by following a factory design pattern. This allows you to:\n",
|
||||
"\n",
|
||||
"- **Extend functionality**: Add support for new vector database backends\n",
|
||||
"- **Customize behavior**: Implement specialized search logic or data structures\n",
|
||||
"- **Integrate existing systems**: Connect GraphRAG to your existing vector database infrastructure\n",
|
||||
"\n",
|
||||
"### What You'll Learn\n",
|
||||
"\n",
|
||||
"1. Understanding the `BaseVectorStore` interface\n",
|
||||
"2. Implementing a custom vector store class\n",
|
||||
"3. Registering your vector store with the `VectorStoreFactory`\n",
|
||||
"4. Testing and validating your implementation\n",
|
||||
"5. Configuring GraphRAG to use your custom vector store\n",
|
||||
"\n",
|
||||
"Let's get started!"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Step 1: Import Required Dependencies\n",
|
||||
"\n",
|
||||
"First, let's import the necessary GraphRAG components and other dependencies we'll need.\n",
|
||||
"\n",
|
||||
"```bash\n",
|
||||
"pip install graphrag\n",
|
||||
"```"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from typing import Any\n",
|
||||
"\n",
|
||||
"import numpy as np\n",
|
||||
"import yaml\n",
|
||||
"\n",
|
||||
"from graphrag.config.models.vector_store_schema_config import VectorStoreSchemaConfig\n",
|
||||
"from graphrag.data_model.types import TextEmbedder\n",
|
||||
"\n",
|
||||
"# GraphRAG vector store components\n",
|
||||
"from graphrag.vector_stores.base import (\n",
|
||||
" BaseVectorStore,\n",
|
||||
" VectorStoreDocument,\n",
|
||||
" VectorStoreSearchResult,\n",
|
||||
")\n",
|
||||
"from graphrag.vector_stores.factory import VectorStoreFactory"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Step 2: Understand the BaseVectorStore Interface\n",
|
||||
"\n",
|
||||
"Before using a custom vector store, let's examine the `BaseVectorStore` interface to understand what methods need to be implemented."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Let's inspect the BaseVectorStore class to understand the required methods\n",
|
||||
"import inspect\n",
|
||||
"\n",
|
||||
"print(\"BaseVectorStore Abstract Methods:\")\n",
|
||||
"print(\"=\" * 40)\n",
|
||||
"\n",
|
||||
"abstract_methods = []\n",
|
||||
"for name, method in inspect.getmembers(BaseVectorStore, predicate=inspect.isfunction):\n",
|
||||
" if getattr(method, \"__isabstractmethod__\", False):\n",
|
||||
" signature = inspect.signature(method)\n",
|
||||
" abstract_methods.append(f\"• {name}{signature}\")\n",
|
||||
" print(f\"• {name}{signature}\")\n",
|
||||
"\n",
|
||||
"print(f\"\\nTotal abstract methods to implement: {len(abstract_methods)}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Step 3: Implement a Custom Vector Store\n",
|
||||
"\n",
|
||||
"Now let's implement a simple in-memory vector store as an example. This vector store will:\n",
|
||||
"\n",
|
||||
"- Store documents and vectors in memory using Python data structures\n",
|
||||
"- Support all required BaseVectorStore methods\n",
|
||||
"\n",
|
||||
"**Note**: This is a simplified example for demonstration. Production vector stores would typically use optimized libraries like FAISS, more sophisticated indexing, and persistent storage."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class SimpleInMemoryVectorStore(BaseVectorStore):\n",
|
||||
" \"\"\"A simple in-memory vector store implementation for demonstration purposes.\n",
|
||||
"\n",
|
||||
" This vector store stores documents and their embeddings in memory and provides\n",
|
||||
" basic similarity search functionality using cosine similarity.\n",
|
||||
"\n",
|
||||
" WARNING: This is for demonstration only - not suitable for production use.\n",
|
||||
" For production, consider using optimized vector databases like LanceDB,\n",
|
||||
" Azure AI Search, or other specialized vector stores.\n",
|
||||
" \"\"\"\n",
|
||||
"\n",
|
||||
" # Internal storage for documents and vectors\n",
|
||||
" documents: dict[str, VectorStoreDocument]\n",
|
||||
" vectors: dict[str, np.ndarray]\n",
|
||||
" connected: bool\n",
|
||||
"\n",
|
||||
" def __init__(self, **kwargs: Any):\n",
|
||||
" \"\"\"Initialize the in-memory vector store.\"\"\"\n",
|
||||
" super().__init__(**kwargs)\n",
|
||||
"\n",
|
||||
" self.documents: dict[str, VectorStoreDocument] = {}\n",
|
||||
" self.vectors: dict[str, np.ndarray] = {}\n",
|
||||
" self.connected = False\n",
|
||||
"\n",
|
||||
" print(f\"🚀 SimpleInMemoryVectorStore initialized for index: {self.index_name}\")\n",
|
||||
"\n",
|
||||
" def connect(self, **kwargs: Any) -> None:\n",
|
||||
" \"\"\"Connect to the vector storage (no-op for in-memory store).\"\"\"\n",
|
||||
" self.connected = True\n",
|
||||
" print(f\"✅ Connected to in-memory vector store: {self.index_name}\")\n",
|
||||
"\n",
|
||||
" def load_documents(\n",
|
||||
" self, documents: list[VectorStoreDocument], overwrite: bool = True\n",
|
||||
" ) -> None:\n",
|
||||
" \"\"\"Load documents into the vector store.\"\"\"\n",
|
||||
" if not self.connected:\n",
|
||||
" msg = \"Vector store not connected. Call connect() first.\"\n",
|
||||
" raise RuntimeError(msg)\n",
|
||||
"\n",
|
||||
" if overwrite:\n",
|
||||
" self.documents.clear()\n",
|
||||
" self.vectors.clear()\n",
|
||||
"\n",
|
||||
" loaded_count = 0\n",
|
||||
" for doc in documents:\n",
|
||||
" if doc.vector is not None:\n",
|
||||
" doc_id = str(doc.id)\n",
|
||||
" self.documents[doc_id] = doc\n",
|
||||
" self.vectors[doc_id] = np.array(doc.vector, dtype=np.float32)\n",
|
||||
" loaded_count += 1\n",
|
||||
"\n",
|
||||
" print(f\"📚 Loaded {loaded_count} documents into vector store\")\n",
|
||||
"\n",
|
||||
" def _cosine_similarity(self, vec1: np.ndarray, vec2: np.ndarray) -> float:\n",
|
||||
" \"\"\"Calculate cosine similarity between two vectors.\"\"\"\n",
|
||||
" # Normalize vectors\n",
|
||||
" norm1 = np.linalg.norm(vec1)\n",
|
||||
" norm2 = np.linalg.norm(vec2)\n",
|
||||
"\n",
|
||||
" if norm1 == 0 or norm2 == 0:\n",
|
||||
" return 0.0\n",
|
||||
"\n",
|
||||
" return float(np.dot(vec1, vec2) / (norm1 * norm2))\n",
|
||||
"\n",
|
||||
" def similarity_search_by_vector(\n",
|
||||
" self, query_embedding: list[float], k: int = 10, **kwargs: Any\n",
|
||||
" ) -> list[VectorStoreSearchResult]:\n",
|
||||
" \"\"\"Perform similarity search using a query vector.\"\"\"\n",
|
||||
" if not self.connected:\n",
|
||||
" msg = \"Vector store not connected. Call connect() first.\"\n",
|
||||
" raise RuntimeError(msg)\n",
|
||||
"\n",
|
||||
" if not self.vectors:\n",
|
||||
" return []\n",
|
||||
"\n",
|
||||
" query_vec = np.array(query_embedding, dtype=np.float32)\n",
|
||||
" similarities = []\n",
|
||||
"\n",
|
||||
" # Calculate similarity with all stored vectors\n",
|
||||
" for doc_id, stored_vec in self.vectors.items():\n",
|
||||
" similarity = self._cosine_similarity(query_vec, stored_vec)\n",
|
||||
" similarities.append((doc_id, similarity))\n",
|
||||
"\n",
|
||||
" # Sort by similarity (descending) and take top k\n",
|
||||
" similarities.sort(key=lambda x: x[1], reverse=True)\n",
|
||||
" top_k = similarities[:k]\n",
|
||||
"\n",
|
||||
" # Create search results\n",
|
||||
" results = []\n",
|
||||
" for doc_id, score in top_k:\n",
|
||||
" document = self.documents[doc_id]\n",
|
||||
" result = VectorStoreSearchResult(document=document, score=score)\n",
|
||||
" results.append(result)\n",
|
||||
"\n",
|
||||
" return results\n",
|
||||
"\n",
|
||||
" def similarity_search_by_text(\n",
|
||||
" self, text: str, text_embedder: TextEmbedder, k: int = 10, **kwargs: Any\n",
|
||||
" ) -> list[VectorStoreSearchResult]:\n",
|
||||
" \"\"\"Perform similarity search using text (which gets embedded first).\"\"\"\n",
|
||||
" # Embed the text first\n",
|
||||
" query_embedding = text_embedder(text)\n",
|
||||
"\n",
|
||||
" # Use vector search with the embedding\n",
|
||||
" return self.similarity_search_by_vector(query_embedding, k, **kwargs)\n",
|
||||
"\n",
|
||||
" def filter_by_id(self, include_ids: list[str] | list[int]) -> Any:\n",
|
||||
" \"\"\"Build a query filter to filter documents by id.\n",
|
||||
"\n",
|
||||
" For this simple implementation, we return the list of IDs as the filter.\n",
|
||||
" \"\"\"\n",
|
||||
" return [str(id_) for id_ in include_ids]\n",
|
||||
"\n",
|
||||
" def search_by_id(self, id: str) -> VectorStoreDocument:\n",
|
||||
" \"\"\"Search for a document by id.\"\"\"\n",
|
||||
" doc_id = str(id)\n",
|
||||
" if doc_id not in self.documents:\n",
|
||||
" msg = f\"Document with id '{id}' not found\"\n",
|
||||
" raise KeyError(msg)\n",
|
||||
"\n",
|
||||
" return self.documents[doc_id]\n",
|
||||
"\n",
|
||||
" def get_stats(self) -> dict[str, Any]:\n",
|
||||
" \"\"\"Get statistics about the vector store (custom method).\"\"\"\n",
|
||||
" return {\n",
|
||||
" \"index_name\": self.index_name,\n",
|
||||
" \"document_count\": len(self.documents),\n",
|
||||
" \"vector_count\": len(self.vectors),\n",
|
||||
" \"connected\": self.connected,\n",
|
||||
" \"vector_dimension\": len(next(iter(self.vectors.values())))\n",
|
||||
" if self.vectors\n",
|
||||
" else 0,\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"print(\"✅ SimpleInMemoryVectorStore class defined!\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Step 4: Register the Custom Vector Store\n",
|
||||
"\n",
|
||||
"Now let's register our custom vector store with the `VectorStoreFactory` so it can be used throughout GraphRAG."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Register our custom vector store with a unique identifier\n",
|
||||
"CUSTOM_VECTOR_STORE_TYPE = \"simple_memory\"\n",
|
||||
"\n",
|
||||
"# Register the vector store class\n",
|
||||
"VectorStoreFactory.register(CUSTOM_VECTOR_STORE_TYPE, SimpleInMemoryVectorStore)\n",
|
||||
"\n",
|
||||
"print(f\"✅ Registered custom vector store with type: '{CUSTOM_VECTOR_STORE_TYPE}'\")\n",
|
||||
"\n",
|
||||
"# Verify registration\n",
|
||||
"available_types = VectorStoreFactory.get_vector_store_types()\n",
|
||||
"print(f\"\\n📋 Available vector store types: {available_types}\")\n",
|
||||
"print(\n",
|
||||
" f\"🔍 Is our custom type supported? {VectorStoreFactory.is_supported_type(CUSTOM_VECTOR_STORE_TYPE)}\"\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Step 5: Test the Custom Vector Store\n",
|
||||
"\n",
|
||||
"Let's create some sample data and test our custom vector store implementation."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Create sample documents with mock embeddings\n",
|
||||
"def create_mock_embedding(dimension: int = 384) -> list[float]:\n",
|
||||
" \"\"\"Create a random embedding vector for testing.\"\"\"\n",
|
||||
" return np.random.normal(0, 1, dimension).tolist()\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# Sample documents\n",
|
||||
"sample_documents = [\n",
|
||||
" VectorStoreDocument(\n",
|
||||
" id=\"doc_1\",\n",
|
||||
" text=\"GraphRAG is a powerful knowledge graph extraction and reasoning framework.\",\n",
|
||||
" vector=create_mock_embedding(),\n",
|
||||
" attributes={\"category\": \"technology\", \"source\": \"documentation\"},\n",
|
||||
" ),\n",
|
||||
" VectorStoreDocument(\n",
|
||||
" id=\"doc_2\",\n",
|
||||
" text=\"Vector stores enable efficient similarity search over high-dimensional data.\",\n",
|
||||
" vector=create_mock_embedding(),\n",
|
||||
" attributes={\"category\": \"technology\", \"source\": \"research\"},\n",
|
||||
" ),\n",
|
||||
" VectorStoreDocument(\n",
|
||||
" id=\"doc_3\",\n",
|
||||
" text=\"Machine learning models can process and understand natural language text.\",\n",
|
||||
" vector=create_mock_embedding(),\n",
|
||||
" attributes={\"category\": \"AI\", \"source\": \"article\"},\n",
|
||||
" ),\n",
|
||||
" VectorStoreDocument(\n",
|
||||
" id=\"doc_4\",\n",
|
||||
" text=\"Custom implementations allow for specialized behavior and integration.\",\n",
|
||||
" vector=create_mock_embedding(),\n",
|
||||
" attributes={\"category\": \"development\", \"source\": \"tutorial\"},\n",
|
||||
" ),\n",
|
||||
"]\n",
|
||||
"\n",
|
||||
"print(f\"📝 Created {len(sample_documents)} sample documents\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Test creating vector store using the factory\n",
|
||||
"schema = VectorStoreSchemaConfig(index_name=\"test_collection\")\n",
|
||||
"\n",
|
||||
"# Create vector store instance using factory\n",
|
||||
"vector_store = VectorStoreFactory.create_vector_store(\n",
|
||||
" CUSTOM_VECTOR_STORE_TYPE, vector_store_schema_config=schema\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"print(f\"✅ Created vector store instance: {type(vector_store).__name__}\")\n",
|
||||
"print(f\"📊 Initial stats: {vector_store.get_stats()}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Connect and load documents\n",
|
||||
"vector_store.connect()\n",
|
||||
"vector_store.load_documents(sample_documents)\n",
|
||||
"\n",
|
||||
"print(f\"📊 Updated stats: {vector_store.get_stats()}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Test similarity search\n",
|
||||
"query_vector = create_mock_embedding() # Random query vector for testing\n",
|
||||
"\n",
|
||||
"search_results = vector_store.similarity_search_by_vector(\n",
|
||||
" query_vector,\n",
|
||||
" k=3, # Get top 3 similar documents\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"print(f\"🔍 Found {len(search_results)} similar documents:\\n\")\n",
|
||||
"\n",
|
||||
"for i, result in enumerate(search_results, 1):\n",
|
||||
" doc = result.document\n",
|
||||
" print(f\"{i}. ID: {doc.id}\")\n",
|
||||
" print(f\" Text: {doc.text[:60]}...\")\n",
|
||||
" print(f\" Similarity Score: {result.score:.4f}\")\n",
|
||||
" print(f\" Category: {doc.attributes.get('category', 'N/A')}\")\n",
|
||||
" print()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Test search by ID\n",
|
||||
"try:\n",
|
||||
" found_doc = vector_store.search_by_id(\"doc_2\")\n",
|
||||
" print(\"✅ Found document by ID:\")\n",
|
||||
" print(f\" ID: {found_doc.id}\")\n",
|
||||
" print(f\" Text: {found_doc.text}\")\n",
|
||||
" print(f\" Attributes: {found_doc.attributes}\")\n",
|
||||
"except KeyError as e:\n",
|
||||
" print(f\"❌ Error: {e}\")\n",
|
||||
"\n",
|
||||
"# Test filter by ID\n",
|
||||
"id_filter = vector_store.filter_by_id([\"doc_1\", \"doc_3\"])\n",
|
||||
"print(f\"\\n🔧 ID filter result: {id_filter}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Step 6: Configuration for GraphRAG\n",
|
||||
"\n",
|
||||
"Now let's see how you would configure GraphRAG to use your custom vector store in a settings file."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Example GraphRAG yaml settings\n",
|
||||
"example_settings = {\n",
|
||||
" \"vector_store\": {\n",
|
||||
" \"default_vector_store\": {\n",
|
||||
" \"type\": CUSTOM_VECTOR_STORE_TYPE, # \"simple_memory\"\n",
|
||||
" \"collection_name\": \"graphrag_entities\",\n",
|
||||
" # Add any custom parameters your vector store needs\n",
|
||||
" \"custom_parameter\": \"custom_value\",\n",
|
||||
" }\n",
|
||||
" },\n",
|
||||
" # Other GraphRAG configuration...\n",
|
||||
" \"models\": {\n",
|
||||
" \"default_embedding_model\": {\n",
|
||||
" \"type\": \"openai_embedding\",\n",
|
||||
" \"model\": \"text-embedding-3-small\",\n",
|
||||
" }\n",
|
||||
" },\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"# Convert to YAML format for settings.yml\n",
|
||||
"yaml_config = yaml.dump(example_settings, default_flow_style=False, indent=2)\n",
|
||||
"\n",
|
||||
"print(\"📄 Example settings.yml configuration:\")\n",
|
||||
"print(\"=\" * 40)\n",
|
||||
"print(yaml_config)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Step 7: Integration with GraphRAG Pipeline\n",
|
||||
"\n",
|
||||
"Here's how your custom vector store would be used in a typical GraphRAG pipeline."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Example of how GraphRAG would use your custom vector store\n",
|
||||
"def simulate_graphrag_pipeline():\n",
|
||||
" \"\"\"Simulate how GraphRAG would use the custom vector store.\"\"\"\n",
|
||||
" print(\"🚀 Simulating GraphRAG pipeline with custom vector store...\\n\")\n",
|
||||
"\n",
|
||||
" # 1. GraphRAG creates vector store using factory\n",
|
||||
" schema = VectorStoreSchemaConfig(index_name=\"graphrag_entities\")\n",
|
||||
"\n",
|
||||
" store = VectorStoreFactory.create_vector_store(\n",
|
||||
" CUSTOM_VECTOR_STORE_TYPE,\n",
|
||||
" vector_store_schema_config=schema,\n",
|
||||
" similarity_threshold=0.3,\n",
|
||||
" )\n",
|
||||
" store.connect()\n",
|
||||
"\n",
|
||||
" print(\"✅ Step 1: Vector store created and connected\")\n",
|
||||
"\n",
|
||||
" # 2. During indexing, GraphRAG loads extracted entities\n",
|
||||
" entity_documents = [\n",
|
||||
" VectorStoreDocument(\n",
|
||||
" id=f\"entity_{i}\",\n",
|
||||
" text=f\"Entity {i} description: Important concept in the knowledge graph\",\n",
|
||||
" vector=create_mock_embedding(),\n",
|
||||
" attributes={\"type\": \"entity\", \"importance\": i % 3 + 1},\n",
|
||||
" )\n",
|
||||
" for i in range(10)\n",
|
||||
" ]\n",
|
||||
"\n",
|
||||
" store.load_documents(entity_documents)\n",
|
||||
" print(f\"✅ Step 2: Loaded {len(entity_documents)} entity documents\")\n",
|
||||
"\n",
|
||||
" # 3. During query time, GraphRAG searches for relevant entities\n",
|
||||
" query_embedding = create_mock_embedding()\n",
|
||||
" relevant_entities = store.similarity_search_by_vector(query_embedding, k=5)\n",
|
||||
"\n",
|
||||
" print(f\"✅ Step 3: Found {len(relevant_entities)} relevant entities for query\")\n",
|
||||
"\n",
|
||||
" # 4. GraphRAG uses these entities for context building\n",
|
||||
" context_entities = [result.document for result in relevant_entities]\n",
|
||||
"\n",
|
||||
" print(\"✅ Step 4: Context built using retrieved entities\")\n",
|
||||
" print(f\"📊 Final stats: {store.get_stats()}\")\n",
|
||||
"\n",
|
||||
" return context_entities\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# Run the simulation\n",
|
||||
"context = simulate_graphrag_pipeline()\n",
|
||||
"print(f\"\\n🎯 Retrieved {len(context)} entities for context building\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Step 8: Testing and Validation\n",
|
||||
"\n",
|
||||
"Let's create a comprehensive test suite to ensure our vector store works correctly."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def test_custom_vector_store():\n",
|
||||
" \"\"\"Comprehensive test suite for the custom vector store.\"\"\"\n",
|
||||
" print(\"🧪 Running comprehensive vector store tests...\\n\")\n",
|
||||
"\n",
|
||||
" # Test 1: Basic functionality\n",
|
||||
" print(\"Test 1: Basic functionality\")\n",
|
||||
" store = VectorStoreFactory.create_vector_store(\n",
|
||||
" CUSTOM_VECTOR_STORE_TYPE,\n",
|
||||
" vector_store_schema_config=VectorStoreSchemaConfig(index_name=\"test\"),\n",
|
||||
" )\n",
|
||||
" store.connect()\n",
|
||||
"\n",
|
||||
" # Load test documents\n",
|
||||
" test_docs = sample_documents[:2]\n",
|
||||
" store.load_documents(test_docs)\n",
|
||||
"\n",
|
||||
" assert len(store.documents) == 2, \"Should have 2 documents\"\n",
|
||||
" assert len(store.vectors) == 2, \"Should have 2 vectors\"\n",
|
||||
" print(\"✅ Basic functionality test passed\")\n",
|
||||
"\n",
|
||||
" # Test 2: Search functionality\n",
|
||||
" print(\"\\nTest 2: Search functionality\")\n",
|
||||
" query_vec = create_mock_embedding()\n",
|
||||
" results = store.similarity_search_by_vector(query_vec, k=5)\n",
|
||||
"\n",
|
||||
" assert len(results) <= 2, \"Should not return more results than documents\"\n",
|
||||
" assert all(isinstance(r, VectorStoreSearchResult) for r in results), (\n",
|
||||
" \"Should return VectorStoreSearchResult objects\"\n",
|
||||
" )\n",
|
||||
" assert all(-1 <= r.score <= 1 for r in results), (\n",
|
||||
" \"Similarity scores should be between -1 and 1\"\n",
|
||||
" )\n",
|
||||
" print(\"✅ Search functionality test passed\")\n",
|
||||
"\n",
|
||||
" # Test 3: Search by ID\n",
|
||||
" print(\"\\nTest 3: Search by ID\")\n",
|
||||
" found_doc = store.search_by_id(\"doc_1\")\n",
|
||||
" assert found_doc.id == \"doc_1\", \"Should find correct document\"\n",
|
||||
"\n",
|
||||
" try:\n",
|
||||
" store.search_by_id(\"nonexistent\")\n",
|
||||
" assert False, \"Should raise KeyError for nonexistent ID\"\n",
|
||||
" except KeyError:\n",
|
||||
" pass # Expected\n",
|
||||
"\n",
|
||||
" print(\"✅ Search by ID test passed\")\n",
|
||||
"\n",
|
||||
" # Test 4: Filter functionality\n",
|
||||
" print(\"\\nTest 4: Filter functionality\")\n",
|
||||
" filter_result = store.filter_by_id([\"doc_1\", \"doc_2\"])\n",
|
||||
" assert filter_result == [\"doc_1\", \"doc_2\"], \"Should return filtered IDs\"\n",
|
||||
" print(\"✅ Filter functionality test passed\")\n",
|
||||
"\n",
|
||||
" # Test 5: Error handling\n",
|
||||
" print(\"\\nTest 5: Error handling\")\n",
|
||||
" disconnected_store = VectorStoreFactory.create_vector_store(\n",
|
||||
" CUSTOM_VECTOR_STORE_TYPE,\n",
|
||||
" vector_store_schema_config=VectorStoreSchemaConfig(index_name=\"test2\"),\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" try:\n",
|
||||
" disconnected_store.load_documents(test_docs)\n",
|
||||
" assert False, \"Should raise error when not connected\"\n",
|
||||
" except RuntimeError:\n",
|
||||
" pass # Expected\n",
|
||||
"\n",
|
||||
" try:\n",
|
||||
" disconnected_store.similarity_search_by_vector(query_vec)\n",
|
||||
" assert False, \"Should raise error when not connected\"\n",
|
||||
" except RuntimeError:\n",
|
||||
" pass # Expected\n",
|
||||
"\n",
|
||||
" print(\"✅ Error handling test passed\")\n",
|
||||
"\n",
|
||||
" print(\"\\n🎉 All tests passed! Your custom vector store is working correctly.\")\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# Run the tests\n",
|
||||
"test_custom_vector_store()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Summary and Next Steps\n",
|
||||
"\n",
|
||||
"Congratulations! You've successfully learned how to implement and register a custom vector store with GraphRAG. Here's what you accomplished:\n",
|
||||
"\n",
|
||||
"### What You Built\n",
|
||||
"- ✅ **Custom Vector Store Class**: Implemented `SimpleInMemoryVectorStore` with all required methods\n",
|
||||
"- ✅ **Factory Integration**: Registered your vector store with `VectorStoreFactory`\n",
|
||||
"- ✅ **Comprehensive Testing**: Validated functionality with a full test suite\n",
|
||||
"- ✅ **Configuration Examples**: Learned how to configure GraphRAG to use your vector store\n",
|
||||
"\n",
|
||||
"### Key Takeaways\n",
|
||||
"1. **Interface Compliance**: Always implement all methods from `BaseVectorStore`\n",
|
||||
"2. **Factory Pattern**: Use `VectorStoreFactory.register()` to make your vector store available\n",
|
||||
"3. **Configuration**: Vector stores are configured in GraphRAG settings files\n",
|
||||
"4. **Testing**: Thoroughly test all functionality before deploying\n",
|
||||
"\n",
|
||||
"### Next Steps\n",
|
||||
"Check out the API Overview notebook to learn how to index and query data via the graphrag API.\n",
|
||||
"\n",
|
||||
"### Resources\n",
|
||||
"- [GraphRAG Documentation](https://microsoft.github.io/graphrag/)\n",
|
||||
"\n",
|
||||
"Happy building! 🚀"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "graphrag",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.12.10"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 4
|
||||
}
|
||||
@ -20,11 +20,11 @@
|
||||
"from pathlib import Path\n",
|
||||
"\n",
|
||||
"import pandas as pd\n",
|
||||
"import tiktoken\n",
|
||||
"\n",
|
||||
"from graphrag.config.enums import ModelType\n",
|
||||
"from graphrag.config.models.drift_search_config import DRIFTSearchConfig\n",
|
||||
"from graphrag.config.models.language_model_config import LanguageModelConfig\n",
|
||||
"from graphrag.config.models.vector_store_schema_config import VectorStoreSchemaConfig\n",
|
||||
"from graphrag.language_model.manager import ModelManager\n",
|
||||
"from graphrag.query.indexer_adapters import (\n",
|
||||
" read_indexer_entities,\n",
|
||||
@ -37,6 +37,7 @@
|
||||
" DRIFTSearchContextBuilder,\n",
|
||||
")\n",
|
||||
"from graphrag.query.structured_search.drift_search.search import DRIFTSearch\n",
|
||||
"from graphrag.tokenizer.get_tokenizer import get_tokenizer\n",
|
||||
"from graphrag.vector_stores.lancedb import LanceDBVectorStore\n",
|
||||
"\n",
|
||||
"INPUT_DIR = \"./inputs/operation dulce\"\n",
|
||||
@ -62,12 +63,16 @@
|
||||
"# load description embeddings to an in-memory lancedb vectorstore\n",
|
||||
"# to connect to a remote db, specify url and port values.\n",
|
||||
"description_embedding_store = LanceDBVectorStore(\n",
|
||||
" collection_name=\"default-entity-description\",\n",
|
||||
" vector_store_schema_config=VectorStoreSchemaConfig(\n",
|
||||
" index_name=\"default-entity-description\"\n",
|
||||
" ),\n",
|
||||
")\n",
|
||||
"description_embedding_store.connect(db_uri=LANCEDB_URI)\n",
|
||||
"\n",
|
||||
"full_content_embedding_store = LanceDBVectorStore(\n",
|
||||
" collection_name=\"default-community-full_content\",\n",
|
||||
" vector_store_schema_config=VectorStoreSchemaConfig(\n",
|
||||
" index_name=\"default-community-full_content\"\n",
|
||||
" )\n",
|
||||
")\n",
|
||||
"full_content_embedding_store.connect(db_uri=LANCEDB_URI)\n",
|
||||
"\n",
|
||||
@ -94,33 +99,33 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"api_key = os.environ[\"GRAPHRAG_API_KEY\"]\n",
|
||||
"llm_model = os.environ[\"GRAPHRAG_LLM_MODEL\"]\n",
|
||||
"embedding_model = os.environ[\"GRAPHRAG_EMBEDDING_MODEL\"]\n",
|
||||
"\n",
|
||||
"chat_config = LanguageModelConfig(\n",
|
||||
" api_key=api_key,\n",
|
||||
" type=ModelType.OpenAIChat,\n",
|
||||
" model=llm_model,\n",
|
||||
" type=ModelType.Chat,\n",
|
||||
" model_provider=\"openai\",\n",
|
||||
" model=\"gpt-4.1\",\n",
|
||||
" max_retries=20,\n",
|
||||
")\n",
|
||||
"chat_model = ModelManager().get_or_create_chat_model(\n",
|
||||
" name=\"local_search\",\n",
|
||||
" model_type=ModelType.OpenAIChat,\n",
|
||||
" model_type=ModelType.Chat,\n",
|
||||
" config=chat_config,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"token_encoder = tiktoken.encoding_for_model(llm_model)\n",
|
||||
"tokenizer = get_tokenizer(chat_config)\n",
|
||||
"\n",
|
||||
"embedding_config = LanguageModelConfig(\n",
|
||||
" api_key=api_key,\n",
|
||||
" type=ModelType.OpenAIEmbedding,\n",
|
||||
" model=embedding_model,\n",
|
||||
" type=ModelType.Embedding,\n",
|
||||
" model_provider=\"openai\",\n",
|
||||
" model=\"text-embedding-3-small\",\n",
|
||||
" max_retries=20,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"text_embedder = ModelManager().get_or_create_embedding_model(\n",
|
||||
" name=\"local_search_embedding\",\n",
|
||||
" model_type=ModelType.OpenAIEmbedding,\n",
|
||||
" model_type=ModelType.Embedding,\n",
|
||||
" config=embedding_config,\n",
|
||||
")"
|
||||
]
|
||||
@ -173,12 +178,12 @@
|
||||
" reports=reports,\n",
|
||||
" entity_text_embeddings=description_embedding_store,\n",
|
||||
" text_units=text_units,\n",
|
||||
" token_encoder=token_encoder,\n",
|
||||
" tokenizer=tokenizer,\n",
|
||||
" config=drift_params,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"search = DRIFTSearch(\n",
|
||||
" model=chat_model, context_builder=context_builder, token_encoder=token_encoder\n",
|
||||
" model=chat_model, context_builder=context_builder, tokenizer=tokenizer\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
@ -212,7 +217,7 @@
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": ".venv",
|
||||
"display_name": "graphrag",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
@ -226,7 +231,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.9"
|
||||
"version": "3.12.10"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
@ -19,7 +19,6 @@
|
||||
"import os\n",
|
||||
"\n",
|
||||
"import pandas as pd\n",
|
||||
"import tiktoken\n",
|
||||
"\n",
|
||||
"from graphrag.config.enums import ModelType\n",
|
||||
"from graphrag.config.models.language_model_config import LanguageModelConfig\n",
|
||||
@ -32,7 +31,8 @@
|
||||
"from graphrag.query.structured_search.global_search.community_context import (\n",
|
||||
" GlobalCommunityContext,\n",
|
||||
")\n",
|
||||
"from graphrag.query.structured_search.global_search.search import GlobalSearch"
|
||||
"from graphrag.query.structured_search.global_search.search import GlobalSearch\n",
|
||||
"from graphrag.tokenizer.get_tokenizer import get_tokenizer"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -58,21 +58,21 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"api_key = os.environ[\"GRAPHRAG_API_KEY\"]\n",
|
||||
"llm_model = os.environ[\"GRAPHRAG_LLM_MODEL\"]\n",
|
||||
"\n",
|
||||
"config = LanguageModelConfig(\n",
|
||||
" api_key=api_key,\n",
|
||||
" type=ModelType.OpenAIChat,\n",
|
||||
" model=llm_model,\n",
|
||||
" type=ModelType.Chat,\n",
|
||||
" model_provider=\"openai\",\n",
|
||||
" model=\"gpt-4.1\",\n",
|
||||
" max_retries=20,\n",
|
||||
")\n",
|
||||
"model = ModelManager().get_or_create_chat_model(\n",
|
||||
" name=\"global_search\",\n",
|
||||
" model_type=ModelType.OpenAIChat,\n",
|
||||
" model_type=ModelType.Chat,\n",
|
||||
" config=config,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"token_encoder = tiktoken.encoding_for_model(llm_model)"
|
||||
"tokenizer = get_tokenizer(config)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -142,7 +142,7 @@
|
||||
" community_reports=reports,\n",
|
||||
" communities=communities,\n",
|
||||
" entities=entities, # default to None if you don't want to use community weights for ranking\n",
|
||||
" token_encoder=token_encoder,\n",
|
||||
" tokenizer=tokenizer,\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
@ -193,7 +193,7 @@
|
||||
"search_engine = GlobalSearch(\n",
|
||||
" model=model,\n",
|
||||
" context_builder=context_builder,\n",
|
||||
" token_encoder=token_encoder,\n",
|
||||
" tokenizer=tokenizer,\n",
|
||||
" max_data_tokens=12_000, # change this based on the token limit you have on your model (if you are using a model with 8k limit, a good setting could be 5000)\n",
|
||||
" map_llm_params=map_llm_params,\n",
|
||||
" reduce_llm_params=reduce_llm_params,\n",
|
||||
@ -241,7 +241,7 @@
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": ".venv",
|
||||
"display_name": "graphrag",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
@ -255,7 +255,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.9"
|
||||
"version": "3.12.10"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
@ -19,7 +19,6 @@
|
||||
"import os\n",
|
||||
"\n",
|
||||
"import pandas as pd\n",
|
||||
"import tiktoken\n",
|
||||
"\n",
|
||||
"from graphrag.config.enums import ModelType\n",
|
||||
"from graphrag.config.models.language_model_config import LanguageModelConfig\n",
|
||||
@ -57,22 +56,24 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from graphrag.tokenizer.get_tokenizer import get_tokenizer\n",
|
||||
"\n",
|
||||
"api_key = os.environ[\"GRAPHRAG_API_KEY\"]\n",
|
||||
"llm_model = os.environ[\"GRAPHRAG_LLM_MODEL\"]\n",
|
||||
"\n",
|
||||
"config = LanguageModelConfig(\n",
|
||||
" api_key=api_key,\n",
|
||||
" type=ModelType.OpenAIChat,\n",
|
||||
" model=llm_model,\n",
|
||||
" type=ModelType.Chat,\n",
|
||||
" model_provider=\"openai\",\n",
|
||||
" model=\"gpt-4.1\",\n",
|
||||
" max_retries=20,\n",
|
||||
")\n",
|
||||
"model = ModelManager().get_or_create_chat_model(\n",
|
||||
" name=\"global_search\",\n",
|
||||
" model_type=ModelType.OpenAIChat,\n",
|
||||
" model_type=ModelType.Chat,\n",
|
||||
" config=config,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"token_encoder = tiktoken.encoding_for_model(llm_model)"
|
||||
"tokenizer = get_tokenizer(config)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -155,11 +156,11 @@
|
||||
" community_reports=reports,\n",
|
||||
" communities=communities,\n",
|
||||
" entities=entities, # default to None if you don't want to use community weights for ranking\n",
|
||||
" token_encoder=token_encoder,\n",
|
||||
" tokenizer=tokenizer,\n",
|
||||
" dynamic_community_selection=True,\n",
|
||||
" dynamic_community_selection_kwargs={\n",
|
||||
" \"model\": model,\n",
|
||||
" \"token_encoder\": token_encoder,\n",
|
||||
" \"tokenizer\": tokenizer,\n",
|
||||
" },\n",
|
||||
")"
|
||||
]
|
||||
@ -211,7 +212,7 @@
|
||||
"search_engine = GlobalSearch(\n",
|
||||
" model=model,\n",
|
||||
" context_builder=context_builder,\n",
|
||||
" token_encoder=token_encoder,\n",
|
||||
" tokenizer=tokenizer,\n",
|
||||
" max_data_tokens=12_000, # change this based on the token limit you have on your model (if you are using a model with 8k limit, a good setting could be 5000)\n",
|
||||
" map_llm_params=map_llm_params,\n",
|
||||
" reduce_llm_params=reduce_llm_params,\n",
|
||||
@ -255,7 +256,7 @@
|
||||
"prompt_tokens = result.prompt_tokens_categories[\"build_context\"]\n",
|
||||
"output_tokens = result.output_tokens_categories[\"build_context\"]\n",
|
||||
"print(\n",
|
||||
" f\"Build context ({llm_model})\\nLLM calls: {llm_calls}. Prompt tokens: {prompt_tokens}. Output tokens: {output_tokens}.\"\n",
|
||||
" f\"Build context LLM calls: {llm_calls}. Prompt tokens: {prompt_tokens}. Output tokens: {output_tokens}.\"\n",
|
||||
")\n",
|
||||
"# inspect number of LLM calls and tokens in map-reduce\n",
|
||||
"llm_calls = result.llm_calls_categories[\"map\"] + result.llm_calls_categories[\"reduce\"]\n",
|
||||
@ -266,14 +267,14 @@
|
||||
" result.output_tokens_categories[\"map\"] + result.output_tokens_categories[\"reduce\"]\n",
|
||||
")\n",
|
||||
"print(\n",
|
||||
" f\"Map-reduce ({llm_model})\\nLLM calls: {llm_calls}. Prompt tokens: {prompt_tokens}. Output tokens: {output_tokens}.\"\n",
|
||||
" f\"Map-reduce LLM calls: {llm_calls}. Prompt tokens: {prompt_tokens}. Output tokens: {output_tokens}.\"\n",
|
||||
")"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": ".venv",
|
||||
"display_name": "graphrag",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
@ -287,7 +288,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.9"
|
||||
"version": "3.12.10"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
@ -202,10 +202,11 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from graphrag.index.flows.generate_text_embeddings import generate_text_embeddings\n",
|
||||
"\n",
|
||||
"from graphrag.cache.factory import CacheFactory\n",
|
||||
"from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks\n",
|
||||
"from graphrag.config.embeddings import get_embedded_fields, get_embedding_settings\n",
|
||||
"from graphrag.index.flows.generate_text_embeddings import generate_text_embeddings\n",
|
||||
"\n",
|
||||
"# We only need to re-run the embeddings workflow, to ensure that embeddings for all required search fields are in place\n",
|
||||
"# We'll construct the context and run this function flow directly to avoid everything else\n",
|
||||
|
||||
194
docs/examples_notebooks/input_documents.ipynb
Normal file
194
docs/examples_notebooks/input_documents.ipynb
Normal file
@ -0,0 +1,194 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Copyright (c) 2024 Microsoft Corporation.\n",
|
||||
"# Licensed under the MIT License."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Example of indexing from an existing in-memory dataframe\n",
|
||||
"\n",
|
||||
"Newer versions of GraphRAG let you submit a dataframe directly instead of running through the input processing step. This notebook demonstrates with regular or update runs.\n",
|
||||
"\n",
|
||||
"If performing an update, the assumption is that your dataframe contains only the new documents to add to the index."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from pathlib import Path\n",
|
||||
"from pprint import pprint\n",
|
||||
"\n",
|
||||
"import pandas as pd\n",
|
||||
"\n",
|
||||
"import graphrag.api as api\n",
|
||||
"from graphrag.config.load_config import load_config\n",
|
||||
"from graphrag.index.typing.pipeline_run_result import PipelineRunResult"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"PROJECT_DIRECTORY = \"<your project directory>\"\n",
|
||||
"UPDATE = False\n",
|
||||
"FILENAME = \"new_documents.parquet\" if UPDATE else \"<original_documents>.parquet\"\n",
|
||||
"inputs = pd.read_parquet(f\"{PROJECT_DIRECTORY}/input/{FILENAME}\")\n",
|
||||
"# Only the bare minimum for input. These are the same fields that would be present after the load_input_documents workflow\n",
|
||||
"inputs = inputs.loc[:, [\"id\", \"title\", \"text\", \"creation_date\"]]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Generate a `GraphRagConfig` object"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"graphrag_config = load_config(Path(PROJECT_DIRECTORY))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Indexing API\n",
|
||||
"\n",
|
||||
"*Indexing* is the process of ingesting raw text data and constructing a knowledge graph. GraphRAG currently supports plaintext (`.txt`) and `.csv` file formats."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Build an index"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"index_result: list[PipelineRunResult] = await api.build_index(\n",
|
||||
" config=graphrag_config, input_documents=inputs, is_update_run=UPDATE\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# index_result is a list of workflows that make up the indexing pipeline that was run\n",
|
||||
"for workflow_result in index_result:\n",
|
||||
" status = f\"error\\n{workflow_result.errors}\" if workflow_result.errors else \"success\"\n",
|
||||
" print(f\"Workflow Name: {workflow_result.workflow}\\tStatus: {status}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Query an index\n",
|
||||
"\n",
|
||||
"To query an index, several index files must first be read into memory and passed to the query API. "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"entities = pd.read_parquet(f\"{PROJECT_DIRECTORY}/output/entities.parquet\")\n",
|
||||
"communities = pd.read_parquet(f\"{PROJECT_DIRECTORY}/output/communities.parquet\")\n",
|
||||
"community_reports = pd.read_parquet(\n",
|
||||
" f\"{PROJECT_DIRECTORY}/output/community_reports.parquet\"\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"response, context = await api.global_search(\n",
|
||||
" config=graphrag_config,\n",
|
||||
" entities=entities,\n",
|
||||
" communities=communities,\n",
|
||||
" community_reports=community_reports,\n",
|
||||
" community_level=2,\n",
|
||||
" dynamic_community_selection=False,\n",
|
||||
" response_type=\"Multiple Paragraphs\",\n",
|
||||
" query=\"What are the top five themes of the dataset?\",\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"The response object is the official reponse from graphrag while the context object holds various metadata regarding the querying process used to obtain the final response."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"print(response)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Digging into the context a bit more provides users with extremely granular information such as what sources of data (down to the level of text chunks) were ultimately retrieved and used as part of the context sent to the LLM model)."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"pprint(context) # noqa: T203"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "graphrag",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.12.10"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
@ -19,8 +19,8 @@
|
||||
"import os\n",
|
||||
"\n",
|
||||
"import pandas as pd\n",
|
||||
"import tiktoken\n",
|
||||
"\n",
|
||||
"from graphrag.config.models.vector_store_schema_config import VectorStoreSchemaConfig\n",
|
||||
"from graphrag.query.context_builder.entity_extraction import EntityVectorStoreKey\n",
|
||||
"from graphrag.query.indexer_adapters import (\n",
|
||||
" read_indexer_covariates,\n",
|
||||
@ -102,7 +102,9 @@
|
||||
"# load description embeddings to an in-memory lancedb vectorstore\n",
|
||||
"# to connect to a remote db, specify url and port values.\n",
|
||||
"description_embedding_store = LanceDBVectorStore(\n",
|
||||
" collection_name=\"default-entity-description\",\n",
|
||||
" vector_store_schema_config=VectorStoreSchemaConfig(\n",
|
||||
" index_name=\"default-entity-description\"\n",
|
||||
" )\n",
|
||||
")\n",
|
||||
"description_embedding_store.connect(db_uri=LANCEDB_URI)\n",
|
||||
"\n",
|
||||
@ -195,37 +197,38 @@
|
||||
"from graphrag.config.enums import ModelType\n",
|
||||
"from graphrag.config.models.language_model_config import LanguageModelConfig\n",
|
||||
"from graphrag.language_model.manager import ModelManager\n",
|
||||
"from graphrag.tokenizer.get_tokenizer import get_tokenizer\n",
|
||||
"\n",
|
||||
"api_key = os.environ[\"GRAPHRAG_API_KEY\"]\n",
|
||||
"llm_model = os.environ[\"GRAPHRAG_LLM_MODEL\"]\n",
|
||||
"embedding_model = os.environ[\"GRAPHRAG_EMBEDDING_MODEL\"]\n",
|
||||
"\n",
|
||||
"chat_config = LanguageModelConfig(\n",
|
||||
" api_key=api_key,\n",
|
||||
" type=ModelType.OpenAIChat,\n",
|
||||
" model=llm_model,\n",
|
||||
" type=ModelType.Chat,\n",
|
||||
" model_provider=\"openai\",\n",
|
||||
" model=\"gpt-4.1\",\n",
|
||||
" max_retries=20,\n",
|
||||
")\n",
|
||||
"chat_model = ModelManager().get_or_create_chat_model(\n",
|
||||
" name=\"local_search\",\n",
|
||||
" model_type=ModelType.OpenAIChat,\n",
|
||||
" model_type=ModelType.Chat,\n",
|
||||
" config=chat_config,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"token_encoder = tiktoken.encoding_for_model(llm_model)\n",
|
||||
"\n",
|
||||
"embedding_config = LanguageModelConfig(\n",
|
||||
" api_key=api_key,\n",
|
||||
" type=ModelType.OpenAIEmbedding,\n",
|
||||
" model=embedding_model,\n",
|
||||
" type=ModelType.Embedding,\n",
|
||||
" model_provider=\"openai\",\n",
|
||||
" model=\"text-embedding-3-small\",\n",
|
||||
" max_retries=20,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"text_embedder = ModelManager().get_or_create_embedding_model(\n",
|
||||
" name=\"local_search_embedding\",\n",
|
||||
" model_type=ModelType.OpenAIEmbedding,\n",
|
||||
" model_type=ModelType.Embedding,\n",
|
||||
" config=embedding_config,\n",
|
||||
")"
|
||||
")\n",
|
||||
"\n",
|
||||
"tokenizer = get_tokenizer(chat_config)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -251,7 +254,7 @@
|
||||
" entity_text_embeddings=description_embedding_store,\n",
|
||||
" embedding_vectorstore_key=EntityVectorStoreKey.ID, # if the vectorstore uses entity title as ids, set this to EntityVectorStoreKey.TITLE\n",
|
||||
" text_embedder=text_embedder,\n",
|
||||
" token_encoder=token_encoder,\n",
|
||||
" tokenizer=tokenizer,\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
@ -314,7 +317,7 @@
|
||||
"search_engine = LocalSearch(\n",
|
||||
" model=chat_model,\n",
|
||||
" context_builder=context_builder,\n",
|
||||
" token_encoder=token_encoder,\n",
|
||||
" tokenizer=tokenizer,\n",
|
||||
" model_params=model_params,\n",
|
||||
" context_builder_params=local_context_params,\n",
|
||||
" response_type=\"multiple paragraphs\", # free form text describing the response type and format, can be anything, e.g. prioritized list, single paragraph, multiple paragraphs, multiple-page report\n",
|
||||
@ -426,7 +429,7 @@
|
||||
"question_generator = LocalQuestionGen(\n",
|
||||
" model=chat_model,\n",
|
||||
" context_builder=context_builder,\n",
|
||||
" token_encoder=token_encoder,\n",
|
||||
" tokenizer=tokenizer,\n",
|
||||
" model_params=model_params,\n",
|
||||
" context_builder_params=local_context_params,\n",
|
||||
")"
|
||||
@ -451,7 +454,7 @@
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": ".venv",
|
||||
"display_name": "graphrag",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
@ -465,7 +468,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.9"
|
||||
"version": "3.12.10"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
@ -1,12 +1,13 @@
|
||||
# Getting Started
|
||||
|
||||
⚠️ GraphRAG can consume a lot of LLM resources! We strongly recommend starting with the tutorial dataset here until you understand how the system works, and consider experimenting with fast/inexpensive models first before committing to a big indexing job.
|
||||
|
||||
## Requirements
|
||||
|
||||
[Python 3.10-3.12](https://www.python.org/downloads/)
|
||||
|
||||
To get started with the GraphRAG system, you have a few options:
|
||||
|
||||
👉 [Use the GraphRAG Accelerator solution](https://github.com/Azure-Samples/graphrag-accelerator) <br/>
|
||||
👉 [Install from pypi](https://pypi.org/project/graphrag/). <br/>
|
||||
👉 [Use it from source](developing.md)<br/>
|
||||
|
||||
@ -25,25 +26,25 @@ pip install graphrag
|
||||
We need to set up a data project and some initial configuration. First let's get a sample dataset ready:
|
||||
|
||||
```sh
|
||||
mkdir -p ./ragtest/input
|
||||
mkdir -p ./christmas/input
|
||||
```
|
||||
|
||||
Get a copy of A Christmas Carol by Charles Dickens from a trusted source:
|
||||
|
||||
```sh
|
||||
curl https://www.gutenberg.org/cache/epub/24022/pg24022.txt -o ./ragtest/input/book.txt
|
||||
curl https://www.gutenberg.org/cache/epub/24022/pg24022.txt -o ./christmas/input/book.txt
|
||||
```
|
||||
|
||||
## Set Up Your Workspace Variables
|
||||
|
||||
To initialize your workspace, first run the `graphrag init` command.
|
||||
Since we have already configured a directory named `./ragtest` in the previous step, run the following command:
|
||||
Since we have already configured a directory named `./christmas` in the previous step, run the following command:
|
||||
|
||||
```sh
|
||||
graphrag init --root ./ragtest
|
||||
graphrag init --root ./christmas
|
||||
```
|
||||
|
||||
This will create two files: `.env` and `settings.yaml` in the `./ragtest` directory.
|
||||
This will create two files: `.env` and `settings.yaml` in the `./christmas` directory.
|
||||
|
||||
- `.env` contains the environment variables required to run the GraphRAG pipeline. If you inspect the file, you'll see a single environment variable defined,
|
||||
`GRAPHRAG_API_KEY=<API_KEY>`. Replace `<API_KEY>` with your own OpenAI or Azure API key.
|
||||
@ -66,11 +67,10 @@ deployment_name: <azure_model_deployment_name>
|
||||
```
|
||||
|
||||
#### Using Managed Auth on Azure
|
||||
To use managed auth, add an additional value to your model config and comment out or remove the api_key line:
|
||||
To use managed auth, edit the auth_type in your model config and *remove* the api_key line:
|
||||
|
||||
```yaml
|
||||
auth_type: azure_managed_identity # Default auth_type is is api_key
|
||||
# api_key: ${GRAPHRAG_API_KEY}
|
||||
```
|
||||
|
||||
You will also need to login with [az login](https://learn.microsoft.com/en-us/cli/azure/authenticate-azure-cli) and select the subscription with your endpoint.
|
||||
@ -80,13 +80,13 @@ You will also need to login with [az login](https://learn.microsoft.com/en-us/cl
|
||||
Finally we'll run the pipeline!
|
||||
|
||||
```sh
|
||||
graphrag index --root ./ragtest
|
||||
graphrag index --root ./christmas
|
||||
```
|
||||
|
||||

|
||||
|
||||
This process will take some time to run. This depends on the size of your input data, what model you're using, and the text chunk size being used (these can be configured in your `settings.yaml` file).
|
||||
Once the pipeline is complete, you should see a new folder called `./ragtest/output` with a series of parquet files.
|
||||
Once the pipeline is complete, you should see a new folder called `./christmas/output` with a series of parquet files.
|
||||
|
||||
# Using the Query Engine
|
||||
|
||||
@ -96,7 +96,7 @@ Here is an example using Global search to ask a high-level question:
|
||||
|
||||
```sh
|
||||
graphrag query \
|
||||
--root ./ragtest \
|
||||
--root ./christmas \
|
||||
--method global \
|
||||
--query "What are the top themes in this story?"
|
||||
```
|
||||
@ -105,7 +105,7 @@ Here is an example using Local search to ask a more specific question about a pa
|
||||
|
||||
```sh
|
||||
graphrag query \
|
||||
--root ./ragtest \
|
||||
--root ./christmas \
|
||||
--method local \
|
||||
--query "Who is Scrooge and what are his main relationships?"
|
||||
```
|
||||
@ -117,4 +117,4 @@ Please refer to [Query Engine](query/overview.md) docs for detailed information
|
||||
- For more details about configuring GraphRAG, see the [configuration documentation](config/overview.md).
|
||||
- To learn more about Initialization, refer to the [Initialization documentation](config/init.md).
|
||||
- For more details about using the CLI, refer to the [CLI documentation](cli.md).
|
||||
- Check out our [visualization guide](visualization_guide.md) for a more interactive experience in debugging and exploring the knowledge graph.
|
||||
- Check out our [visualization guide](visualization_guide.md) for a more interactive experience in debugging and exploring the knowledge graph.
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
# Welcome to GraphRAG
|
||||
|
||||
👉 [Microsoft Research Blog Post](https://www.microsoft.com/en-us/research/blog/graphrag-unlocking-llm-discovery-on-narrative-private-data/) <br/>
|
||||
👉 [GraphRAG Accelerator](https://github.com/Azure-Samples/graphrag-accelerator) <br/>
|
||||
👉 [GraphRAG Arxiv](https://arxiv.org/pdf/2404.16130)
|
||||
|
||||
<p align="center">
|
||||
@ -16,10 +15,6 @@ approaches using plain text snippets. The GraphRAG process involves extracting a
|
||||
|
||||
To learn more about GraphRAG and how it can be used to enhance your language model's ability to reason about your private data, please visit the [Microsoft Research Blog Post](https://www.microsoft.com/en-us/research/blog/graphrag-unlocking-llm-discovery-on-narrative-private-data/).
|
||||
|
||||
## Solution Accelerator 🚀
|
||||
|
||||
To quickstart the GraphRAG system we recommend trying the [Solution Accelerator](https://github.com/Azure-Samples/graphrag-accelerator) package. This provides a user-friendly end-to-end experience with Azure resources.
|
||||
|
||||
## Get Started with GraphRAG 🚀
|
||||
|
||||
To start using GraphRAG, check out the [_Get Started_](get_started.md) guide.
|
||||
@ -52,6 +47,7 @@ At query time, these structures are used to provide materials for the LLM contex
|
||||
- [_Global Search_](query/global_search.md) for reasoning about holistic questions about the corpus by leveraging the community summaries.
|
||||
- [_Local Search_](query/local_search.md) for reasoning about specific entities by fanning-out to their neighbors and associated concepts.
|
||||
- [_DRIFT Search_](query/drift_search.md) for reasoning about specific entities by fanning-out to their neighbors and associated concepts, but with the added context of community information.
|
||||
- _Basic Search_ for those times when your query is best answered by baseline RAG (standard top _k_ vector search).
|
||||
|
||||
### Prompt Tuning
|
||||
|
||||
|
||||
@ -32,3 +32,20 @@ The GraphRAG library was designed with LLM interactions in mind, and a common se
|
||||
Because of these potential error cases, we've added a cache layer around LLM interactions.
|
||||
When completion requests are made using the same input set (prompt and tuning parameters), we return a cached result if one exists.
|
||||
This allows our indexer to be more resilient to network issues, to act idempotently, and to provide a more efficient end-user experience.
|
||||
|
||||
### Providers & Factories
|
||||
|
||||
Several subsystems within GraphRAG use a factory pattern to register and retrieve provider implementations. This allows deep customization to support models, storage, and so on that you may use but isn't built directly into GraphRAG.
|
||||
|
||||
The following subsystems use a factory pattern that allows you to register your own implementations:
|
||||
|
||||
- [language model](https://github.com/microsoft/graphrag/blob/main/graphrag/language_model/factory.py) - implement your own `chat` and `embed` methods to use a model provider of choice beyond the built-in OpenAI/Azure support
|
||||
- [cache](https://github.com/microsoft/graphrag/blob/main/graphrag/cache/factory.py) - create your own cache storage location in addition to the file, blob, and CosmosDB ones we provide
|
||||
- [logger](https://github.com/microsoft/graphrag/blob/main/graphrag/logger/factory.py) - create your own log writing location in addition to the built-in file and blob storage
|
||||
- [storage](https://github.com/microsoft/graphrag/blob/main/graphrag/storage/factory.py) - create your own storage provider (database, etc.) beyond the file, blob, and CosmosDB ones built in
|
||||
- [vector store](https://github.com/microsoft/graphrag/blob/main/graphrag/vector_stores/factory.py) - implement your own vector store other than the built-in lancedb, Azure AI Search, and CosmosDB ones built in
|
||||
- [pipeline + workflows](https://github.com/microsoft/graphrag/blob/main/graphrag/index/workflows/factory.py) - implement your own workflow steps with a custom `run_workflow` function, or register an entire pipeline (list of named workflows)
|
||||
|
||||
The links for each of these subsystems point to the source code of the factory, which includes registration of the default built-in implementations. In addition, we have a detailed discussion of [language models](../config/models.md), which includes and example of a custom provider, and a [sample notebook](../examples_notebooks/custom_vector_store.ipynb) that demonstrates a custom vector store.
|
||||
|
||||
All of these factories allow you to register an impl using any string name you would like, even overriding built-in ones directly.
|
||||
@ -52,7 +52,7 @@ workflows: [create_communities, create_community_reports, generate_text_embeddin
|
||||
|
||||
### FastGraphRAG
|
||||
|
||||
[FastGraphRAG](./methods.md#fastgraphrag) uses text_units for the community reports instead of the entity and relationship descriptions. If your graph is sourced in such a way that it does not have descriptions, this might be a useful alternative. In this case, you would update your workflows list to include the text variant:
|
||||
[FastGraphRAG](./methods.md#fastgraphrag) uses text_units for the community reports instead of the entity and relationship descriptions. If your graph is sourced in such a way that it does not have descriptions, this might be a useful alternative. In this case, you would update your workflows list to include the text variant of the community reports workflow:
|
||||
|
||||
```yaml
|
||||
workflows: [create_communities, create_community_reports_text, generate_text_embeddings]
|
||||
@ -65,7 +65,6 @@ This method requires that your entities and relationships tables have valid link
|
||||
|
||||
Putting it all together:
|
||||
|
||||
- `input`: GraphRAG does require an input document set, even if you don't need us to process it. You can create an input folder and drop a dummy.txt document in there to work around this.
|
||||
- `output`: Create an output folder and put your entities and relationships (and optionally text_units) parquet files in it.
|
||||
- Update your config as noted above to only run the workflows subset you need.
|
||||
- Run `graphrag index --root <your project root>`
|
||||
@ -16,6 +16,10 @@ All input formats are loaded within GraphRAG and passed to the indexing pipeline
|
||||
|
||||
Also see the [outputs](outputs.md) documentation for the final documents table schema saved to parquet after pipeline completion.
|
||||
|
||||
## Bring-your-own DataFrame
|
||||
|
||||
As of version 2.6.0, GraphRAG's [indexing API method](https://github.com/microsoft/graphrag/blob/main/graphrag/api/index.py) allows you to pass in your own pandas DataFrame and bypass all of the input loading/parsing described in the next section. This is convenient if you have content in a format or storage location we don't support out-of-the-box. __You must ensure that your input DataFrame conforms to the schema described above.__ All of the chunking behavior described later will proceed exactly the same.
|
||||
|
||||
## Formats
|
||||
|
||||
We support three file formats out-of-the-box. This covers the overwhelming majority of use cases we have encountered. If you have a different format, we recommend writing a script to convert to one of these, which are widely used and supported by many tools and libraries.
|
||||
|
||||
@ -10,7 +10,7 @@ This is the method described in the original [blog post](https://www.microsoft.c
|
||||
- relationship extraction: LLM is prompted to describe the relationship between each pair of entities in each text unit.
|
||||
- entity summarization: LLM is prompted to combine the descriptions for every instance of an entity found across the text units into a single summary.
|
||||
- relationship summarization: LLM is prompted to combine the descriptions for every instance of a relationship found across the text units into a single summary.
|
||||
- claim extraction (optiona): LLM is prompted to extract and describe claims from each text unit.
|
||||
- claim extraction (optional): LLM is prompted to extract and describe claims from each text unit.
|
||||
- community report generation: entity and relationship descriptions (and optionally claims) for each community are collected and used to prompt the LLM to generate a summary report.
|
||||
|
||||
`graphrag index --method standard`. This is the default method, so the method param can actual be omitted.
|
||||
@ -23,7 +23,7 @@ FastGraphRAG is a method that substitutes some of the language model reasoning f
|
||||
- relationship extraction: relationships are defined as text unit co-occurrence between entity pairs. There is no description.
|
||||
- entity summarization: not necessary.
|
||||
- relationship summarization: not necessary.
|
||||
- claim extraction (optiona): unused.
|
||||
- claim extraction (optional): unused.
|
||||
- community report generation: The direct text unit content containing each entity noun phrase is collected and used to prompt the LLM to generate a summary report.
|
||||
|
||||
`graphrag index --method fast`
|
||||
@ -41,4 +41,4 @@ You can install it manually by running `python -m spacy download <model_name>`,
|
||||
|
||||
## Choosing a Method
|
||||
|
||||
Standard GraphRAG provides a rich description of real-world entities and relationships, but is more expensive that FastGraphRAG. We estimate graph extraction to constitute roughly 75% of indexing cost. FastGraphRAG is therefore much cheaper, but the tradeoff is that the extracted graph is less directly relevant for use outside of GraphRAG, and the graph tends to be quite a bit noisier. If high fidelity entities and graph exploration are important to your use case, we recommend staying with traditional GraphRAG. If your use case is primarily aimed at summary questions using global search, FastGraphRAG is a reasonable and cheaper alternative.
|
||||
Standard GraphRAG provides a rich description of real-world entities and relationships, but is more expensive that FastGraphRAG. We estimate graph extraction to constitute roughly 75% of indexing cost. FastGraphRAG is therefore much cheaper, but the tradeoff is that the extracted graph is less directly relevant for use outside of GraphRAG, and the graph tends to be quite a bit noisier. If high fidelity entities and graph exploration are important to your use case, we recommend staying with traditional GraphRAG. If your use case is primarily aimed at summary questions using global search, FastGraphRAG provides high quality summarization at much less LLM cost.
|
||||
|
||||
@ -26,8 +26,7 @@ After you have a config file you can run the pipeline using the CLI or the Pytho
|
||||
### CLI
|
||||
|
||||
```bash
|
||||
# Via Poetry
|
||||
poetry run poe index --root <data_root> # default config mode
|
||||
uv run poe index --root <data_root> # default config mode
|
||||
```
|
||||
|
||||
### Python API
|
||||
|
||||
@ -58,13 +58,13 @@ graphrag prompt-tune [--root ROOT] [--config CONFIG] [--domain DOMAIN] [--selec
|
||||
```bash
|
||||
python -m graphrag prompt-tune --root /path/to/project --config /path/to/settings.yaml --domain "environmental news" \
|
||||
--selection-method random --limit 10 --language English --max-tokens 2048 --chunk-size 256 --min-examples-required 3 \
|
||||
--no-entity-types --output /path/to/output
|
||||
--no-discover-entity-types --output /path/to/output
|
||||
```
|
||||
|
||||
or, with minimal configuration (suggested):
|
||||
|
||||
```bash
|
||||
python -m graphrag prompt-tune --root /path/to/project --config /path/to/settings.yaml --no-entity-types
|
||||
python -m graphrag prompt-tune --root /path/to/project --config /path/to/settings.yaml --no-discover-entity-types
|
||||
```
|
||||
|
||||
## Document Selection Methods
|
||||
@ -79,15 +79,7 @@ After that, it uses one of the following selection methods to pick a sample to w
|
||||
|
||||
## Modify Env Vars
|
||||
|
||||
After running auto tuning, you should modify the following environment variables (or config variables) to pick up the new prompts on your index run. Note: Please make sure to update the correct path to the generated prompts, in this example we are using the default "prompts" path.
|
||||
|
||||
- `GRAPHRAG_ENTITY_EXTRACTION_PROMPT_FILE` = "prompts/entity_extraction.txt"
|
||||
|
||||
- `GRAPHRAG_COMMUNITY_REPORT_PROMPT_FILE` = "prompts/community_report.txt"
|
||||
|
||||
- `GRAPHRAG_SUMMARIZE_DESCRIPTIONS_PROMPT_FILE` = "prompts/summarize_descriptions.txt"
|
||||
|
||||
or in your yaml config file:
|
||||
After running auto tuning, you should modify the following config variables to pick up the new prompts on your index run. Note: Please make sure to update the correct path to the generated prompts, in this example we are using the default "prompts" path.
|
||||
|
||||
```yaml
|
||||
entity_extraction:
|
||||
|
||||
@ -10,7 +10,7 @@ Each of these prompts may be overridden by writing a custom prompt file in plain
|
||||
|
||||
### Entity/Relationship Extraction
|
||||
|
||||
[Prompt Source](http://github.com/microsoft/graphrag/blob/main/graphrag/prompts/index/entity_extraction.py)
|
||||
[Prompt Source](http://github.com/microsoft/graphrag/blob/main/graphrag/prompts/index/extract_graph.py)
|
||||
|
||||
#### Tokens
|
||||
|
||||
@ -31,7 +31,7 @@ Each of these prompts may be overridden by writing a custom prompt file in plain
|
||||
|
||||
### Claim Extraction
|
||||
|
||||
[Prompt Source](http://github.com/microsoft/graphrag/blob/main/graphrag/prompts/index/claim_extraction.py)
|
||||
[Prompt Source](http://github.com/microsoft/graphrag/blob/main/graphrag/prompts/index/extract_claims.py)
|
||||
|
||||
#### Tokens
|
||||
|
||||
|
||||
@ -24,7 +24,7 @@ Below are the key parameters of the [DRIFTSearch class](https://github.com/micro
|
||||
- `llm`: OpenAI model object to be used for response generation
|
||||
- `context_builder`: [context builder](https://github.com/microsoft/graphrag/blob/main/graphrag/query/structured_search/drift_search/drift_context.py) object to be used for preparing context data from community reports and query information
|
||||
- `config`: model to define the DRIFT Search hyperparameters. [DRIFT Config model](https://github.com/microsoft/graphrag/blob/main/graphrag/config/models/drift_search_config.py)
|
||||
- `token_encoder`: token encoder for tracking the budget for the algorithm.
|
||||
- `tokenizer`: token encoder for tracking the budget for the algorithm.
|
||||
- `query_state`: a state object as defined in [Query State](https://github.com/microsoft/graphrag/blob/main/graphrag/query/structured_search/drift_search/state.py) that allows to track execution of a DRIFT Search instance, alongside follow ups and [DRIFT actions](https://github.com/microsoft/graphrag/blob/main/graphrag/query/structured_search/drift_search/action.py).
|
||||
|
||||
## How to Use
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
# API Notebooks
|
||||
|
||||
- [API Overview Notebook](../../examples_notebooks/api_overview.ipynb)
|
||||
- [Bring-Your-Own Vector Store](../../examples_notebooks/custom_vector_store.ipynb)
|
||||
|
||||
# Query Engine Notebooks
|
||||
|
||||
|
||||
@ -29,6 +29,9 @@
|
||||
"\n",
|
||||
"import pandas as pd\n",
|
||||
"import tiktoken\n",
|
||||
"from graphrag.query.llm.oai.chat_openai import ChatOpenAI\n",
|
||||
"from graphrag.query.llm.oai.embedding import OpenAIEmbedding\n",
|
||||
"from graphrag.query.llm.oai.typing import OpenaiApiType\n",
|
||||
"\n",
|
||||
"from graphrag.query.context_builder.entity_extraction import EntityVectorStoreKey\n",
|
||||
"from graphrag.query.indexer_adapters import (\n",
|
||||
@ -38,9 +41,6 @@
|
||||
" read_indexer_reports,\n",
|
||||
" read_indexer_text_units,\n",
|
||||
")\n",
|
||||
"from graphrag.query.llm.oai.chat_openai import ChatOpenAI\n",
|
||||
"from graphrag.query.llm.oai.embedding import OpenAIEmbedding\n",
|
||||
"from graphrag.query.llm.oai.typing import OpenaiApiType\n",
|
||||
"from graphrag.query.structured_search.local_search.mixed_context import (\n",
|
||||
" LocalSearchMixedContext,\n",
|
||||
")\n",
|
||||
|
||||
@ -9,29 +9,32 @@ Backwards compatibility is not guaranteed at this time.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from graphrag.callbacks.reporting import create_pipeline_reporter
|
||||
import pandas as pd
|
||||
|
||||
from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks
|
||||
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
|
||||
from graphrag.config.enums import IndexingMethod
|
||||
from graphrag.config.models.graph_rag_config import GraphRagConfig
|
||||
from graphrag.index.run.run_pipeline import run_pipeline
|
||||
from graphrag.index.run.utils import create_callback_chain
|
||||
from graphrag.index.typing.pipeline_run_result import PipelineRunResult
|
||||
from graphrag.index.typing.workflow import WorkflowFunction
|
||||
from graphrag.index.workflows.factory import PipelineFactory
|
||||
from graphrag.logger.base import ProgressLogger
|
||||
from graphrag.logger.null_progress import NullProgressLogger
|
||||
from graphrag.logger.standard_logging import init_loggers
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def build_index(
|
||||
config: GraphRagConfig,
|
||||
method: IndexingMethod = IndexingMethod.Standard,
|
||||
method: IndexingMethod | str = IndexingMethod.Standard,
|
||||
is_update_run: bool = False,
|
||||
memory_profile: bool = False,
|
||||
callbacks: list[WorkflowCallbacks] | None = None,
|
||||
progress_logger: ProgressLogger | None = None,
|
||||
additional_context: dict[str, Any] | None = None,
|
||||
verbose: bool = False,
|
||||
input_documents: pd.DataFrame | None = None,
|
||||
) -> list[PipelineRunResult]:
|
||||
"""Run the pipeline with the given configuration.
|
||||
|
||||
@ -45,27 +48,32 @@ async def build_index(
|
||||
Whether to enable memory profiling.
|
||||
callbacks : list[WorkflowCallbacks] | None default=None
|
||||
A list of callbacks to register.
|
||||
progress_logger : ProgressLogger | None default=None
|
||||
The progress logger.
|
||||
additional_context : dict[str, Any] | None default=None
|
||||
Additional context to pass to the pipeline run. This can be accessed in the pipeline state under the 'additional_context' key.
|
||||
input_documents : pd.DataFrame | None default=None.
|
||||
Override document loading and parsing and supply your own dataframe of documents to index.
|
||||
|
||||
Returns
|
||||
-------
|
||||
list[PipelineRunResult]
|
||||
The list of pipeline run results
|
||||
"""
|
||||
logger = progress_logger or NullProgressLogger()
|
||||
# create a pipeline reporter and add to any additional callbacks
|
||||
callbacks = callbacks or []
|
||||
callbacks.append(create_pipeline_reporter(config.reporting, None))
|
||||
init_loggers(config=config, verbose=verbose)
|
||||
|
||||
workflow_callbacks = create_callback_chain(callbacks, logger)
|
||||
# Create callbacks for pipeline lifecycle events if provided
|
||||
workflow_callbacks = (
|
||||
create_callback_chain(callbacks) if callbacks else NoopWorkflowCallbacks()
|
||||
)
|
||||
|
||||
outputs: list[PipelineRunResult] = []
|
||||
|
||||
if memory_profile:
|
||||
log.warning("New pipeline does not yet support memory profiling.")
|
||||
logger.warning("New pipeline does not yet support memory profiling.")
|
||||
|
||||
pipeline = PipelineFactory.create_pipeline(config, method, is_update_run)
|
||||
logger.info("Initializing indexing pipeline...")
|
||||
# todo: this could propagate out to the cli for better clarity, but will be a breaking api change
|
||||
method = _get_method(method, is_update_run)
|
||||
pipeline = PipelineFactory.create_pipeline(config, method)
|
||||
|
||||
workflow_callbacks.pipeline_start(pipeline.names())
|
||||
|
||||
@ -73,20 +81,21 @@ async def build_index(
|
||||
pipeline,
|
||||
config,
|
||||
callbacks=workflow_callbacks,
|
||||
logger=logger,
|
||||
is_update_run=is_update_run,
|
||||
additional_context=additional_context,
|
||||
input_documents=input_documents,
|
||||
):
|
||||
outputs.append(output)
|
||||
if output.errors and len(output.errors) > 0:
|
||||
logger.error(output.workflow)
|
||||
logger.error("Workflow %s completed with errors", output.workflow)
|
||||
else:
|
||||
logger.success(output.workflow)
|
||||
logger.info(str(output.result))
|
||||
logger.info("Workflow %s completed successfully", output.workflow)
|
||||
logger.debug(str(output.result))
|
||||
|
||||
workflow_callbacks.pipeline_end(outputs)
|
||||
return outputs
|
||||
|
||||
|
||||
def register_workflow_function(name: str, workflow: WorkflowFunction):
|
||||
"""Register a custom workflow function. You can then include the name in the settings.yaml workflows list."""
|
||||
PipelineFactory.register(name, workflow)
|
||||
def _get_method(method: IndexingMethod | str, is_update_run: bool) -> str:
|
||||
m = method.value if isinstance(method, IndexingMethod) else method
|
||||
return f"{m}-update" if is_update_run else m
|
||||
|
||||
@ -11,6 +11,7 @@ WARNING: This API is under development and may undergo changes in future release
|
||||
Backwards compatibility is not guaranteed at this time.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Annotated
|
||||
|
||||
import annotated_types
|
||||
@ -20,7 +21,7 @@ from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks
|
||||
from graphrag.config.defaults import graphrag_config_defaults
|
||||
from graphrag.config.models.graph_rag_config import GraphRagConfig
|
||||
from graphrag.language_model.manager import ModelManager
|
||||
from graphrag.logger.base import ProgressLogger
|
||||
from graphrag.logger.standard_logging import init_loggers
|
||||
from graphrag.prompt_tune.defaults import MAX_TOKEN_COUNT, PROMPT_TUNING_MODEL_ID
|
||||
from graphrag.prompt_tune.generator.community_report_rating import (
|
||||
generate_community_report_rating,
|
||||
@ -46,13 +47,14 @@ from graphrag.prompt_tune.generator.language import detect_language
|
||||
from graphrag.prompt_tune.generator.persona import generate_persona
|
||||
from graphrag.prompt_tune.loader.input import load_docs_in_chunks
|
||||
from graphrag.prompt_tune.types import DocSelectionType
|
||||
from graphrag.tokenizer.get_tokenizer import get_tokenizer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@validate_call(config={"arbitrary_types_allowed": True})
|
||||
async def generate_indexing_prompts(
|
||||
config: GraphRagConfig,
|
||||
logger: ProgressLogger,
|
||||
root: str,
|
||||
chunk_size: PositiveInt = graphrag_config_defaults.chunks.size,
|
||||
overlap: Annotated[
|
||||
int, annotated_types.Gt(-1)
|
||||
@ -66,14 +68,13 @@ async def generate_indexing_prompts(
|
||||
min_examples_required: PositiveInt = 2,
|
||||
n_subset_max: PositiveInt = 300,
|
||||
k: PositiveInt = 15,
|
||||
verbose: bool = False,
|
||||
) -> tuple[str, str, str]:
|
||||
"""Generate indexing prompts.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
- config: The GraphRag configuration.
|
||||
- logger: The logger to use for progress updates.
|
||||
- root: The root directory.
|
||||
- output_path: The path to store the prompts.
|
||||
- chunk_size: The chunk token size to use for input text units.
|
||||
- limit: The limit of chunks to load.
|
||||
@ -90,10 +91,11 @@ async def generate_indexing_prompts(
|
||||
-------
|
||||
tuple[str, str, str]: entity extraction prompt, entity summarization prompt, community summarization prompt
|
||||
"""
|
||||
init_loggers(config=config, verbose=verbose, filename="prompt-tuning.log")
|
||||
|
||||
# Retrieve documents
|
||||
logger.info("Chunking documents...")
|
||||
doc_list = await load_docs_in_chunks(
|
||||
root=root,
|
||||
config=config,
|
||||
limit=limit,
|
||||
select_method=selection_method,
|
||||
@ -165,7 +167,7 @@ async def generate_indexing_prompts(
|
||||
examples=examples,
|
||||
language=language,
|
||||
json_mode=False, # config.llm.model_supports_json should be used, but these prompts are used in non-json mode by the index engine
|
||||
encoding_model=extract_graph_llm_settings.encoding_model,
|
||||
tokenizer=get_tokenizer(model_config=extract_graph_llm_settings),
|
||||
max_token_count=max_tokens,
|
||||
min_examples_required=min_examples_required,
|
||||
)
|
||||
@ -189,9 +191,9 @@ async def generate_indexing_prompts(
|
||||
language=language,
|
||||
)
|
||||
|
||||
logger.info(f"\nGenerated domain: {domain}") # noqa: G004
|
||||
logger.info(f"\nDetected language: {language}") # noqa: G004
|
||||
logger.info(f"\nGenerated persona: {persona}") # noqa: G004
|
||||
logger.debug("Generated domain: %s", domain)
|
||||
logger.debug("Detected language: %s", language)
|
||||
logger.debug("Generated persona: %s", persona)
|
||||
|
||||
return (
|
||||
extract_graph_prompt,
|
||||
|
||||
@ -17,6 +17,7 @@ WARNING: This API is under development and may undergo changes in future release
|
||||
Backwards compatibility is not guaranteed at this time.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Any
|
||||
|
||||
@ -31,7 +32,7 @@ from graphrag.config.embeddings import (
|
||||
text_unit_text_embedding,
|
||||
)
|
||||
from graphrag.config.models.graph_rag_config import GraphRagConfig
|
||||
from graphrag.logger.print_progress import PrintProgressLogger
|
||||
from graphrag.logger.standard_logging import init_loggers
|
||||
from graphrag.query.factory import (
|
||||
get_basic_search_engine,
|
||||
get_drift_search_engine,
|
||||
@ -50,11 +51,13 @@ from graphrag.query.indexer_adapters import (
|
||||
from graphrag.utils.api import (
|
||||
get_embedding_store,
|
||||
load_search_prompt,
|
||||
truncate,
|
||||
update_context_data,
|
||||
)
|
||||
from graphrag.utils.cli import redact
|
||||
|
||||
logger = PrintProgressLogger("")
|
||||
# Initialize standard logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@validate_call(config={"arbitrary_types_allowed": True})
|
||||
@ -68,6 +71,7 @@ async def global_search(
|
||||
response_type: str,
|
||||
query: str,
|
||||
callbacks: list[QueryCallbacks] | None = None,
|
||||
verbose: bool = False,
|
||||
) -> tuple[
|
||||
str | dict[str, Any] | list[dict[str, Any]],
|
||||
str | list[pd.DataFrame] | dict[str, pd.DataFrame],
|
||||
@ -88,11 +92,9 @@ async def global_search(
|
||||
Returns
|
||||
-------
|
||||
TODO: Document the search response type and format.
|
||||
|
||||
Raises
|
||||
------
|
||||
TODO: Document any exceptions to expect.
|
||||
"""
|
||||
init_loggers(config=config, verbose=verbose, filename="query.log")
|
||||
|
||||
callbacks = callbacks or []
|
||||
full_response = ""
|
||||
context_data = {}
|
||||
@ -105,6 +107,7 @@ async def global_search(
|
||||
local_callbacks.on_context = on_context
|
||||
callbacks.append(local_callbacks)
|
||||
|
||||
logger.debug("Executing global search query: %s", query)
|
||||
async for chunk in global_search_streaming(
|
||||
config=config,
|
||||
entities=entities,
|
||||
@ -117,6 +120,7 @@ async def global_search(
|
||||
callbacks=callbacks,
|
||||
):
|
||||
full_response += chunk
|
||||
logger.debug("Query response: %s", truncate(full_response, 400))
|
||||
return full_response, context_data
|
||||
|
||||
|
||||
@ -131,6 +135,7 @@ def global_search_streaming(
|
||||
response_type: str,
|
||||
query: str,
|
||||
callbacks: list[QueryCallbacks] | None = None,
|
||||
verbose: bool = False,
|
||||
) -> AsyncGenerator:
|
||||
"""Perform a global search and return the context data and response via a generator.
|
||||
|
||||
@ -150,11 +155,9 @@ def global_search_streaming(
|
||||
Returns
|
||||
-------
|
||||
TODO: Document the search response type and format.
|
||||
|
||||
Raises
|
||||
------
|
||||
TODO: Document any exceptions to expect.
|
||||
"""
|
||||
init_loggers(config=config, verbose=verbose, filename="query.log")
|
||||
|
||||
communities_ = read_indexer_communities(communities, community_reports)
|
||||
reports = read_indexer_reports(
|
||||
community_reports,
|
||||
@ -173,6 +176,7 @@ def global_search_streaming(
|
||||
config.root_dir, config.global_search.knowledge_prompt
|
||||
)
|
||||
|
||||
logger.debug("Executing streaming global search query: %s", query)
|
||||
search_engine = get_global_search_engine(
|
||||
config,
|
||||
reports=reports,
|
||||
@ -201,6 +205,7 @@ async def multi_index_global_search(
|
||||
streaming: bool,
|
||||
query: str,
|
||||
callbacks: list[QueryCallbacks] | None = None,
|
||||
verbose: bool = False,
|
||||
) -> tuple[
|
||||
str | dict[str, Any] | list[dict[str, Any]],
|
||||
str | list[pd.DataFrame] | dict[str, pd.DataFrame],
|
||||
@ -223,11 +228,13 @@ async def multi_index_global_search(
|
||||
Returns
|
||||
-------
|
||||
TODO: Document the search response type and format.
|
||||
|
||||
Raises
|
||||
------
|
||||
TODO: Document any exceptions to expect.
|
||||
"""
|
||||
init_loggers(config=config, verbose=verbose, filename="query.log")
|
||||
|
||||
logger.warning(
|
||||
"Multi-index search is deprecated and will be removed in GraphRAG v3."
|
||||
)
|
||||
|
||||
# Streaming not supported yet
|
||||
if streaming:
|
||||
message = "Streaming not yet implemented for multi_global_search"
|
||||
@ -311,6 +318,7 @@ async def multi_index_global_search(
|
||||
communities_dfs, axis=0, ignore_index=True, sort=False
|
||||
)
|
||||
|
||||
logger.debug("Executing multi-index global search query: %s", query)
|
||||
result = await global_search(
|
||||
config,
|
||||
entities=entities_combined,
|
||||
@ -326,6 +334,7 @@ async def multi_index_global_search(
|
||||
# Update the context data by linking index names and community ids
|
||||
context = update_context_data(result[1], links)
|
||||
|
||||
logger.debug("Query response: %s", truncate(result[0], 400)) # type: ignore
|
||||
return (result[0], context)
|
||||
|
||||
|
||||
@ -342,6 +351,7 @@ async def local_search(
|
||||
response_type: str,
|
||||
query: str,
|
||||
callbacks: list[QueryCallbacks] | None = None,
|
||||
verbose: bool = False,
|
||||
) -> tuple[
|
||||
str | dict[str, Any] | list[dict[str, Any]],
|
||||
str | list[pd.DataFrame] | dict[str, pd.DataFrame],
|
||||
@ -362,11 +372,9 @@ async def local_search(
|
||||
Returns
|
||||
-------
|
||||
TODO: Document the search response type and format.
|
||||
|
||||
Raises
|
||||
------
|
||||
TODO: Document any exceptions to expect.
|
||||
"""
|
||||
init_loggers(config=config, verbose=verbose, filename="query.log")
|
||||
|
||||
callbacks = callbacks or []
|
||||
full_response = ""
|
||||
context_data = {}
|
||||
@ -379,6 +387,7 @@ async def local_search(
|
||||
local_callbacks.on_context = on_context
|
||||
callbacks.append(local_callbacks)
|
||||
|
||||
logger.debug("Executing local search query: %s", query)
|
||||
async for chunk in local_search_streaming(
|
||||
config=config,
|
||||
entities=entities,
|
||||
@ -393,6 +402,7 @@ async def local_search(
|
||||
callbacks=callbacks,
|
||||
):
|
||||
full_response += chunk
|
||||
logger.debug("Query response: %s", truncate(full_response, 400))
|
||||
return full_response, context_data
|
||||
|
||||
|
||||
@ -409,6 +419,7 @@ def local_search_streaming(
|
||||
response_type: str,
|
||||
query: str,
|
||||
callbacks: list[QueryCallbacks] | None = None,
|
||||
verbose: bool = False,
|
||||
) -> AsyncGenerator:
|
||||
"""Perform a local search and return the context data and response via a generator.
|
||||
|
||||
@ -427,16 +438,14 @@ def local_search_streaming(
|
||||
Returns
|
||||
-------
|
||||
TODO: Document the search response type and format.
|
||||
|
||||
Raises
|
||||
------
|
||||
TODO: Document any exceptions to expect.
|
||||
"""
|
||||
init_loggers(config=config, verbose=verbose, filename="query.log")
|
||||
|
||||
vector_store_args = {}
|
||||
for index, store in config.vector_store.items():
|
||||
vector_store_args[index] = store.model_dump()
|
||||
msg = f"Vector Store Args: {redact(vector_store_args)}"
|
||||
logger.info(msg)
|
||||
logger.debug(msg)
|
||||
|
||||
description_embedding_store = get_embedding_store(
|
||||
config_args=vector_store_args,
|
||||
@ -447,6 +456,7 @@ def local_search_streaming(
|
||||
covariates_ = read_indexer_covariates(covariates) if covariates is not None else []
|
||||
prompt = load_search_prompt(config.root_dir, config.local_search.prompt)
|
||||
|
||||
logger.debug("Executing streaming local search query: %s", query)
|
||||
search_engine = get_local_search_engine(
|
||||
config=config,
|
||||
reports=read_indexer_reports(community_reports, communities, community_level),
|
||||
@ -477,6 +487,7 @@ async def multi_index_local_search(
|
||||
streaming: bool,
|
||||
query: str,
|
||||
callbacks: list[QueryCallbacks] | None = None,
|
||||
verbose: bool = False,
|
||||
) -> tuple[
|
||||
str | dict[str, Any] | list[dict[str, Any]],
|
||||
str | list[pd.DataFrame] | dict[str, pd.DataFrame],
|
||||
@ -500,11 +511,12 @@ async def multi_index_local_search(
|
||||
Returns
|
||||
-------
|
||||
TODO: Document the search response type and format.
|
||||
|
||||
Raises
|
||||
------
|
||||
TODO: Document any exceptions to expect.
|
||||
"""
|
||||
init_loggers(config=config, verbose=verbose, filename="query.log")
|
||||
|
||||
logger.warning(
|
||||
"Multi-index search is deprecated and will be removed in GraphRAG v3."
|
||||
)
|
||||
# Streaming not supported yet
|
||||
if streaming:
|
||||
message = "Streaming not yet implemented for multi_index_local_search"
|
||||
@ -670,6 +682,7 @@ async def multi_index_local_search(
|
||||
covariates_combined = pd.concat(
|
||||
covariates_dfs, axis=0, ignore_index=True, sort=False
|
||||
)
|
||||
logger.debug("Executing multi-index local search query: %s", query)
|
||||
result = await local_search(
|
||||
config,
|
||||
entities=entities_combined,
|
||||
@ -687,6 +700,7 @@ async def multi_index_local_search(
|
||||
# Update the context data by linking index names and community ids
|
||||
context = update_context_data(result[1], links)
|
||||
|
||||
logger.debug("Query response: %s", truncate(result[0], 400)) # type: ignore
|
||||
return (result[0], context)
|
||||
|
||||
|
||||
@ -702,6 +716,7 @@ async def drift_search(
|
||||
response_type: str,
|
||||
query: str,
|
||||
callbacks: list[QueryCallbacks] | None = None,
|
||||
verbose: bool = False,
|
||||
) -> tuple[
|
||||
str | dict[str, Any] | list[dict[str, Any]],
|
||||
str | list[pd.DataFrame] | dict[str, pd.DataFrame],
|
||||
@ -721,11 +736,9 @@ async def drift_search(
|
||||
Returns
|
||||
-------
|
||||
TODO: Document the search response type and format.
|
||||
|
||||
Raises
|
||||
------
|
||||
TODO: Document any exceptions to expect.
|
||||
"""
|
||||
init_loggers(config=config, verbose=verbose, filename="query.log")
|
||||
|
||||
callbacks = callbacks or []
|
||||
full_response = ""
|
||||
context_data = {}
|
||||
@ -738,6 +751,7 @@ async def drift_search(
|
||||
local_callbacks.on_context = on_context
|
||||
callbacks.append(local_callbacks)
|
||||
|
||||
logger.debug("Executing drift search query: %s", query)
|
||||
async for chunk in drift_search_streaming(
|
||||
config=config,
|
||||
entities=entities,
|
||||
@ -751,6 +765,7 @@ async def drift_search(
|
||||
callbacks=callbacks,
|
||||
):
|
||||
full_response += chunk
|
||||
logger.debug("Query response: %s", truncate(full_response, 400))
|
||||
return full_response, context_data
|
||||
|
||||
|
||||
@ -766,6 +781,7 @@ def drift_search_streaming(
|
||||
response_type: str,
|
||||
query: str,
|
||||
callbacks: list[QueryCallbacks] | None = None,
|
||||
verbose: bool = False,
|
||||
) -> AsyncGenerator:
|
||||
"""Perform a DRIFT search and return the context data and response.
|
||||
|
||||
@ -782,16 +798,14 @@ def drift_search_streaming(
|
||||
Returns
|
||||
-------
|
||||
TODO: Document the search response type and format.
|
||||
|
||||
Raises
|
||||
------
|
||||
TODO: Document any exceptions to expect.
|
||||
"""
|
||||
init_loggers(config=config, verbose=verbose, filename="query.log")
|
||||
|
||||
vector_store_args = {}
|
||||
for index, store in config.vector_store.items():
|
||||
vector_store_args[index] = store.model_dump()
|
||||
msg = f"Vector Store Args: {redact(vector_store_args)}"
|
||||
logger.info(msg)
|
||||
logger.debug(msg)
|
||||
|
||||
description_embedding_store = get_embedding_store(
|
||||
config_args=vector_store_args,
|
||||
@ -811,6 +825,7 @@ def drift_search_streaming(
|
||||
config.root_dir, config.drift_search.reduce_prompt
|
||||
)
|
||||
|
||||
logger.debug("Executing streaming drift search query: %s", query)
|
||||
search_engine = get_drift_search_engine(
|
||||
config=config,
|
||||
reports=reports,
|
||||
@ -840,6 +855,7 @@ async def multi_index_drift_search(
|
||||
streaming: bool,
|
||||
query: str,
|
||||
callbacks: list[QueryCallbacks] | None = None,
|
||||
verbose: bool = False,
|
||||
) -> tuple[
|
||||
str | dict[str, Any] | list[dict[str, Any]],
|
||||
str | list[pd.DataFrame] | dict[str, pd.DataFrame],
|
||||
@ -862,11 +878,13 @@ async def multi_index_drift_search(
|
||||
Returns
|
||||
-------
|
||||
TODO: Document the search response type and format.
|
||||
|
||||
Raises
|
||||
------
|
||||
TODO: Document any exceptions to expect.
|
||||
"""
|
||||
init_loggers(config=config, verbose=verbose, filename="query.log")
|
||||
|
||||
logger.warning(
|
||||
"Multi-index search is deprecated and will be removed in GraphRAG v3."
|
||||
)
|
||||
|
||||
# Streaming not supported yet
|
||||
if streaming:
|
||||
message = "Streaming not yet implemented for multi_drift_search"
|
||||
@ -1009,6 +1027,7 @@ async def multi_index_drift_search(
|
||||
text_units_dfs, axis=0, ignore_index=True, sort=False
|
||||
)
|
||||
|
||||
logger.debug("Executing multi-index drift search query: %s", query)
|
||||
result = await drift_search(
|
||||
config,
|
||||
entities=entities_combined,
|
||||
@ -1029,6 +1048,8 @@ async def multi_index_drift_search(
|
||||
context[key] = update_context_data(result[1][key], links)
|
||||
else:
|
||||
context = result[1]
|
||||
|
||||
logger.debug("Query response: %s", truncate(result[0], 400)) # type: ignore
|
||||
return (result[0], context)
|
||||
|
||||
|
||||
@ -1038,6 +1059,7 @@ async def basic_search(
|
||||
text_units: pd.DataFrame,
|
||||
query: str,
|
||||
callbacks: list[QueryCallbacks] | None = None,
|
||||
verbose: bool = False,
|
||||
) -> tuple[
|
||||
str | dict[str, Any] | list[dict[str, Any]],
|
||||
str | list[pd.DataFrame] | dict[str, pd.DataFrame],
|
||||
@ -1053,11 +1075,9 @@ async def basic_search(
|
||||
Returns
|
||||
-------
|
||||
TODO: Document the search response type and format.
|
||||
|
||||
Raises
|
||||
------
|
||||
TODO: Document any exceptions to expect.
|
||||
"""
|
||||
init_loggers(config=config, verbose=verbose, filename="query.log")
|
||||
|
||||
callbacks = callbacks or []
|
||||
full_response = ""
|
||||
context_data = {}
|
||||
@ -1070,6 +1090,7 @@ async def basic_search(
|
||||
local_callbacks.on_context = on_context
|
||||
callbacks.append(local_callbacks)
|
||||
|
||||
logger.debug("Executing basic search query: %s", query)
|
||||
async for chunk in basic_search_streaming(
|
||||
config=config,
|
||||
text_units=text_units,
|
||||
@ -1077,6 +1098,7 @@ async def basic_search(
|
||||
callbacks=callbacks,
|
||||
):
|
||||
full_response += chunk
|
||||
logger.debug("Query response: %s", truncate(full_response, 400))
|
||||
return full_response, context_data
|
||||
|
||||
|
||||
@ -1086,6 +1108,7 @@ def basic_search_streaming(
|
||||
text_units: pd.DataFrame,
|
||||
query: str,
|
||||
callbacks: list[QueryCallbacks] | None = None,
|
||||
verbose: bool = False,
|
||||
) -> AsyncGenerator:
|
||||
"""Perform a local search and return the context data and response via a generator.
|
||||
|
||||
@ -1098,28 +1121,27 @@ def basic_search_streaming(
|
||||
Returns
|
||||
-------
|
||||
TODO: Document the search response type and format.
|
||||
|
||||
Raises
|
||||
------
|
||||
TODO: Document any exceptions to expect.
|
||||
"""
|
||||
init_loggers(config=config, verbose=verbose, filename="query.log")
|
||||
|
||||
vector_store_args = {}
|
||||
for index, store in config.vector_store.items():
|
||||
vector_store_args[index] = store.model_dump()
|
||||
msg = f"Vector Store Args: {redact(vector_store_args)}"
|
||||
logger.info(msg)
|
||||
logger.debug(msg)
|
||||
|
||||
description_embedding_store = get_embedding_store(
|
||||
embedding_store = get_embedding_store(
|
||||
config_args=vector_store_args,
|
||||
embedding_name=text_unit_text_embedding,
|
||||
)
|
||||
|
||||
prompt = load_search_prompt(config.root_dir, config.basic_search.prompt)
|
||||
|
||||
logger.debug("Executing streaming basic search query: %s", query)
|
||||
search_engine = get_basic_search_engine(
|
||||
config=config,
|
||||
text_units=read_indexer_text_units(text_units),
|
||||
text_unit_embeddings=description_embedding_store,
|
||||
text_unit_embeddings=embedding_store,
|
||||
system_prompt=prompt,
|
||||
callbacks=callbacks,
|
||||
)
|
||||
@ -1134,6 +1156,7 @@ async def multi_index_basic_search(
|
||||
streaming: bool,
|
||||
query: str,
|
||||
callbacks: list[QueryCallbacks] | None = None,
|
||||
verbose: bool = False,
|
||||
) -> tuple[
|
||||
str | dict[str, Any] | list[dict[str, Any]],
|
||||
str | list[pd.DataFrame] | dict[str, pd.DataFrame],
|
||||
@ -1151,11 +1174,13 @@ async def multi_index_basic_search(
|
||||
Returns
|
||||
-------
|
||||
TODO: Document the search response type and format.
|
||||
|
||||
Raises
|
||||
------
|
||||
TODO: Document any exceptions to expect.
|
||||
"""
|
||||
init_loggers(config=config, verbose=verbose, filename="query.log")
|
||||
|
||||
logger.warning(
|
||||
"Multi-index search is deprecated and will be removed in GraphRAG v3."
|
||||
)
|
||||
|
||||
# Streaming not supported yet
|
||||
if streaming:
|
||||
message = "Streaming not yet implemented for multi_basic_search"
|
||||
@ -1192,6 +1217,7 @@ async def multi_index_basic_search(
|
||||
text_units_dfs, axis=0, ignore_index=True, sort=False
|
||||
)
|
||||
|
||||
logger.debug("Executing multi-index basic search query: %s", query)
|
||||
return await basic_search(
|
||||
config,
|
||||
text_units=text_units_combined,
|
||||
|
||||
128
graphrag/cache/factory.py
vendored
128
graphrag/cache/factory.py
vendored
@ -1,23 +1,24 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""A module containing create_cache method definition."""
|
||||
"""Factory functions for creating a cache."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, ClassVar
|
||||
|
||||
from graphrag.config.enums import CacheType
|
||||
from graphrag.storage.blob_pipeline_storage import create_blob_storage
|
||||
from graphrag.storage.cosmosdb_pipeline_storage import create_cosmosdb_storage
|
||||
from graphrag.storage.file_pipeline_storage import FilePipelineStorage
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from graphrag.cache.pipeline_cache import PipelineCache
|
||||
|
||||
from graphrag.cache.json_pipeline_cache import JsonPipelineCache
|
||||
from graphrag.cache.memory_pipeline_cache import InMemoryCache
|
||||
from graphrag.cache.noop_pipeline_cache import NoopPipelineCache
|
||||
from graphrag.config.enums import CacheType
|
||||
from graphrag.storage.blob_pipeline_storage import BlobPipelineStorage
|
||||
from graphrag.storage.cosmosdb_pipeline_storage import CosmosDBPipelineStorage
|
||||
from graphrag.storage.file_pipeline_storage import FilePipelineStorage
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Callable
|
||||
|
||||
from graphrag.cache.pipeline_cache import PipelineCache
|
||||
|
||||
|
||||
class CacheFactory:
|
||||
@ -25,39 +26,90 @@ class CacheFactory:
|
||||
|
||||
Includes a method for users to register a custom cache implementation.
|
||||
|
||||
Configuration arguments are passed to each cache implementation as kwargs (where possible)
|
||||
Configuration arguments are passed to each cache implementation as kwargs
|
||||
for individual enforcement of required/optional arguments.
|
||||
"""
|
||||
|
||||
cache_types: ClassVar[dict[str, type]] = {}
|
||||
_registry: ClassVar[dict[str, Callable[..., PipelineCache]]] = {}
|
||||
|
||||
@classmethod
|
||||
def register(cls, cache_type: str, cache: type):
|
||||
"""Register a custom cache implementation."""
|
||||
cls.cache_types[cache_type] = cache
|
||||
def register(cls, cache_type: str, creator: Callable[..., PipelineCache]) -> None:
|
||||
"""Register a custom cache implementation.
|
||||
|
||||
Args:
|
||||
cache_type: The type identifier for the cache.
|
||||
creator: A class or callable that creates an instance of PipelineCache.
|
||||
"""
|
||||
cls._registry[cache_type] = creator
|
||||
|
||||
@classmethod
|
||||
def create_cache(
|
||||
cls, cache_type: CacheType | str | None, root_dir: str, kwargs: dict
|
||||
) -> PipelineCache:
|
||||
"""Create or get a cache from the provided type."""
|
||||
if not cache_type:
|
||||
return NoopPipelineCache()
|
||||
match cache_type:
|
||||
case CacheType.none:
|
||||
return NoopPipelineCache()
|
||||
case CacheType.memory:
|
||||
return InMemoryCache()
|
||||
case CacheType.file:
|
||||
return JsonPipelineCache(
|
||||
FilePipelineStorage(root_dir=root_dir).child(kwargs["base_dir"])
|
||||
)
|
||||
case CacheType.blob:
|
||||
return JsonPipelineCache(create_blob_storage(**kwargs))
|
||||
case CacheType.cosmosdb:
|
||||
return JsonPipelineCache(create_cosmosdb_storage(**kwargs))
|
||||
case _:
|
||||
if cache_type in cls.cache_types:
|
||||
return cls.cache_types[cache_type](**kwargs)
|
||||
msg = f"Unknown cache type: {cache_type}"
|
||||
raise ValueError(msg)
|
||||
def create_cache(cls, cache_type: str, kwargs: dict) -> PipelineCache:
|
||||
"""Create a cache object from the provided type.
|
||||
|
||||
Args:
|
||||
cache_type: The type of cache to create.
|
||||
root_dir: The root directory for file-based caches.
|
||||
kwargs: Additional keyword arguments for the cache constructor.
|
||||
|
||||
Returns
|
||||
-------
|
||||
A PipelineCache instance.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError: If the cache type is not registered.
|
||||
"""
|
||||
if cache_type not in cls._registry:
|
||||
msg = f"Unknown cache type: {cache_type}"
|
||||
raise ValueError(msg)
|
||||
|
||||
return cls._registry[cache_type](**kwargs)
|
||||
|
||||
@classmethod
|
||||
def get_cache_types(cls) -> list[str]:
|
||||
"""Get the registered cache implementations."""
|
||||
return list(cls._registry.keys())
|
||||
|
||||
@classmethod
|
||||
def is_supported_type(cls, cache_type: str) -> bool:
|
||||
"""Check if the given cache type is supported."""
|
||||
return cache_type in cls._registry
|
||||
|
||||
|
||||
# --- register built-in cache implementations ---
|
||||
def create_file_cache(root_dir: str, base_dir: str, **kwargs) -> PipelineCache:
|
||||
"""Create a file-based cache implementation."""
|
||||
# Create storage with base_dir in kwargs since FilePipelineStorage expects it there
|
||||
storage_kwargs = {"base_dir": root_dir, **kwargs}
|
||||
storage = FilePipelineStorage(**storage_kwargs).child(base_dir)
|
||||
return JsonPipelineCache(storage)
|
||||
|
||||
|
||||
def create_blob_cache(**kwargs) -> PipelineCache:
|
||||
"""Create a blob storage-based cache implementation."""
|
||||
storage = BlobPipelineStorage(**kwargs)
|
||||
return JsonPipelineCache(storage)
|
||||
|
||||
|
||||
def create_cosmosdb_cache(**kwargs) -> PipelineCache:
|
||||
"""Create a CosmosDB-based cache implementation."""
|
||||
storage = CosmosDBPipelineStorage(**kwargs)
|
||||
return JsonPipelineCache(storage)
|
||||
|
||||
|
||||
def create_noop_cache(**_kwargs) -> PipelineCache:
|
||||
"""Create a no-op cache implementation."""
|
||||
return NoopPipelineCache()
|
||||
|
||||
|
||||
def create_memory_cache(**kwargs) -> PipelineCache:
|
||||
"""Create a memory cache implementation."""
|
||||
return InMemoryCache(**kwargs)
|
||||
|
||||
|
||||
# --- register built-in cache implementations ---
|
||||
CacheFactory.register(CacheType.none.value, create_noop_cache)
|
||||
CacheFactory.register(CacheType.memory.value, create_memory_cache)
|
||||
CacheFactory.register(CacheType.file.value, create_file_cache)
|
||||
CacheFactory.register(CacheType.blob.value, create_blob_cache)
|
||||
CacheFactory.register(CacheType.cosmosdb.value, create_cosmosdb_cache)
|
||||
|
||||
@ -4,29 +4,43 @@
|
||||
"""A logger that emits updates from the indexing engine to the console."""
|
||||
|
||||
from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks
|
||||
from graphrag.index.typing.pipeline_run_result import PipelineRunResult
|
||||
from graphrag.logger.progress import Progress
|
||||
|
||||
# ruff: noqa: T201
|
||||
|
||||
|
||||
class ConsoleWorkflowCallbacks(NoopWorkflowCallbacks):
|
||||
"""A logger that writes to a console."""
|
||||
|
||||
def error(
|
||||
self,
|
||||
message: str,
|
||||
cause: BaseException | None = None,
|
||||
stack: str | None = None,
|
||||
details: dict | None = None,
|
||||
):
|
||||
"""Handle when an error occurs."""
|
||||
print(message, str(cause), stack, details) # noqa T201
|
||||
_verbose = False
|
||||
|
||||
def warning(self, message: str, details: dict | None = None):
|
||||
"""Handle when a warning occurs."""
|
||||
_print_warning(message)
|
||||
def __init__(self, verbose=False):
|
||||
self._verbose = verbose
|
||||
|
||||
def log(self, message: str, details: dict | None = None):
|
||||
"""Handle when a log message is produced."""
|
||||
print(message, details) # noqa T201
|
||||
def pipeline_start(self, names: list[str]) -> None:
|
||||
"""Execute this callback to signal when the entire pipeline starts."""
|
||||
print("Starting pipeline with workflows:", ", ".join(names))
|
||||
|
||||
def pipeline_end(self, results: list[PipelineRunResult]) -> None:
|
||||
"""Execute this callback to signal when the entire pipeline ends."""
|
||||
print("Pipeline complete")
|
||||
|
||||
def _print_warning(skk):
|
||||
print("\033[93m {}\033[00m".format(skk)) # noqa T201
|
||||
def workflow_start(self, name: str, instance: object) -> None:
|
||||
"""Execute this callback when a workflow starts."""
|
||||
print(f"Starting workflow: {name}")
|
||||
|
||||
def workflow_end(self, name: str, instance: object) -> None:
|
||||
"""Execute this callback when a workflow ends."""
|
||||
print("") # account for potential return on prior progress
|
||||
print(f"Workflow complete: {name}")
|
||||
if self._verbose:
|
||||
print(instance)
|
||||
|
||||
def progress(self, progress: Progress) -> None:
|
||||
"""Handle when progress occurs."""
|
||||
complete = progress.completed_items or 0
|
||||
total = progress.total_items or 1
|
||||
percent = round((complete / total) * 100)
|
||||
start = f" {complete} / {total} "
|
||||
print(f"{start:{'.'}<{percent}}", flush=True, end="\r")
|
||||
|
||||
@ -1,78 +0,0 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""A logger that emits updates from the indexing engine to a local file."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from io import TextIOWrapper
|
||||
from pathlib import Path
|
||||
|
||||
from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FileWorkflowCallbacks(NoopWorkflowCallbacks):
|
||||
"""A logger that writes to a local file."""
|
||||
|
||||
_out_stream: TextIOWrapper
|
||||
|
||||
def __init__(self, directory: str):
|
||||
"""Create a new file-based workflow logger."""
|
||||
Path(directory).mkdir(parents=True, exist_ok=True)
|
||||
self._out_stream = open( # noqa: PTH123, SIM115
|
||||
Path(directory) / "logs.json", "a", encoding="utf-8", errors="strict"
|
||||
)
|
||||
|
||||
def error(
|
||||
self,
|
||||
message: str,
|
||||
cause: BaseException | None = None,
|
||||
stack: str | None = None,
|
||||
details: dict | None = None,
|
||||
):
|
||||
"""Handle when an error occurs."""
|
||||
self._out_stream.write(
|
||||
json.dumps(
|
||||
{
|
||||
"type": "error",
|
||||
"data": message,
|
||||
"stack": stack,
|
||||
"source": str(cause),
|
||||
"details": details,
|
||||
},
|
||||
indent=4,
|
||||
ensure_ascii=False,
|
||||
)
|
||||
+ "\n"
|
||||
)
|
||||
message = f"{message} details={details}"
|
||||
log.info(message)
|
||||
|
||||
def warning(self, message: str, details: dict | None = None):
|
||||
"""Handle when a warning occurs."""
|
||||
self._out_stream.write(
|
||||
json.dumps(
|
||||
{"type": "warning", "data": message, "details": details},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
+ "\n"
|
||||
)
|
||||
_print_warning(message)
|
||||
|
||||
def log(self, message: str, details: dict | None = None):
|
||||
"""Handle when a log message is produced."""
|
||||
self._out_stream.write(
|
||||
json.dumps(
|
||||
{"type": "log", "data": message, "details": details}, ensure_ascii=False
|
||||
)
|
||||
+ "\n"
|
||||
)
|
||||
|
||||
message = f"{message} details={details}"
|
||||
log.info(message)
|
||||
|
||||
|
||||
def _print_warning(skk):
|
||||
log.warning(skk)
|
||||
@ -9,13 +9,13 @@ from graphrag.logger.progress import Progress
|
||||
|
||||
|
||||
class NoopWorkflowCallbacks(WorkflowCallbacks):
|
||||
"""A no-op implementation of WorkflowCallbacks."""
|
||||
"""A no-op implementation of WorkflowCallbacks that logs all events to standard logging."""
|
||||
|
||||
def pipeline_start(self, names: list[str]) -> None:
|
||||
"""Execute this callback when a the entire pipeline starts."""
|
||||
"""Execute this callback to signal when the entire pipeline starts."""
|
||||
|
||||
def pipeline_end(self, results: list[PipelineRunResult]) -> None:
|
||||
"""Execute this callback when the entire pipeline ends."""
|
||||
"""Execute this callback to signal when the entire pipeline ends."""
|
||||
|
||||
def workflow_start(self, name: str, instance: object) -> None:
|
||||
"""Execute this callback when a workflow starts."""
|
||||
@ -25,18 +25,3 @@ class NoopWorkflowCallbacks(WorkflowCallbacks):
|
||||
|
||||
def progress(self, progress: Progress) -> None:
|
||||
"""Handle when progress occurs."""
|
||||
|
||||
def error(
|
||||
self,
|
||||
message: str,
|
||||
cause: BaseException | None = None,
|
||||
stack: str | None = None,
|
||||
details: dict | None = None,
|
||||
) -> None:
|
||||
"""Handle when an error occurs."""
|
||||
|
||||
def warning(self, message: str, details: dict | None = None) -> None:
|
||||
"""Handle when a warning occurs."""
|
||||
|
||||
def log(self, message: str, details: dict | None = None) -> None:
|
||||
"""Handle when a log message occurs."""
|
||||
|
||||
@ -1,42 +0,0 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""A workflow callback manager that emits updates."""
|
||||
|
||||
from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks
|
||||
from graphrag.logger.base import ProgressLogger
|
||||
from graphrag.logger.progress import Progress
|
||||
|
||||
|
||||
class ProgressWorkflowCallbacks(NoopWorkflowCallbacks):
|
||||
"""A callbackmanager that delegates to a ProgressLogger."""
|
||||
|
||||
_root_progress: ProgressLogger
|
||||
_progress_stack: list[ProgressLogger]
|
||||
|
||||
def __init__(self, progress: ProgressLogger) -> None:
|
||||
"""Create a new ProgressWorkflowCallbacks."""
|
||||
self._progress = progress
|
||||
self._progress_stack = [progress]
|
||||
|
||||
def _pop(self) -> None:
|
||||
self._progress_stack.pop()
|
||||
|
||||
def _push(self, name: str) -> None:
|
||||
self._progress_stack.append(self._latest.child(name))
|
||||
|
||||
@property
|
||||
def _latest(self) -> ProgressLogger:
|
||||
return self._progress_stack[-1]
|
||||
|
||||
def workflow_start(self, name: str, instance: object) -> None:
|
||||
"""Execute this callback when a workflow starts."""
|
||||
self._push(name)
|
||||
|
||||
def workflow_end(self, name: str, instance: object) -> None:
|
||||
"""Execute this callback when a workflow ends."""
|
||||
self._pop()
|
||||
|
||||
def progress(self, progress: Progress) -> None:
|
||||
"""Handle when progress occurs."""
|
||||
self._latest(progress)
|
||||
@ -1,39 +0,0 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""A module containing the pipeline reporter factory."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from graphrag.callbacks.blob_workflow_callbacks import BlobWorkflowCallbacks
|
||||
from graphrag.callbacks.console_workflow_callbacks import ConsoleWorkflowCallbacks
|
||||
from graphrag.callbacks.file_workflow_callbacks import FileWorkflowCallbacks
|
||||
from graphrag.config.enums import ReportingType
|
||||
from graphrag.config.models.reporting_config import ReportingConfig
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
|
||||
|
||||
|
||||
def create_pipeline_reporter(
|
||||
config: ReportingConfig | None, root_dir: str | None
|
||||
) -> WorkflowCallbacks:
|
||||
"""Create a logger for the given pipeline config."""
|
||||
config = config or ReportingConfig(base_dir="logs", type=ReportingType.file)
|
||||
match config.type:
|
||||
case ReportingType.file:
|
||||
return FileWorkflowCallbacks(
|
||||
str(Path(root_dir or "") / (config.base_dir or ""))
|
||||
)
|
||||
case ReportingType.console:
|
||||
return ConsoleWorkflowCallbacks()
|
||||
case ReportingType.blob:
|
||||
return BlobWorkflowCallbacks(
|
||||
config.connection_string,
|
||||
config.container_name,
|
||||
base_dir=config.base_dir,
|
||||
storage_account_blob_url=config.storage_account_blob_url,
|
||||
)
|
||||
@ -35,21 +35,3 @@ class WorkflowCallbacks(Protocol):
|
||||
def progress(self, progress: Progress) -> None:
|
||||
"""Handle when progress occurs."""
|
||||
...
|
||||
|
||||
def error(
|
||||
self,
|
||||
message: str,
|
||||
cause: BaseException | None = None,
|
||||
stack: str | None = None,
|
||||
details: dict | None = None,
|
||||
) -> None:
|
||||
"""Handle when an error occurs."""
|
||||
...
|
||||
|
||||
def warning(self, message: str, details: dict | None = None) -> None:
|
||||
"""Handle when a warning occurs."""
|
||||
...
|
||||
|
||||
def log(self, message: str, details: dict | None = None) -> None:
|
||||
"""Handle when a log message occurs."""
|
||||
...
|
||||
|
||||
@ -50,27 +50,3 @@ class WorkflowCallbacksManager(WorkflowCallbacks):
|
||||
for callback in self._callbacks:
|
||||
if hasattr(callback, "progress"):
|
||||
callback.progress(progress)
|
||||
|
||||
def error(
|
||||
self,
|
||||
message: str,
|
||||
cause: BaseException | None = None,
|
||||
stack: str | None = None,
|
||||
details: dict | None = None,
|
||||
) -> None:
|
||||
"""Handle when an error occurs."""
|
||||
for callback in self._callbacks:
|
||||
if hasattr(callback, "error"):
|
||||
callback.error(message, cause, stack, details)
|
||||
|
||||
def warning(self, message: str, details: dict | None = None) -> None:
|
||||
"""Handle when a warning occurs."""
|
||||
for callback in self._callbacks:
|
||||
if hasattr(callback, "warning"):
|
||||
callback.warning(message, details)
|
||||
|
||||
def log(self, message: str, details: dict | None = None) -> None:
|
||||
"""Handle when a log message occurs."""
|
||||
for callback in self._callbacks:
|
||||
if hasattr(callback, "log"):
|
||||
callback.log(message, details)
|
||||
|
||||
@ -10,49 +10,27 @@ import warnings
|
||||
from pathlib import Path
|
||||
|
||||
import graphrag.api as api
|
||||
from graphrag.callbacks.console_workflow_callbacks import ConsoleWorkflowCallbacks
|
||||
from graphrag.config.enums import CacheType, IndexingMethod
|
||||
from graphrag.config.load_config import load_config
|
||||
from graphrag.config.logging import enable_logging_with_config
|
||||
from graphrag.index.validate_config import validate_config_names
|
||||
from graphrag.logger.base import ProgressLogger
|
||||
from graphrag.logger.factory import LoggerFactory, LoggerType
|
||||
from graphrag.utils.cli import redact
|
||||
|
||||
# Ignore warnings from numba
|
||||
warnings.filterwarnings("ignore", message=".*NumbaDeprecationWarning.*")
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _logger(logger: ProgressLogger):
|
||||
def info(msg: str, verbose: bool = False):
|
||||
log.info(msg)
|
||||
if verbose:
|
||||
logger.info(msg)
|
||||
|
||||
def error(msg: str, verbose: bool = False):
|
||||
log.error(msg)
|
||||
if verbose:
|
||||
logger.error(msg)
|
||||
|
||||
def success(msg: str, verbose: bool = False):
|
||||
log.info(msg)
|
||||
if verbose:
|
||||
logger.success(msg)
|
||||
|
||||
return info, error, success
|
||||
|
||||
|
||||
def _register_signal_handlers(logger: ProgressLogger):
|
||||
def _register_signal_handlers():
|
||||
import signal
|
||||
|
||||
def handle_signal(signum, _):
|
||||
# Handle the signal here
|
||||
logger.info(f"Received signal {signum}, exiting...") # noqa: G004
|
||||
logger.dispose()
|
||||
logger.debug(f"Received signal {signum}, exiting...") # noqa: G004
|
||||
for task in asyncio.all_tasks():
|
||||
task.cancel()
|
||||
logger.info("All tasks cancelled. Exiting...")
|
||||
logger.debug("All tasks cancelled. Exiting...")
|
||||
|
||||
# Register signal handlers for SIGINT and SIGHUP
|
||||
signal.signal(signal.SIGINT, handle_signal)
|
||||
@ -67,7 +45,6 @@ def index_cli(
|
||||
verbose: bool,
|
||||
memprofile: bool,
|
||||
cache: bool,
|
||||
logger: LoggerType,
|
||||
config_filepath: Path | None,
|
||||
dry_run: bool,
|
||||
skip_validation: bool,
|
||||
@ -80,7 +57,6 @@ def index_cli(
|
||||
cli_overrides["reporting.base_dir"] = str(output_dir)
|
||||
cli_overrides["update_index_output.base_dir"] = str(output_dir)
|
||||
config = load_config(root_dir, config_filepath, cli_overrides)
|
||||
|
||||
_run_index(
|
||||
config=config,
|
||||
method=method,
|
||||
@ -88,7 +64,6 @@ def index_cli(
|
||||
verbose=verbose,
|
||||
memprofile=memprofile,
|
||||
cache=cache,
|
||||
logger=logger,
|
||||
dry_run=dry_run,
|
||||
skip_validation=skip_validation,
|
||||
)
|
||||
@ -100,7 +75,6 @@ def update_cli(
|
||||
verbose: bool,
|
||||
memprofile: bool,
|
||||
cache: bool,
|
||||
logger: LoggerType,
|
||||
config_filepath: Path | None,
|
||||
skip_validation: bool,
|
||||
output_dir: Path | None,
|
||||
@ -121,7 +95,6 @@ def update_cli(
|
||||
verbose=verbose,
|
||||
memprofile=memprofile,
|
||||
cache=cache,
|
||||
logger=logger,
|
||||
dry_run=False,
|
||||
skip_validation=skip_validation,
|
||||
)
|
||||
@ -134,39 +107,35 @@ def _run_index(
|
||||
verbose,
|
||||
memprofile,
|
||||
cache,
|
||||
logger,
|
||||
dry_run,
|
||||
skip_validation,
|
||||
):
|
||||
progress_logger = LoggerFactory().create_logger(logger)
|
||||
info, error, success = _logger(progress_logger)
|
||||
# Configure the root logger with the specified log level
|
||||
from graphrag.logger.standard_logging import init_loggers
|
||||
|
||||
# Initialize loggers and reporting config
|
||||
init_loggers(
|
||||
config=config,
|
||||
verbose=verbose,
|
||||
)
|
||||
|
||||
if not cache:
|
||||
config.cache.type = CacheType.none
|
||||
|
||||
enabled_logging, log_path = enable_logging_with_config(config, verbose)
|
||||
if enabled_logging:
|
||||
info(f"Logging enabled at {log_path}", True)
|
||||
else:
|
||||
info(
|
||||
f"Logging not enabled for config {redact(config.model_dump())}",
|
||||
True,
|
||||
)
|
||||
|
||||
if not skip_validation:
|
||||
validate_config_names(progress_logger, config)
|
||||
validate_config_names(config)
|
||||
|
||||
info(f"Starting pipeline run. {dry_run=}", verbose)
|
||||
info(
|
||||
f"Using default configuration: {redact(config.model_dump())}",
|
||||
verbose,
|
||||
logger.info("Starting pipeline run. %s", dry_run)
|
||||
logger.info(
|
||||
"Using default configuration: %s",
|
||||
redact(config.model_dump()),
|
||||
)
|
||||
|
||||
if dry_run:
|
||||
info("Dry run complete, exiting...", True)
|
||||
logger.info("Dry run complete, exiting...", True)
|
||||
sys.exit(0)
|
||||
|
||||
_register_signal_handlers(progress_logger)
|
||||
_register_signal_handlers()
|
||||
|
||||
outputs = asyncio.run(
|
||||
api.build_index(
|
||||
@ -174,19 +143,19 @@ def _run_index(
|
||||
method=method,
|
||||
is_update_run=is_update_run,
|
||||
memory_profile=memprofile,
|
||||
progress_logger=progress_logger,
|
||||
callbacks=[ConsoleWorkflowCallbacks(verbose=verbose)],
|
||||
verbose=verbose,
|
||||
)
|
||||
)
|
||||
encountered_errors = any(
|
||||
output.errors and len(output.errors) > 0 for output in outputs
|
||||
)
|
||||
|
||||
progress_logger.stop()
|
||||
if encountered_errors:
|
||||
error(
|
||||
"Errors occurred during the pipeline run, see logs for more details.", True
|
||||
logger.error(
|
||||
"Errors occurred during the pipeline run, see logs for more details."
|
||||
)
|
||||
else:
|
||||
success("All workflows completed successfully.", True)
|
||||
logger.info("All workflows completed successfully.")
|
||||
|
||||
sys.exit(1 if encountered_errors else 0)
|
||||
|
||||
@ -3,10 +3,10 @@
|
||||
|
||||
"""CLI implementation of the initialization subcommand."""
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
from graphrag.config.init_content import INIT_DOTENV, INIT_YAML
|
||||
from graphrag.logger.factory import LoggerFactory, LoggerType
|
||||
from graphrag.prompts.index.community_report import (
|
||||
COMMUNITY_REPORT_PROMPT,
|
||||
)
|
||||
@ -31,6 +31,8 @@ from graphrag.prompts.query.global_search_reduce_system_prompt import (
|
||||
from graphrag.prompts.query.local_search_system_prompt import LOCAL_SEARCH_SYSTEM_PROMPT
|
||||
from graphrag.prompts.query.question_gen_system_prompt import QUESTION_SYSTEM_PROMPT
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def initialize_project_at(path: Path, force: bool) -> None:
|
||||
"""
|
||||
@ -48,8 +50,7 @@ def initialize_project_at(path: Path, force: bool) -> None:
|
||||
ValueError
|
||||
If the project already exists and force is False.
|
||||
"""
|
||||
progress_logger = LoggerFactory().create_logger(LoggerType.RICH)
|
||||
progress_logger.info(f"Initializing project at {path}") # noqa: G004
|
||||
logger.info("Initializing project at %s", path)
|
||||
root = Path(path)
|
||||
if not root.exists():
|
||||
root.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
@ -12,7 +12,6 @@ import typer
|
||||
|
||||
from graphrag.config.defaults import graphrag_config_defaults
|
||||
from graphrag.config.enums import IndexingMethod, SearchMethod
|
||||
from graphrag.logger.types import LoggerType
|
||||
from graphrag.prompt_tune.defaults import LIMIT, MAX_TOKEN_COUNT, N_SUBSET_MAX, K
|
||||
from graphrag.prompt_tune.types import DocSelectionType
|
||||
|
||||
@ -157,11 +156,6 @@ def _index_cli(
|
||||
"--memprofile",
|
||||
help="Run the indexing pipeline with memory profiling",
|
||||
),
|
||||
logger: LoggerType = typer.Option(
|
||||
LoggerType.RICH.value,
|
||||
"--logger",
|
||||
help="The progress logger to use.",
|
||||
),
|
||||
dry_run: bool = typer.Option(
|
||||
False,
|
||||
"--dry-run",
|
||||
@ -201,7 +195,6 @@ def _index_cli(
|
||||
verbose=verbose,
|
||||
memprofile=memprofile,
|
||||
cache=cache,
|
||||
logger=LoggerType(logger),
|
||||
config_filepath=config,
|
||||
dry_run=dry_run,
|
||||
skip_validation=skip_validation,
|
||||
@ -250,11 +243,6 @@ def _update_cli(
|
||||
"--memprofile",
|
||||
help="Run the indexing pipeline with memory profiling.",
|
||||
),
|
||||
logger: LoggerType = typer.Option(
|
||||
LoggerType.RICH.value,
|
||||
"--logger",
|
||||
help="The progress logger to use.",
|
||||
),
|
||||
cache: bool = typer.Option(
|
||||
True,
|
||||
"--cache/--no-cache",
|
||||
@ -290,7 +278,6 @@ def _update_cli(
|
||||
verbose=verbose,
|
||||
memprofile=memprofile,
|
||||
cache=cache,
|
||||
logger=LoggerType(logger),
|
||||
config_filepath=config,
|
||||
skip_validation=skip_validation,
|
||||
output_dir=output,
|
||||
@ -327,11 +314,6 @@ def _prompt_tune_cli(
|
||||
"-v",
|
||||
help="Run the prompt tuning pipeline with verbose logging.",
|
||||
),
|
||||
logger: LoggerType = typer.Option(
|
||||
LoggerType.RICH.value,
|
||||
"--logger",
|
||||
help="The progress logger to use.",
|
||||
),
|
||||
domain: str | None = typer.Option(
|
||||
None,
|
||||
"--domain",
|
||||
@ -413,7 +395,6 @@ def _prompt_tune_cli(
|
||||
config=config,
|
||||
domain=domain,
|
||||
verbose=verbose,
|
||||
logger=logger,
|
||||
selection_method=selection_method,
|
||||
limit=limit,
|
||||
max_tokens=max_tokens,
|
||||
@ -453,6 +434,12 @@ def _query_cli(
|
||||
readable=True,
|
||||
autocompletion=CONFIG_AUTOCOMPLETE,
|
||||
),
|
||||
verbose: bool = typer.Option(
|
||||
False,
|
||||
"--verbose",
|
||||
"-v",
|
||||
help="Run the query with verbose logging.",
|
||||
),
|
||||
data: Path | None = typer.Option(
|
||||
None,
|
||||
"--data",
|
||||
@ -520,6 +507,7 @@ def _query_cli(
|
||||
response_type=response_type,
|
||||
streaming=streaming,
|
||||
query=query,
|
||||
verbose=verbose,
|
||||
)
|
||||
case SearchMethod.GLOBAL:
|
||||
run_global_search(
|
||||
@ -531,6 +519,7 @@ def _query_cli(
|
||||
response_type=response_type,
|
||||
streaming=streaming,
|
||||
query=query,
|
||||
verbose=verbose,
|
||||
)
|
||||
case SearchMethod.DRIFT:
|
||||
run_drift_search(
|
||||
@ -541,6 +530,7 @@ def _query_cli(
|
||||
streaming=streaming,
|
||||
response_type=response_type,
|
||||
query=query,
|
||||
verbose=verbose,
|
||||
)
|
||||
case SearchMethod.BASIC:
|
||||
run_basic_search(
|
||||
@ -549,6 +539,7 @@ def _query_cli(
|
||||
root_dir=root,
|
||||
streaming=streaming,
|
||||
query=query,
|
||||
verbose=verbose,
|
||||
)
|
||||
case _:
|
||||
raise ValueError(INVALID_METHOD_ERROR)
|
||||
|
||||
@ -3,13 +3,11 @@
|
||||
|
||||
"""CLI implementation of the prompt-tune subcommand."""
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import graphrag.api as api
|
||||
from graphrag.cli.index import _logger
|
||||
from graphrag.config.load_config import load_config
|
||||
from graphrag.config.logging import enable_logging_with_config
|
||||
from graphrag.logger.factory import LoggerFactory, LoggerType
|
||||
from graphrag.prompt_tune.generator.community_report_summarization import (
|
||||
COMMUNITY_SUMMARIZATION_FILENAME,
|
||||
)
|
||||
@ -21,13 +19,14 @@ from graphrag.prompt_tune.generator.extract_graph_prompt import (
|
||||
)
|
||||
from graphrag.utils.cli import redact
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def prompt_tune(
|
||||
root: Path,
|
||||
config: Path | None,
|
||||
domain: str | None,
|
||||
verbose: bool,
|
||||
logger: LoggerType,
|
||||
selection_method: api.DocSelectionType,
|
||||
limit: int,
|
||||
max_tokens: int,
|
||||
@ -47,8 +46,7 @@ async def prompt_tune(
|
||||
- config: The configuration file.
|
||||
- root: The root directory.
|
||||
- domain: The domain to map the input documents to.
|
||||
- verbose: Whether to enable verbose logging.
|
||||
- logger: The logger to use.
|
||||
- verbose: Enable verbose logging.
|
||||
- selection_method: The chunk selection method.
|
||||
- limit: The limit of chunks to load.
|
||||
- max_tokens: The maximum number of tokens to use on entity extraction prompts.
|
||||
@ -70,24 +68,20 @@ async def prompt_tune(
|
||||
if overlap != graph_config.chunks.overlap:
|
||||
graph_config.chunks.overlap = overlap
|
||||
|
||||
progress_logger = LoggerFactory().create_logger(logger)
|
||||
info, error, success = _logger(progress_logger)
|
||||
# configure the root logger with the specified log level
|
||||
from graphrag.logger.standard_logging import init_loggers
|
||||
|
||||
enabled_logging, log_path = enable_logging_with_config(
|
||||
graph_config, verbose, filename="prompt-tune.log"
|
||||
# initialize loggers with config
|
||||
init_loggers(config=graph_config, verbose=verbose, filename="prompt-tuning.log")
|
||||
|
||||
logger.info("Starting prompt tune.")
|
||||
logger.info(
|
||||
"Using default configuration: %s",
|
||||
redact(graph_config.model_dump()),
|
||||
)
|
||||
if enabled_logging:
|
||||
info(f"Logging enabled at {log_path}", verbose)
|
||||
else:
|
||||
info(
|
||||
f"Logging not enabled for config {redact(graph_config.model_dump())}",
|
||||
verbose,
|
||||
)
|
||||
|
||||
prompts = await api.generate_indexing_prompts(
|
||||
config=graph_config,
|
||||
root=str(root_path),
|
||||
logger=progress_logger,
|
||||
chunk_size=chunk_size,
|
||||
overlap=overlap,
|
||||
limit=limit,
|
||||
@ -99,24 +93,25 @@ async def prompt_tune(
|
||||
min_examples_required=min_examples_required,
|
||||
n_subset_max=n_subset_max,
|
||||
k=k,
|
||||
verbose=verbose,
|
||||
)
|
||||
|
||||
output_path = output.resolve()
|
||||
if output_path:
|
||||
info(f"Writing prompts to {output_path}")
|
||||
logger.info("Writing prompts to %s", output_path)
|
||||
output_path.mkdir(parents=True, exist_ok=True)
|
||||
extract_graph_prompt_path = output_path / EXTRACT_GRAPH_FILENAME
|
||||
entity_summarization_prompt_path = output_path / ENTITY_SUMMARIZATION_FILENAME
|
||||
community_summarization_prompt_path = (
|
||||
output_path / COMMUNITY_SUMMARIZATION_FILENAME
|
||||
)
|
||||
# Write files to output path
|
||||
# write files to output path
|
||||
with extract_graph_prompt_path.open("wb") as file:
|
||||
file.write(prompts[0].encode(encoding="utf-8", errors="strict"))
|
||||
with entity_summarization_prompt_path.open("wb") as file:
|
||||
file.write(prompts[1].encode(encoding="utf-8", errors="strict"))
|
||||
with community_summarization_prompt_path.open("wb") as file:
|
||||
file.write(prompts[2].encode(encoding="utf-8", errors="strict"))
|
||||
success(f"Prompts written to {output_path}")
|
||||
logger.info("Prompts written to %s", output_path)
|
||||
else:
|
||||
error("No output path provided. Skipping writing prompts.")
|
||||
logger.error("No output path provided. Skipping writing prompts.")
|
||||
|
||||
@ -12,14 +12,13 @@ import graphrag.api as api
|
||||
from graphrag.callbacks.noop_query_callbacks import NoopQueryCallbacks
|
||||
from graphrag.config.load_config import load_config
|
||||
from graphrag.config.models.graph_rag_config import GraphRagConfig
|
||||
from graphrag.logger.print_progress import PrintProgressLogger
|
||||
from graphrag.utils.api import create_storage_from_config
|
||||
from graphrag.utils.storage import load_table_from_storage, storage_has_table
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import pandas as pd
|
||||
|
||||
logger = PrintProgressLogger("")
|
||||
# ruff: noqa: T201
|
||||
|
||||
|
||||
def run_global_search(
|
||||
@ -31,6 +30,7 @@ def run_global_search(
|
||||
response_type: str,
|
||||
streaming: bool,
|
||||
query: str,
|
||||
verbose: bool,
|
||||
):
|
||||
"""Perform a global search with a given query.
|
||||
|
||||
@ -59,10 +59,6 @@ def run_global_search(
|
||||
final_community_reports_list = dataframe_dict["community_reports"]
|
||||
index_names = dataframe_dict["index_names"]
|
||||
|
||||
logger.success(
|
||||
f"Running Multi-index Global Search: {dataframe_dict['index_names']}"
|
||||
)
|
||||
|
||||
response, context_data = asyncio.run(
|
||||
api.multi_index_global_search(
|
||||
config=config,
|
||||
@ -75,11 +71,10 @@ def run_global_search(
|
||||
response_type=response_type,
|
||||
streaming=streaming,
|
||||
query=query,
|
||||
verbose=verbose,
|
||||
)
|
||||
)
|
||||
logger.success(f"Global Search Response:\n{response}")
|
||||
# NOTE: we return the response and context data here purely as a complete demonstration of the API.
|
||||
# External users should use the API directly to get the response and context data.
|
||||
print(response)
|
||||
return response, context_data
|
||||
|
||||
# Otherwise, call the Single-Index Global Search API
|
||||
@ -110,11 +105,12 @@ def run_global_search(
|
||||
response_type=response_type,
|
||||
query=query,
|
||||
callbacks=[callbacks],
|
||||
verbose=verbose,
|
||||
):
|
||||
full_response += stream_chunk
|
||||
print(stream_chunk, end="") # noqa: T201
|
||||
sys.stdout.flush() # flush output buffer to display text immediately
|
||||
print() # noqa: T201
|
||||
print(stream_chunk, end="")
|
||||
sys.stdout.flush()
|
||||
print()
|
||||
return full_response, context_data
|
||||
|
||||
return asyncio.run(run_streaming_search())
|
||||
@ -129,11 +125,11 @@ def run_global_search(
|
||||
dynamic_community_selection=dynamic_community_selection,
|
||||
response_type=response_type,
|
||||
query=query,
|
||||
verbose=verbose,
|
||||
)
|
||||
)
|
||||
logger.success(f"Global Search Response:\n{response}")
|
||||
# NOTE: we return the response and context data here purely as a complete demonstration of the API.
|
||||
# External users should use the API directly to get the response and context data.
|
||||
print(response)
|
||||
|
||||
return response, context_data
|
||||
|
||||
|
||||
@ -145,6 +141,7 @@ def run_local_search(
|
||||
response_type: str,
|
||||
streaming: bool,
|
||||
query: str,
|
||||
verbose: bool,
|
||||
):
|
||||
"""Perform a local search with a given query.
|
||||
|
||||
@ -178,10 +175,6 @@ def run_local_search(
|
||||
final_relationships_list = dataframe_dict["relationships"]
|
||||
index_names = dataframe_dict["index_names"]
|
||||
|
||||
logger.success(
|
||||
f"Running Multi-index Local Search: {dataframe_dict['index_names']}"
|
||||
)
|
||||
|
||||
# If any covariates tables are missing from any index, set the covariates list to None
|
||||
if len(dataframe_dict["covariates"]) != dataframe_dict["num_indexes"]:
|
||||
final_covariates_list = None
|
||||
@ -202,11 +195,11 @@ def run_local_search(
|
||||
response_type=response_type,
|
||||
streaming=streaming,
|
||||
query=query,
|
||||
verbose=verbose,
|
||||
)
|
||||
)
|
||||
logger.success(f"Local Search Response:\n{response}")
|
||||
# NOTE: we return the response and context data here purely as a complete demonstration of the API.
|
||||
# External users should use the API directly to get the response and context data.
|
||||
print(response)
|
||||
|
||||
return response, context_data
|
||||
|
||||
# Otherwise, call the Single-Index Local Search API
|
||||
@ -242,11 +235,12 @@ def run_local_search(
|
||||
response_type=response_type,
|
||||
query=query,
|
||||
callbacks=[callbacks],
|
||||
verbose=verbose,
|
||||
):
|
||||
full_response += stream_chunk
|
||||
print(stream_chunk, end="") # noqa: T201
|
||||
sys.stdout.flush() # flush output buffer to display text immediately
|
||||
print() # noqa: T201
|
||||
print(stream_chunk, end="")
|
||||
sys.stdout.flush()
|
||||
print()
|
||||
return full_response, context_data
|
||||
|
||||
return asyncio.run(run_streaming_search())
|
||||
@ -263,11 +257,11 @@ def run_local_search(
|
||||
community_level=community_level,
|
||||
response_type=response_type,
|
||||
query=query,
|
||||
verbose=verbose,
|
||||
)
|
||||
)
|
||||
logger.success(f"Local Search Response:\n{response}")
|
||||
# NOTE: we return the response and context data here purely as a complete demonstration of the API.
|
||||
# External users should use the API directly to get the response and context data.
|
||||
print(response)
|
||||
|
||||
return response, context_data
|
||||
|
||||
|
||||
@ -279,6 +273,7 @@ def run_drift_search(
|
||||
response_type: str,
|
||||
streaming: bool,
|
||||
query: str,
|
||||
verbose: bool,
|
||||
):
|
||||
"""Perform a local search with a given query.
|
||||
|
||||
@ -310,10 +305,6 @@ def run_drift_search(
|
||||
final_relationships_list = dataframe_dict["relationships"]
|
||||
index_names = dataframe_dict["index_names"]
|
||||
|
||||
logger.success(
|
||||
f"Running Multi-index Drift Search: {dataframe_dict['index_names']}"
|
||||
)
|
||||
|
||||
response, context_data = asyncio.run(
|
||||
api.multi_index_drift_search(
|
||||
config=config,
|
||||
@ -327,11 +318,11 @@ def run_drift_search(
|
||||
response_type=response_type,
|
||||
streaming=streaming,
|
||||
query=query,
|
||||
verbose=verbose,
|
||||
)
|
||||
)
|
||||
logger.success(f"DRIFT Search Response:\n{response}")
|
||||
# NOTE: we return the response and context data here purely as a complete demonstration of the API.
|
||||
# External users should use the API directly to get the response and context data.
|
||||
print(response)
|
||||
|
||||
return response, context_data
|
||||
|
||||
# Otherwise, call the Single-Index Drift Search API
|
||||
@ -365,11 +356,12 @@ def run_drift_search(
|
||||
response_type=response_type,
|
||||
query=query,
|
||||
callbacks=[callbacks],
|
||||
verbose=verbose,
|
||||
):
|
||||
full_response += stream_chunk
|
||||
print(stream_chunk, end="") # noqa: T201
|
||||
sys.stdout.flush() # flush output buffer to display text immediately
|
||||
print() # noqa: T201
|
||||
print(stream_chunk, end="")
|
||||
sys.stdout.flush()
|
||||
print()
|
||||
return full_response, context_data
|
||||
|
||||
return asyncio.run(run_streaming_search())
|
||||
@ -386,11 +378,11 @@ def run_drift_search(
|
||||
community_level=community_level,
|
||||
response_type=response_type,
|
||||
query=query,
|
||||
verbose=verbose,
|
||||
)
|
||||
)
|
||||
logger.success(f"DRIFT Search Response:\n{response}")
|
||||
# NOTE: we return the response and context data here purely as a complete demonstration of the API.
|
||||
# External users should use the API directly to get the response and context data.
|
||||
print(response)
|
||||
|
||||
return response, context_data
|
||||
|
||||
|
||||
@ -400,6 +392,7 @@ def run_basic_search(
|
||||
root_dir: Path,
|
||||
streaming: bool,
|
||||
query: str,
|
||||
verbose: bool,
|
||||
):
|
||||
"""Perform a basics search with a given query.
|
||||
|
||||
@ -423,10 +416,6 @@ def run_basic_search(
|
||||
final_text_units_list = dataframe_dict["text_units"]
|
||||
index_names = dataframe_dict["index_names"]
|
||||
|
||||
logger.success(
|
||||
f"Running Multi-index Basic Search: {dataframe_dict['index_names']}"
|
||||
)
|
||||
|
||||
response, context_data = asyncio.run(
|
||||
api.multi_index_basic_search(
|
||||
config=config,
|
||||
@ -434,11 +423,11 @@ def run_basic_search(
|
||||
index_names=index_names,
|
||||
streaming=streaming,
|
||||
query=query,
|
||||
verbose=verbose,
|
||||
)
|
||||
)
|
||||
logger.success(f"Basic Search Response:\n{response}")
|
||||
# NOTE: we return the response and context data here purely as a complete demonstration of the API.
|
||||
# External users should use the API directly to get the response and context data.
|
||||
print(response)
|
||||
|
||||
return response, context_data
|
||||
|
||||
# Otherwise, call the Single-Index Basic Search API
|
||||
@ -461,11 +450,13 @@ def run_basic_search(
|
||||
config=config,
|
||||
text_units=final_text_units,
|
||||
query=query,
|
||||
callbacks=[callbacks],
|
||||
verbose=verbose,
|
||||
):
|
||||
full_response += stream_chunk
|
||||
print(stream_chunk, end="") # noqa: T201
|
||||
sys.stdout.flush() # flush output buffer to display text immediately
|
||||
print() # noqa: T201
|
||||
print(stream_chunk, end="")
|
||||
sys.stdout.flush()
|
||||
print()
|
||||
return full_response, context_data
|
||||
|
||||
return asyncio.run(run_streaming_search())
|
||||
@ -475,11 +466,11 @@ def run_basic_search(
|
||||
config=config,
|
||||
text_units=final_text_units,
|
||||
query=query,
|
||||
verbose=verbose,
|
||||
)
|
||||
)
|
||||
logger.success(f"Basic Search Response:\n{response}")
|
||||
# NOTE: we return the response and context data here purely as a complete demonstration of the API.
|
||||
# External users should use the API directly to get the response and context data.
|
||||
print(response)
|
||||
|
||||
return response, context_data
|
||||
|
||||
|
||||
|
||||
@ -3,9 +3,10 @@
|
||||
|
||||
"""Common default configuration values."""
|
||||
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Literal
|
||||
from typing import ClassVar
|
||||
|
||||
from graphrag.config.embeddings import default_embeddings
|
||||
from graphrag.config.enums import (
|
||||
@ -14,32 +15,63 @@ from graphrag.config.enums import (
|
||||
CacheType,
|
||||
ChunkStrategyType,
|
||||
InputFileType,
|
||||
InputType,
|
||||
ModelType,
|
||||
NounPhraseExtractorType,
|
||||
OutputType,
|
||||
ReportingType,
|
||||
StorageType,
|
||||
VectorStoreType,
|
||||
)
|
||||
from graphrag.index.operations.build_noun_graph.np_extractors.stop_words import (
|
||||
EN_STOP_WORDS,
|
||||
)
|
||||
from graphrag.vector_stores.factory import VectorStoreType
|
||||
from graphrag.language_model.providers.litellm.services.rate_limiter.rate_limiter import (
|
||||
RateLimiter,
|
||||
)
|
||||
from graphrag.language_model.providers.litellm.services.rate_limiter.static_rate_limiter import (
|
||||
StaticRateLimiter,
|
||||
)
|
||||
from graphrag.language_model.providers.litellm.services.retry.exponential_retry import (
|
||||
ExponentialRetry,
|
||||
)
|
||||
from graphrag.language_model.providers.litellm.services.retry.incremental_wait_retry import (
|
||||
IncrementalWaitRetry,
|
||||
)
|
||||
from graphrag.language_model.providers.litellm.services.retry.native_wait_retry import (
|
||||
NativeRetry,
|
||||
)
|
||||
from graphrag.language_model.providers.litellm.services.retry.random_wait_retry import (
|
||||
RandomWaitRetry,
|
||||
)
|
||||
from graphrag.language_model.providers.litellm.services.retry.retry import Retry
|
||||
|
||||
DEFAULT_OUTPUT_BASE_DIR = "output"
|
||||
DEFAULT_CHAT_MODEL_ID = "default_chat_model"
|
||||
DEFAULT_CHAT_MODEL_TYPE = ModelType.OpenAIChat
|
||||
DEFAULT_CHAT_MODEL_TYPE = ModelType.Chat
|
||||
DEFAULT_CHAT_MODEL = "gpt-4-turbo-preview"
|
||||
DEFAULT_CHAT_MODEL_AUTH_TYPE = AuthType.APIKey
|
||||
DEFAULT_EMBEDDING_MODEL_ID = "default_embedding_model"
|
||||
DEFAULT_EMBEDDING_MODEL_TYPE = ModelType.OpenAIEmbedding
|
||||
DEFAULT_EMBEDDING_MODEL_TYPE = ModelType.Embedding
|
||||
DEFAULT_EMBEDDING_MODEL = "text-embedding-3-small"
|
||||
DEFAULT_EMBEDDING_MODEL_AUTH_TYPE = AuthType.APIKey
|
||||
DEFAULT_MODEL_PROVIDER = "openai"
|
||||
DEFAULT_VECTOR_STORE_ID = "default_vector_store"
|
||||
|
||||
ENCODING_MODEL = "cl100k_base"
|
||||
COGNITIVE_SERVICES_AUDIENCE = "https://cognitiveservices.azure.com/.default"
|
||||
|
||||
|
||||
DEFAULT_RETRY_SERVICES: dict[str, Callable[..., Retry]] = {
|
||||
"native": NativeRetry,
|
||||
"exponential_backoff": ExponentialRetry,
|
||||
"random_wait": RandomWaitRetry,
|
||||
"incremental_wait": IncrementalWaitRetry,
|
||||
}
|
||||
|
||||
DEFAULT_RATE_LIMITER_SERVICES: dict[str, Callable[..., RateLimiter]] = {
|
||||
"static": StaticRateLimiter,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class BasicSearchDefaults:
|
||||
"""Default values for basic search."""
|
||||
@ -55,7 +87,7 @@ class BasicSearchDefaults:
|
||||
class CacheDefaults:
|
||||
"""Default values for cache."""
|
||||
|
||||
type = CacheType.file
|
||||
type: ClassVar[CacheType] = CacheType.file
|
||||
base_dir: str = "cache"
|
||||
connection_string: None = None
|
||||
container_name: None = None
|
||||
@ -70,7 +102,7 @@ class ChunksDefaults:
|
||||
size: int = 1200
|
||||
overlap: int = 100
|
||||
group_by_columns: list[str] = field(default_factory=lambda: ["id"])
|
||||
strategy = ChunkStrategyType.tokens
|
||||
strategy: ClassVar[ChunkStrategyType] = ChunkStrategyType.tokens
|
||||
encoding_model: str = "cl100k_base"
|
||||
prepend_metadata: bool = False
|
||||
chunk_size_includes_metadata: bool = False
|
||||
@ -120,8 +152,8 @@ class DriftSearchDefaults:
|
||||
local_search_temperature: float = 0
|
||||
local_search_top_p: float = 1
|
||||
local_search_n: int = 1
|
||||
local_search_llm_max_gen_tokens = None
|
||||
local_search_llm_max_gen_completion_tokens = None
|
||||
local_search_llm_max_gen_tokens: int | None = None
|
||||
local_search_llm_max_gen_completion_tokens: int | None = None
|
||||
chat_model_id: str = DEFAULT_CHAT_MODEL_ID
|
||||
embedding_model_id: str = DEFAULT_EMBEDDING_MODEL_ID
|
||||
|
||||
@ -184,7 +216,9 @@ class ExtractGraphDefaults:
|
||||
class TextAnalyzerDefaults:
|
||||
"""Default values for text analyzer."""
|
||||
|
||||
extractor_type = NounPhraseExtractorType.RegexEnglish
|
||||
extractor_type: ClassVar[NounPhraseExtractorType] = (
|
||||
NounPhraseExtractorType.RegexEnglish
|
||||
)
|
||||
model_name: str = "en_core_web_md"
|
||||
max_word_length: int = 15
|
||||
word_delimiter: str = " "
|
||||
@ -213,6 +247,7 @@ class ExtractGraphNLPDefaults:
|
||||
normalize_edge_weights: bool = True
|
||||
text_analyzer: TextAnalyzerDefaults = field(default_factory=TextAnalyzerDefaults)
|
||||
concurrent_requests: int = 25
|
||||
async_mode: AsyncType = AsyncType.Threaded
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -234,16 +269,31 @@ class GlobalSearchDefaults:
|
||||
chat_model_id: str = DEFAULT_CHAT_MODEL_ID
|
||||
|
||||
|
||||
@dataclass
|
||||
class StorageDefaults:
|
||||
"""Default values for storage."""
|
||||
|
||||
type: ClassVar[StorageType] = StorageType.file
|
||||
base_dir: str = DEFAULT_OUTPUT_BASE_DIR
|
||||
connection_string: None = None
|
||||
container_name: None = None
|
||||
storage_account_blob_url: None = None
|
||||
cosmosdb_account_url: None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class InputStorageDefaults(StorageDefaults):
|
||||
"""Default values for input storage."""
|
||||
|
||||
base_dir: str = "input"
|
||||
|
||||
|
||||
@dataclass
|
||||
class InputDefaults:
|
||||
"""Default values for input."""
|
||||
|
||||
type = InputType.file
|
||||
file_type = InputFileType.text
|
||||
base_dir: str = "input"
|
||||
connection_string: None = None
|
||||
storage_account_blob_url: None = None
|
||||
container_name: None = None
|
||||
storage: InputStorageDefaults = field(default_factory=InputStorageDefaults)
|
||||
file_type: ClassVar[InputFileType] = InputFileType.text
|
||||
encoding: str = "utf-8"
|
||||
file_pattern: str = ""
|
||||
file_filter: None = None
|
||||
@ -257,7 +307,8 @@ class LanguageModelDefaults:
|
||||
"""Default values for language model."""
|
||||
|
||||
api_key: None = None
|
||||
auth_type = AuthType.APIKey
|
||||
auth_type: ClassVar[AuthType] = AuthType.APIKey
|
||||
model_provider: str | None = None
|
||||
encoding_model: str = ""
|
||||
max_tokens: int | None = None
|
||||
temperature: float = 0
|
||||
@ -275,9 +326,10 @@ class LanguageModelDefaults:
|
||||
proxy: None = None
|
||||
audience: None = None
|
||||
model_supports_json: None = None
|
||||
tokens_per_minute: Literal["auto"] = "auto"
|
||||
requests_per_minute: Literal["auto"] = "auto"
|
||||
retry_strategy: str = "native"
|
||||
tokens_per_minute: None = None
|
||||
requests_per_minute: None = None
|
||||
rate_limit_strategy: str | None = "static"
|
||||
retry_strategy: str = "exponential_backoff"
|
||||
max_retries: int = 10
|
||||
max_retry_wait: float = 10.0
|
||||
concurrent_requests: int = 25
|
||||
@ -301,15 +353,10 @@ class LocalSearchDefaults:
|
||||
|
||||
|
||||
@dataclass
|
||||
class OutputDefaults:
|
||||
class OutputDefaults(StorageDefaults):
|
||||
"""Default values for output."""
|
||||
|
||||
type = OutputType.file
|
||||
base_dir: str = DEFAULT_OUTPUT_BASE_DIR
|
||||
connection_string: None = None
|
||||
container_name: None = None
|
||||
storage_account_blob_url: None = None
|
||||
cosmosdb_account_url: None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -329,7 +376,7 @@ class PruneGraphDefaults:
|
||||
class ReportingDefaults:
|
||||
"""Default values for reporting."""
|
||||
|
||||
type = ReportingType.file
|
||||
type: ClassVar[ReportingType] = ReportingType.file
|
||||
base_dir: str = "logs"
|
||||
connection_string: None = None
|
||||
container_name: None = None
|
||||
@ -364,21 +411,17 @@ class UmapDefaults:
|
||||
|
||||
|
||||
@dataclass
|
||||
class UpdateIndexOutputDefaults:
|
||||
class UpdateIndexOutputDefaults(StorageDefaults):
|
||||
"""Default values for update index output."""
|
||||
|
||||
type = OutputType.file
|
||||
base_dir: str = "update_output"
|
||||
connection_string: None = None
|
||||
container_name: None = None
|
||||
storage_account_blob_url: None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class VectorStoreDefaults:
|
||||
"""Default values for vector stores."""
|
||||
|
||||
type = VectorStoreType.LanceDB.value
|
||||
type: ClassVar[str] = VectorStoreType.LanceDB.value
|
||||
db_uri: str = str(Path(DEFAULT_OUTPUT_BASE_DIR) / "lancedb")
|
||||
container_name: str = "default"
|
||||
overwrite: bool = True
|
||||
@ -386,6 +429,7 @@ class VectorStoreDefaults:
|
||||
api_key: None = None
|
||||
audience: None = None
|
||||
database_name: None = None
|
||||
schema: None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -395,6 +439,7 @@ class GraphRagConfigDefaults:
|
||||
root_dir: str = ""
|
||||
models: dict = field(default_factory=dict)
|
||||
reporting: ReportingDefaults = field(default_factory=ReportingDefaults)
|
||||
storage: StorageDefaults = field(default_factory=StorageDefaults)
|
||||
output: OutputDefaults = field(default_factory=OutputDefaults)
|
||||
outputs: None = None
|
||||
update_index_output: UpdateIndexOutputDefaults = field(
|
||||
|
||||
@ -29,14 +29,14 @@ default_embeddings: list[str] = [
|
||||
]
|
||||
|
||||
|
||||
def create_collection_name(
|
||||
def create_index_name(
|
||||
container_name: str, embedding_name: str, validate: bool = True
|
||||
) -> str:
|
||||
"""
|
||||
Create a collection name for the embedding store.
|
||||
Create a index name for the embedding store.
|
||||
|
||||
Within any given vector store, we can have multiple sets of embeddings organized into projects.
|
||||
The `container` param is used for this partitioning, and is added as a prefix to the collection name for differentiation.
|
||||
The `container` param is used for this partitioning, and is added as a prefix to the index name for differentiation.
|
||||
|
||||
The embedding name is fixed, with the available list defined in graphrag.index.config.embeddings
|
||||
|
||||
|
||||
@ -42,20 +42,7 @@ class InputFileType(str, Enum):
|
||||
return f'"{self.value}"'
|
||||
|
||||
|
||||
class InputType(str, Enum):
|
||||
"""The input type for the pipeline."""
|
||||
|
||||
file = "file"
|
||||
"""The file storage type."""
|
||||
blob = "blob"
|
||||
"""The blob storage type."""
|
||||
|
||||
def __repr__(self):
|
||||
"""Get a string representation."""
|
||||
return f'"{self.value}"'
|
||||
|
||||
|
||||
class OutputType(str, Enum):
|
||||
class StorageType(str, Enum):
|
||||
"""The output type for the pipeline."""
|
||||
|
||||
file = "file"
|
||||
@ -72,13 +59,19 @@ class OutputType(str, Enum):
|
||||
return f'"{self.value}"'
|
||||
|
||||
|
||||
class VectorStoreType(str, Enum):
|
||||
"""The supported vector store types."""
|
||||
|
||||
LanceDB = "lancedb"
|
||||
AzureAISearch = "azure_ai_search"
|
||||
CosmosDB = "cosmosdb"
|
||||
|
||||
|
||||
class ReportingType(str, Enum):
|
||||
"""The reporting configuration type for the pipeline."""
|
||||
|
||||
file = "file"
|
||||
"""The file reporting configuration type."""
|
||||
console = "console"
|
||||
"""The console reporting configuration type."""
|
||||
blob = "blob"
|
||||
"""The blob reporting configuration type."""
|
||||
|
||||
@ -93,10 +86,12 @@ class ModelType(str, Enum):
|
||||
# Embeddings
|
||||
OpenAIEmbedding = "openai_embedding"
|
||||
AzureOpenAIEmbedding = "azure_openai_embedding"
|
||||
Embedding = "embedding"
|
||||
|
||||
# Chat Completion
|
||||
OpenAIChat = "openai_chat"
|
||||
AzureOpenAIChat = "azure_openai_chat"
|
||||
Chat = "chat"
|
||||
|
||||
# Debug
|
||||
MockChat = "mock_chat"
|
||||
@ -152,6 +147,10 @@ class IndexingMethod(str, Enum):
|
||||
"""Traditional GraphRAG indexing, with all graph construction and summarization performed by a language model."""
|
||||
Fast = "fast"
|
||||
"""Fast indexing, using NLP for graph construction and language model for summarization."""
|
||||
StandardUpdate = "standard-update"
|
||||
"""Incremental update with standard indexing."""
|
||||
FastUpdate = "fast-update"
|
||||
"""Incremental update with fast indexing."""
|
||||
|
||||
|
||||
class NounPhraseExtractorType(str, Enum):
|
||||
|
||||
@ -33,15 +33,6 @@ class AzureApiVersionMissingError(ValueError):
|
||||
super().__init__(msg)
|
||||
|
||||
|
||||
class AzureDeploymentNameMissingError(ValueError):
|
||||
"""Azure Deployment Name missing error."""
|
||||
|
||||
def __init__(self, llm_type: str) -> None:
|
||||
"""Init method definition."""
|
||||
msg = f"Deployment name is required for {llm_type}. Please rerun `graphrag init` set the deployment_name."
|
||||
super().__init__(msg)
|
||||
|
||||
|
||||
class LanguageModelConfigMissingError(ValueError):
|
||||
"""Missing model configuration error."""
|
||||
|
||||
|
||||
@ -11,7 +11,6 @@ def get_embedding_settings(
|
||||
vector_store_params: dict | None = None,
|
||||
) -> dict:
|
||||
"""Transform GraphRAG config into settings for workflows."""
|
||||
# TEMP
|
||||
embeddings_llm_settings = settings.get_language_model_config(
|
||||
settings.embed_text.model_id
|
||||
)
|
||||
|
||||
@ -19,48 +19,42 @@ INIT_YAML = f"""\
|
||||
|
||||
models:
|
||||
{defs.DEFAULT_CHAT_MODEL_ID}:
|
||||
type: {defs.DEFAULT_CHAT_MODEL_TYPE.value} # or azure_openai_chat
|
||||
# api_base: https://<instance>.openai.azure.com
|
||||
# api_version: 2024-05-01-preview
|
||||
type: {defs.DEFAULT_CHAT_MODEL_TYPE.value}
|
||||
model_provider: {defs.DEFAULT_MODEL_PROVIDER}
|
||||
auth_type: {defs.DEFAULT_CHAT_MODEL_AUTH_TYPE.value} # or azure_managed_identity
|
||||
api_key: ${{GRAPHRAG_API_KEY}} # set this in the generated .env file
|
||||
# audience: "https://cognitiveservices.azure.com/.default"
|
||||
# organization: <organization_id>
|
||||
api_key: ${{GRAPHRAG_API_KEY}} # set this in the generated .env file, or remove if managed identity
|
||||
model: {defs.DEFAULT_CHAT_MODEL}
|
||||
# deployment_name: <azure_model_deployment_name>
|
||||
# encoding_model: {defs.ENCODING_MODEL} # automatically set by tiktoken if left undefined
|
||||
model_supports_json: true # recommended if this is available for your model.
|
||||
concurrent_requests: {language_model_defaults.concurrent_requests} # max number of simultaneous LLM requests allowed
|
||||
async_mode: {language_model_defaults.async_mode.value} # or asyncio
|
||||
retry_strategy: native
|
||||
max_retries: {language_model_defaults.max_retries}
|
||||
tokens_per_minute: {language_model_defaults.tokens_per_minute} # set to null to disable rate limiting
|
||||
requests_per_minute: {language_model_defaults.requests_per_minute} # set to null to disable rate limiting
|
||||
{defs.DEFAULT_EMBEDDING_MODEL_ID}:
|
||||
type: {defs.DEFAULT_EMBEDDING_MODEL_TYPE.value} # or azure_openai_embedding
|
||||
# api_base: https://<instance>.openai.azure.com
|
||||
# api_version: 2024-05-01-preview
|
||||
auth_type: {defs.DEFAULT_EMBEDDING_MODEL_AUTH_TYPE.value} # or azure_managed_identity
|
||||
api_key: ${{GRAPHRAG_API_KEY}}
|
||||
# audience: "https://cognitiveservices.azure.com/.default"
|
||||
# organization: <organization_id>
|
||||
model: {defs.DEFAULT_EMBEDDING_MODEL}
|
||||
# deployment_name: <azure_model_deployment_name>
|
||||
# encoding_model: {defs.ENCODING_MODEL} # automatically set by tiktoken if left undefined
|
||||
model_supports_json: true # recommended if this is available for your model.
|
||||
concurrent_requests: {language_model_defaults.concurrent_requests} # max number of simultaneous LLM requests allowed
|
||||
concurrent_requests: {language_model_defaults.concurrent_requests}
|
||||
async_mode: {language_model_defaults.async_mode.value} # or asyncio
|
||||
retry_strategy: native
|
||||
retry_strategy: {language_model_defaults.retry_strategy}
|
||||
max_retries: {language_model_defaults.max_retries}
|
||||
tokens_per_minute: {language_model_defaults.tokens_per_minute} # set to null to disable rate limiting
|
||||
requests_per_minute: {language_model_defaults.requests_per_minute} # set to null to disable rate limiting
|
||||
tokens_per_minute: null
|
||||
requests_per_minute: null
|
||||
{defs.DEFAULT_EMBEDDING_MODEL_ID}:
|
||||
type: {defs.DEFAULT_EMBEDDING_MODEL_TYPE.value}
|
||||
model_provider: {defs.DEFAULT_MODEL_PROVIDER}
|
||||
auth_type: {defs.DEFAULT_EMBEDDING_MODEL_AUTH_TYPE.value}
|
||||
api_key: ${{GRAPHRAG_API_KEY}}
|
||||
model: {defs.DEFAULT_EMBEDDING_MODEL}
|
||||
# api_base: https://<instance>.openai.azure.com
|
||||
# api_version: 2024-05-01-preview
|
||||
concurrent_requests: {language_model_defaults.concurrent_requests}
|
||||
async_mode: {language_model_defaults.async_mode.value} # or asyncio
|
||||
retry_strategy: {language_model_defaults.retry_strategy}
|
||||
max_retries: {language_model_defaults.max_retries}
|
||||
tokens_per_minute: null
|
||||
requests_per_minute: null
|
||||
|
||||
### Input settings ###
|
||||
|
||||
input:
|
||||
type: {graphrag_config_defaults.input.type.value} # or blob
|
||||
storage:
|
||||
type: {graphrag_config_defaults.input.storage.type.value} # or blob
|
||||
base_dir: "{graphrag_config_defaults.input.storage.base_dir}"
|
||||
file_type: {graphrag_config_defaults.input.file_type.value} # [csv, text, json]
|
||||
base_dir: "{graphrag_config_defaults.input.base_dir}"
|
||||
|
||||
chunks:
|
||||
size: {graphrag_config_defaults.chunks.size}
|
||||
@ -80,7 +74,7 @@ cache:
|
||||
base_dir: "{graphrag_config_defaults.cache.base_dir}"
|
||||
|
||||
reporting:
|
||||
type: {graphrag_config_defaults.reporting.type.value} # [file, blob, cosmosdb]
|
||||
type: {graphrag_config_defaults.reporting.type.value} # [file, blob]
|
||||
base_dir: "{graphrag_config_defaults.reporting.base_dir}"
|
||||
|
||||
vector_store:
|
||||
@ -88,7 +82,6 @@ vector_store:
|
||||
type: {vector_store_defaults.type}
|
||||
db_uri: {vector_store_defaults.db_uri}
|
||||
container_name: {vector_store_defaults.container_name}
|
||||
overwrite: {vector_store_defaults.overwrite}
|
||||
|
||||
### Workflow settings ###
|
||||
|
||||
@ -110,6 +103,7 @@ summarize_descriptions:
|
||||
extract_graph_nlp:
|
||||
text_analyzer:
|
||||
extractor_type: {graphrag_config_defaults.extract_graph_nlp.text_analyzer.extractor_type.value} # [regex_english, syntactic_parser, cfg]
|
||||
async_mode: {graphrag_config_defaults.extract_graph_nlp.async_mode.value} # or asyncio
|
||||
|
||||
cluster_graph:
|
||||
max_cluster_size: {graphrag_config_defaults.cluster_graph.max_cluster_size}
|
||||
|
||||
@ -1,61 +0,0 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""Logging utilities. A unified way for enabling logging."""
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
from graphrag.config.enums import ReportingType
|
||||
from graphrag.config.models.graph_rag_config import GraphRagConfig
|
||||
|
||||
|
||||
def enable_logging(log_filepath: str | Path, verbose: bool = False) -> None:
|
||||
"""Enable logging to a file.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
log_filepath : str | Path
|
||||
The path to the log file.
|
||||
verbose : bool, default=False
|
||||
Whether to log debug messages.
|
||||
"""
|
||||
log_filepath = Path(log_filepath)
|
||||
log_filepath.parent.mkdir(parents=True, exist_ok=True)
|
||||
log_filepath.touch(exist_ok=True)
|
||||
|
||||
logging.basicConfig(
|
||||
filename=log_filepath,
|
||||
filemode="a",
|
||||
format="%(asctime)s,%(msecs)d %(name)s %(levelname)s %(message)s",
|
||||
datefmt="%H:%M:%S",
|
||||
level=logging.DEBUG if verbose else logging.INFO,
|
||||
)
|
||||
|
||||
|
||||
def enable_logging_with_config(
|
||||
config: GraphRagConfig, verbose: bool = False, filename: str = "indexing-engine.log"
|
||||
) -> tuple[bool, str]:
|
||||
"""Enable logging to a file based on the config.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
config : GraphRagConfig
|
||||
The configuration.
|
||||
timestamp_value : str
|
||||
The timestamp value representing the directory to place the log files.
|
||||
verbose : bool, default=False
|
||||
Whether to log debug messages.
|
||||
|
||||
Returns
|
||||
-------
|
||||
tuple[bool, str]
|
||||
A tuple of a boolean indicating if logging was enabled and the path to the log file.
|
||||
(False, "") if logging was not enabled.
|
||||
(True, str) if logging was enabled.
|
||||
"""
|
||||
if config.reporting.type == ReportingType.file:
|
||||
log_path = Path(config.reporting.base_dir) / filename
|
||||
enable_logging(log_path, verbose)
|
||||
return (True, str(log_path))
|
||||
return (False, "")
|
||||
@ -12,7 +12,7 @@ from graphrag.config.enums import CacheType
|
||||
class CacheConfig(BaseModel):
|
||||
"""The default configuration section for Cache."""
|
||||
|
||||
type: CacheType = Field(
|
||||
type: CacheType | str = Field(
|
||||
description="The cache type to use.",
|
||||
default=graphrag_config_defaults.cache.type,
|
||||
)
|
||||
|
||||
@ -6,7 +6,7 @@
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from graphrag.config.defaults import graphrag_config_defaults
|
||||
from graphrag.config.enums import NounPhraseExtractorType
|
||||
from graphrag.config.enums import AsyncType, NounPhraseExtractorType
|
||||
|
||||
|
||||
class TextAnalyzerConfig(BaseModel):
|
||||
@ -68,3 +68,7 @@ class ExtractGraphNLPConfig(BaseModel):
|
||||
description="The number of threads to use for the extraction process.",
|
||||
default=graphrag_config_defaults.extract_graph_nlp.concurrent_requests,
|
||||
)
|
||||
async_mode: AsyncType = Field(
|
||||
description="The async mode to use.",
|
||||
default=graphrag_config_defaults.extract_graph_nlp.async_mode,
|
||||
)
|
||||
|
||||
@ -11,6 +11,7 @@ from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
import graphrag.config.defaults as defs
|
||||
from graphrag.config.defaults import graphrag_config_defaults
|
||||
from graphrag.config.enums import VectorStoreType
|
||||
from graphrag.config.errors import LanguageModelConfigMissingError
|
||||
from graphrag.config.models.basic_search_config import BasicSearchConfig
|
||||
from graphrag.config.models.cache_config import CacheConfig
|
||||
@ -26,17 +27,22 @@ from graphrag.config.models.global_search_config import GlobalSearchConfig
|
||||
from graphrag.config.models.input_config import InputConfig
|
||||
from graphrag.config.models.language_model_config import LanguageModelConfig
|
||||
from graphrag.config.models.local_search_config import LocalSearchConfig
|
||||
from graphrag.config.models.output_config import OutputConfig
|
||||
from graphrag.config.models.prune_graph_config import PruneGraphConfig
|
||||
from graphrag.config.models.reporting_config import ReportingConfig
|
||||
from graphrag.config.models.snapshots_config import SnapshotsConfig
|
||||
from graphrag.config.models.storage_config import StorageConfig
|
||||
from graphrag.config.models.summarize_descriptions_config import (
|
||||
SummarizeDescriptionsConfig,
|
||||
)
|
||||
from graphrag.config.models.text_embedding_config import TextEmbeddingConfig
|
||||
from graphrag.config.models.umap_config import UmapConfig
|
||||
from graphrag.config.models.vector_store_config import VectorStoreConfig
|
||||
from graphrag.vector_stores.factory import VectorStoreType
|
||||
from graphrag.language_model.providers.litellm.services.rate_limiter.rate_limiter_factory import (
|
||||
RateLimiterFactory,
|
||||
)
|
||||
from graphrag.language_model.providers.litellm.services.retry.retry_factory import (
|
||||
RetryFactory,
|
||||
)
|
||||
|
||||
|
||||
class GraphRagConfig(BaseModel):
|
||||
@ -89,6 +95,47 @@ class GraphRagConfig(BaseModel):
|
||||
if defs.DEFAULT_EMBEDDING_MODEL_ID not in self.models:
|
||||
raise LanguageModelConfigMissingError(defs.DEFAULT_EMBEDDING_MODEL_ID)
|
||||
|
||||
def _validate_retry_services(self) -> None:
|
||||
"""Validate the retry services configuration."""
|
||||
retry_factory = RetryFactory()
|
||||
|
||||
for model_id, model in self.models.items():
|
||||
if model.retry_strategy != "none":
|
||||
if model.retry_strategy not in retry_factory:
|
||||
msg = f"Retry strategy '{model.retry_strategy}' for model '{model_id}' is not registered. Available strategies: {', '.join(retry_factory.keys())}"
|
||||
raise ValueError(msg)
|
||||
|
||||
_ = retry_factory.create(
|
||||
strategy=model.retry_strategy,
|
||||
max_retries=model.max_retries,
|
||||
max_retry_wait=model.max_retry_wait,
|
||||
)
|
||||
|
||||
def _validate_rate_limiter_services(self) -> None:
|
||||
"""Validate the rate limiter services configuration."""
|
||||
rate_limiter_factory = RateLimiterFactory()
|
||||
|
||||
for model_id, model in self.models.items():
|
||||
if model.rate_limit_strategy is not None:
|
||||
if model.rate_limit_strategy not in rate_limiter_factory:
|
||||
msg = f"Rate Limiter strategy '{model.rate_limit_strategy}' for model '{model_id}' is not registered. Available strategies: {', '.join(rate_limiter_factory.keys())}"
|
||||
raise ValueError(msg)
|
||||
|
||||
rpm = (
|
||||
model.requests_per_minute
|
||||
if type(model.requests_per_minute) is int
|
||||
else None
|
||||
)
|
||||
tpm = (
|
||||
model.tokens_per_minute
|
||||
if type(model.tokens_per_minute) is int
|
||||
else None
|
||||
)
|
||||
if rpm is not None or tpm is not None:
|
||||
_ = rate_limiter_factory.create(
|
||||
strategy=model.rate_limit_strategy, rpm=rpm, tpm=tpm
|
||||
)
|
||||
|
||||
input: InputConfig = Field(
|
||||
description="The input configuration.", default=InputConfig()
|
||||
)
|
||||
@ -102,21 +149,31 @@ class GraphRagConfig(BaseModel):
|
||||
else:
|
||||
self.input.file_pattern = f".*\\.{self.input.file_type.value}$"
|
||||
|
||||
def _validate_input_base_dir(self) -> None:
|
||||
"""Validate the input base directory."""
|
||||
if self.input.storage.type == defs.StorageType.file:
|
||||
if self.input.storage.base_dir.strip() == "":
|
||||
msg = "input storage base directory is required for file input storage. Please rerun `graphrag init` and set the input storage configuration."
|
||||
raise ValueError(msg)
|
||||
self.input.storage.base_dir = str(
|
||||
(Path(self.root_dir) / self.input.storage.base_dir).resolve()
|
||||
)
|
||||
|
||||
chunks: ChunkingConfig = Field(
|
||||
description="The chunking configuration to use.",
|
||||
default=ChunkingConfig(),
|
||||
)
|
||||
"""The chunking configuration to use."""
|
||||
|
||||
output: OutputConfig = Field(
|
||||
output: StorageConfig = Field(
|
||||
description="The output configuration.",
|
||||
default=OutputConfig(),
|
||||
default=StorageConfig(),
|
||||
)
|
||||
"""The output configuration."""
|
||||
|
||||
def _validate_output_base_dir(self) -> None:
|
||||
"""Validate the output base directory."""
|
||||
if self.output.type == defs.OutputType.file:
|
||||
if self.output.type == defs.StorageType.file:
|
||||
if self.output.base_dir.strip() == "":
|
||||
msg = "output base directory is required for file output. Please rerun `graphrag init` and set the output configuration."
|
||||
raise ValueError(msg)
|
||||
@ -124,7 +181,7 @@ class GraphRagConfig(BaseModel):
|
||||
(Path(self.root_dir) / self.output.base_dir).resolve()
|
||||
)
|
||||
|
||||
outputs: dict[str, OutputConfig] | None = Field(
|
||||
outputs: dict[str, StorageConfig] | None = Field(
|
||||
description="A list of output configurations used for multi-index query.",
|
||||
default=graphrag_config_defaults.outputs,
|
||||
)
|
||||
@ -133,7 +190,7 @@ class GraphRagConfig(BaseModel):
|
||||
"""Validate the outputs dict base directories."""
|
||||
if self.outputs:
|
||||
for output in self.outputs.values():
|
||||
if output.type == defs.OutputType.file:
|
||||
if output.type == defs.StorageType.file:
|
||||
if output.base_dir.strip() == "":
|
||||
msg = "Output base directory is required for file output. Please rerun `graphrag init` and set the output configuration."
|
||||
raise ValueError(msg)
|
||||
@ -141,10 +198,9 @@ class GraphRagConfig(BaseModel):
|
||||
(Path(self.root_dir) / output.base_dir).resolve()
|
||||
)
|
||||
|
||||
update_index_output: OutputConfig = Field(
|
||||
update_index_output: StorageConfig = Field(
|
||||
description="The output configuration for the updated index.",
|
||||
default=OutputConfig(
|
||||
type=graphrag_config_defaults.update_index_output.type,
|
||||
default=StorageConfig(
|
||||
base_dir=graphrag_config_defaults.update_index_output.base_dir,
|
||||
),
|
||||
)
|
||||
@ -152,7 +208,7 @@ class GraphRagConfig(BaseModel):
|
||||
|
||||
def _validate_update_index_output_base_dir(self) -> None:
|
||||
"""Validate the update index output base directory."""
|
||||
if self.update_index_output.type == defs.OutputType.file:
|
||||
if self.update_index_output.type == defs.StorageType.file:
|
||||
if self.update_index_output.base_dir.strip() == "":
|
||||
msg = "update_index_output base directory is required for file output. Please rerun `graphrag init` and set the update_index_output configuration."
|
||||
raise ValueError(msg)
|
||||
@ -291,6 +347,11 @@ class GraphRagConfig(BaseModel):
|
||||
raise ValueError(msg)
|
||||
store.db_uri = str((Path(self.root_dir) / store.db_uri).resolve())
|
||||
|
||||
def _validate_factories(self) -> None:
|
||||
"""Validate the factories used in the configuration."""
|
||||
self._validate_retry_services()
|
||||
self._validate_rate_limiter_services()
|
||||
|
||||
def get_language_model_config(self, model_id: str) -> LanguageModelConfig:
|
||||
"""Get a model configuration by ID.
|
||||
|
||||
@ -345,9 +406,11 @@ class GraphRagConfig(BaseModel):
|
||||
self._validate_root_dir()
|
||||
self._validate_models()
|
||||
self._validate_input_pattern()
|
||||
self._validate_input_base_dir()
|
||||
self._validate_reporting_base_dir()
|
||||
self._validate_output_base_dir()
|
||||
self._validate_multi_output_base_dirs()
|
||||
self._validate_update_index_output_base_dir()
|
||||
self._validate_vector_store_db_uri()
|
||||
self._validate_factories()
|
||||
return self
|
||||
|
||||
@ -7,36 +7,23 @@ from pydantic import BaseModel, Field
|
||||
|
||||
import graphrag.config.defaults as defs
|
||||
from graphrag.config.defaults import graphrag_config_defaults
|
||||
from graphrag.config.enums import InputFileType, InputType
|
||||
from graphrag.config.enums import InputFileType
|
||||
from graphrag.config.models.storage_config import StorageConfig
|
||||
|
||||
|
||||
class InputConfig(BaseModel):
|
||||
"""The default configuration section for Input."""
|
||||
|
||||
type: InputType = Field(
|
||||
description="The input type to use.",
|
||||
default=graphrag_config_defaults.input.type,
|
||||
storage: StorageConfig = Field(
|
||||
description="The storage configuration to use for reading input documents.",
|
||||
default=StorageConfig(
|
||||
base_dir=graphrag_config_defaults.input.storage.base_dir,
|
||||
),
|
||||
)
|
||||
file_type: InputFileType = Field(
|
||||
description="The input file type to use.",
|
||||
default=graphrag_config_defaults.input.file_type,
|
||||
)
|
||||
base_dir: str = Field(
|
||||
description="The input base directory to use.",
|
||||
default=graphrag_config_defaults.input.base_dir,
|
||||
)
|
||||
connection_string: str | None = Field(
|
||||
description="The azure blob storage connection string to use.",
|
||||
default=graphrag_config_defaults.input.connection_string,
|
||||
)
|
||||
storage_account_blob_url: str | None = Field(
|
||||
description="The storage account blob url to use.",
|
||||
default=graphrag_config_defaults.input.storage_account_blob_url,
|
||||
)
|
||||
container_name: str | None = Field(
|
||||
description="The azure blob storage container name to use.",
|
||||
default=graphrag_config_defaults.input.container_name,
|
||||
)
|
||||
encoding: str = Field(
|
||||
description="The input file encoding to use.",
|
||||
default=defs.graphrag_config_defaults.input.encoding,
|
||||
|
||||
@ -3,6 +3,7 @@
|
||||
|
||||
"""Language model configuration."""
|
||||
|
||||
import logging
|
||||
from typing import Literal
|
||||
|
||||
import tiktoken
|
||||
@ -14,11 +15,12 @@ from graphrag.config.errors import (
|
||||
ApiKeyMissingError,
|
||||
AzureApiBaseMissingError,
|
||||
AzureApiVersionMissingError,
|
||||
AzureDeploymentNameMissingError,
|
||||
ConflictingSettingsError,
|
||||
)
|
||||
from graphrag.language_model.factory import ModelFactory
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LanguageModelConfig(BaseModel):
|
||||
"""Language model configuration."""
|
||||
@ -73,8 +75,11 @@ class LanguageModelConfig(BaseModel):
|
||||
ConflictingSettingsError
|
||||
If the Azure authentication type conflicts with the model being used.
|
||||
"""
|
||||
if self.auth_type == AuthType.AzureManagedIdentity and (
|
||||
self.type == ModelType.OpenAIChat or self.type == ModelType.OpenAIEmbedding
|
||||
if (
|
||||
self.auth_type == AuthType.AzureManagedIdentity
|
||||
and self.type != ModelType.AzureOpenAIChat
|
||||
and self.type != ModelType.AzureOpenAIEmbedding
|
||||
and self.model_provider != "azure" # indicates Litellm + AOI
|
||||
):
|
||||
msg = f"auth_type of azure_managed_identity is not supported for model type {self.type}. Please rerun `graphrag init` and set the auth_type to api_key."
|
||||
raise ConflictingSettingsError(msg)
|
||||
@ -93,6 +98,35 @@ class LanguageModelConfig(BaseModel):
|
||||
if not ModelFactory.is_supported_model(self.type):
|
||||
msg = f"Model type {self.type} is not recognized, must be one of {ModelFactory.get_chat_models() + ModelFactory.get_embedding_models()}."
|
||||
raise KeyError(msg)
|
||||
if self.type in [
|
||||
"openai_chat",
|
||||
"openai_embedding",
|
||||
"azure_openai_chat",
|
||||
"azure_openai_embedding",
|
||||
]:
|
||||
msg = f"Model config based on fnllm is deprecated and will be removed in GraphRAG v3, please use {ModelType.Chat} or {ModelType.Embedding} instead to switch to LiteLLM config."
|
||||
logger.warning(msg)
|
||||
|
||||
model_provider: str | None = Field(
|
||||
description="The model provider to use.",
|
||||
default=language_model_defaults.model_provider,
|
||||
)
|
||||
|
||||
def _validate_model_provider(self) -> None:
|
||||
"""Validate the model provider.
|
||||
|
||||
Required when using Litellm.
|
||||
|
||||
Raises
|
||||
------
|
||||
KeyError
|
||||
If the model provider is not recognized.
|
||||
"""
|
||||
if (self.type == ModelType.Chat or self.type == ModelType.Embedding) and (
|
||||
self.model_provider is None or self.model_provider.strip() == ""
|
||||
):
|
||||
msg = f"Model provider must be specified when using type == {self.type}."
|
||||
raise KeyError(msg)
|
||||
|
||||
model: str = Field(description="The LLM model to use.")
|
||||
encoding_model: str = Field(
|
||||
@ -103,12 +137,27 @@ class LanguageModelConfig(BaseModel):
|
||||
def _validate_encoding_model(self) -> None:
|
||||
"""Validate the encoding model.
|
||||
|
||||
The default behavior is to use an encoding model that matches the LLM model.
|
||||
LiteLLM supports 100+ models and their tokenization. There is no need to
|
||||
set the encoding model when using the new LiteLLM provider as was done with fnllm provider.
|
||||
|
||||
Users can still manually specify a tiktoken based encoding model to use even with the LiteLLM provider
|
||||
in which case the specified encoding model will be used regardless of the LLM model being used, even if
|
||||
it is not an openai based model.
|
||||
|
||||
If not using LiteLLM provider, set the encoding model based on the LLM model name.
|
||||
This is for backward compatibility with existing fnllm provider until fnllm is removed.
|
||||
|
||||
Raises
|
||||
------
|
||||
KeyError
|
||||
If the model name is not recognized.
|
||||
"""
|
||||
if self.encoding_model.strip() == "":
|
||||
if (
|
||||
self.type != ModelType.Chat
|
||||
and self.type != ModelType.Embedding
|
||||
and self.encoding_model.strip() == ""
|
||||
):
|
||||
self.encoding_model = tiktoken.encoding_name_for_model(self.model)
|
||||
|
||||
api_base: str | None = Field(
|
||||
@ -129,6 +178,7 @@ class LanguageModelConfig(BaseModel):
|
||||
if (
|
||||
self.type == ModelType.AzureOpenAIChat
|
||||
or self.type == ModelType.AzureOpenAIEmbedding
|
||||
or self.model_provider == "azure" # indicates Litellm + AOI
|
||||
) and (self.api_base is None or self.api_base.strip() == ""):
|
||||
raise AzureApiBaseMissingError(self.type)
|
||||
|
||||
@ -150,6 +200,7 @@ class LanguageModelConfig(BaseModel):
|
||||
if (
|
||||
self.type == ModelType.AzureOpenAIChat
|
||||
or self.type == ModelType.AzureOpenAIEmbedding
|
||||
or self.model_provider == "azure" # indicates Litellm + AOI
|
||||
) and (self.api_version is None or self.api_version.strip() == ""):
|
||||
raise AzureApiVersionMissingError(self.type)
|
||||
|
||||
@ -171,8 +222,10 @@ class LanguageModelConfig(BaseModel):
|
||||
if (
|
||||
self.type == ModelType.AzureOpenAIChat
|
||||
or self.type == ModelType.AzureOpenAIEmbedding
|
||||
or self.model_provider == "azure" # indicates Litellm + AOI
|
||||
) and (self.deployment_name is None or self.deployment_name.strip() == ""):
|
||||
raise AzureDeploymentNameMissingError(self.type)
|
||||
msg = f"deployment_name is not set for Azure-hosted model. This will default to your model name ({self.model}). If different, this should be set."
|
||||
logger.debug(msg)
|
||||
|
||||
organization: str | None = Field(
|
||||
description="The organization to use for the LLM service.",
|
||||
@ -212,6 +265,14 @@ class LanguageModelConfig(BaseModel):
|
||||
msg = f"Tokens per minute must be a non zero positive number, 'auto' or null. Suggested value: {language_model_defaults.tokens_per_minute}."
|
||||
raise ValueError(msg)
|
||||
|
||||
if (
|
||||
(self.type == ModelType.Chat or self.type == ModelType.Embedding)
|
||||
and self.rate_limit_strategy is not None
|
||||
and self.tokens_per_minute == "auto"
|
||||
):
|
||||
msg = f"tokens_per_minute cannot be set to 'auto' when using type '{self.type}'. Please set it to a positive integer or null to disable."
|
||||
raise ValueError(msg)
|
||||
|
||||
requests_per_minute: int | Literal["auto"] | None = Field(
|
||||
description="The number of requests per minute to use for the LLM service.",
|
||||
default=language_model_defaults.requests_per_minute,
|
||||
@ -230,6 +291,19 @@ class LanguageModelConfig(BaseModel):
|
||||
msg = f"Requests per minute must be a non zero positive number, 'auto' or null. Suggested value: {language_model_defaults.requests_per_minute}."
|
||||
raise ValueError(msg)
|
||||
|
||||
if (
|
||||
(self.type == ModelType.Chat or self.type == ModelType.Embedding)
|
||||
and self.rate_limit_strategy is not None
|
||||
and self.requests_per_minute == "auto"
|
||||
):
|
||||
msg = f"requests_per_minute cannot be set to 'auto' when using type '{self.type}'. Please set it to a positive integer or null to disable."
|
||||
raise ValueError(msg)
|
||||
|
||||
rate_limit_strategy: str | None = Field(
|
||||
description="The rate limit strategy to use for the LLM service.",
|
||||
default=language_model_defaults.rate_limit_strategy,
|
||||
)
|
||||
|
||||
retry_strategy: str = Field(
|
||||
description="The retry strategy to use for the LLM service.",
|
||||
default=language_model_defaults.retry_strategy,
|
||||
@ -318,6 +392,7 @@ class LanguageModelConfig(BaseModel):
|
||||
@model_validator(mode="after")
|
||||
def _validate_model(self):
|
||||
self._validate_type()
|
||||
self._validate_model_provider()
|
||||
self._validate_auth_type()
|
||||
self._validate_api_key()
|
||||
self._validate_tokens_per_minute()
|
||||
|
||||
@ -1,38 +0,0 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""Parameterization settings for the default configuration."""
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from graphrag.config.defaults import graphrag_config_defaults
|
||||
from graphrag.config.enums import OutputType
|
||||
|
||||
|
||||
class OutputConfig(BaseModel):
|
||||
"""The default configuration section for Output."""
|
||||
|
||||
type: OutputType = Field(
|
||||
description="The output type to use.",
|
||||
default=graphrag_config_defaults.output.type,
|
||||
)
|
||||
base_dir: str = Field(
|
||||
description="The base directory for the output.",
|
||||
default=graphrag_config_defaults.output.base_dir,
|
||||
)
|
||||
connection_string: str | None = Field(
|
||||
description="The storage connection string to use.",
|
||||
default=graphrag_config_defaults.output.connection_string,
|
||||
)
|
||||
container_name: str | None = Field(
|
||||
description="The storage container name to use.",
|
||||
default=graphrag_config_defaults.output.container_name,
|
||||
)
|
||||
storage_account_blob_url: str | None = Field(
|
||||
description="The storage account blob url to use.",
|
||||
default=graphrag_config_defaults.output.storage_account_blob_url,
|
||||
)
|
||||
cosmosdb_account_url: str | None = Field(
|
||||
description="The cosmosdb account url to use.",
|
||||
default=graphrag_config_defaults.output.cosmosdb_account_url,
|
||||
)
|
||||
@ -12,7 +12,7 @@ from graphrag.config.enums import ReportingType
|
||||
class ReportingConfig(BaseModel):
|
||||
"""The default configuration section for Reporting."""
|
||||
|
||||
type: ReportingType = Field(
|
||||
type: ReportingType | str = Field(
|
||||
description="The reporting type to use.",
|
||||
default=graphrag_config_defaults.reporting.type,
|
||||
)
|
||||
|
||||
52
graphrag/config/models/storage_config.py
Normal file
52
graphrag/config/models/storage_config.py
Normal file
@ -0,0 +1,52 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""Parameterization settings for the default configuration."""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from graphrag.config.defaults import graphrag_config_defaults
|
||||
from graphrag.config.enums import StorageType
|
||||
|
||||
|
||||
class StorageConfig(BaseModel):
|
||||
"""The default configuration section for storage."""
|
||||
|
||||
type: StorageType | str = Field(
|
||||
description="The storage type to use.",
|
||||
default=graphrag_config_defaults.storage.type,
|
||||
)
|
||||
base_dir: str = Field(
|
||||
description="The base directory for the output.",
|
||||
default=graphrag_config_defaults.storage.base_dir,
|
||||
)
|
||||
|
||||
# Validate the base dir for multiple OS (use Path)
|
||||
# if not using a cloud storage type.
|
||||
@field_validator("base_dir", mode="before")
|
||||
@classmethod
|
||||
def validate_base_dir(cls, value, info):
|
||||
"""Ensure that base_dir is a valid filesystem path when using local storage."""
|
||||
# info.data contains other field values, including 'type'
|
||||
if info.data.get("type") != StorageType.file:
|
||||
return value
|
||||
return str(Path(value))
|
||||
|
||||
connection_string: str | None = Field(
|
||||
description="The storage connection string to use.",
|
||||
default=graphrag_config_defaults.storage.connection_string,
|
||||
)
|
||||
container_name: str | None = Field(
|
||||
description="The storage container name to use.",
|
||||
default=graphrag_config_defaults.storage.container_name,
|
||||
)
|
||||
storage_account_blob_url: str | None = Field(
|
||||
description="The storage account blob url to use.",
|
||||
default=graphrag_config_defaults.storage.storage_account_blob_url,
|
||||
)
|
||||
cosmosdb_account_url: str | None = Field(
|
||||
description="The cosmosdb account url to use.",
|
||||
default=graphrag_config_defaults.storage.cosmosdb_account_url,
|
||||
)
|
||||
@ -39,7 +39,7 @@ class SummarizeDescriptionsConfig(BaseModel):
|
||||
self, root_dir: str, model_config: LanguageModelConfig
|
||||
) -> dict:
|
||||
"""Get the resolved description summarization strategy."""
|
||||
from graphrag.index.operations.summarize_descriptions import (
|
||||
from graphrag.index.operations.summarize_descriptions.summarize_descriptions import (
|
||||
SummarizeStrategyType,
|
||||
)
|
||||
|
||||
|
||||
@ -39,7 +39,7 @@ class TextEmbeddingConfig(BaseModel):
|
||||
|
||||
def resolved_strategy(self, model_config: LanguageModelConfig) -> dict:
|
||||
"""Get the resolved text embedding strategy."""
|
||||
from graphrag.index.operations.embed_text import (
|
||||
from graphrag.index.operations.embed_text.embed_text import (
|
||||
TextEmbedStrategyType,
|
||||
)
|
||||
|
||||
|
||||
@ -6,7 +6,9 @@
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
from graphrag.config.defaults import vector_store_defaults
|
||||
from graphrag.vector_stores.factory import VectorStoreType
|
||||
from graphrag.config.embeddings import all_embeddings
|
||||
from graphrag.config.enums import VectorStoreType
|
||||
from graphrag.config.models.vector_store_schema_config import VectorStoreSchemaConfig
|
||||
|
||||
|
||||
class VectorStoreConfig(BaseModel):
|
||||
@ -85,9 +87,25 @@ class VectorStoreConfig(BaseModel):
|
||||
default=vector_store_defaults.overwrite,
|
||||
)
|
||||
|
||||
embeddings_schema: dict[str, VectorStoreSchemaConfig] = {}
|
||||
|
||||
def _validate_embeddings_schema(self) -> None:
|
||||
"""Validate the embeddings schema."""
|
||||
for name in self.embeddings_schema:
|
||||
if name not in all_embeddings:
|
||||
msg = f"vector_store.embeddings_schema contains an invalid embedding schema name: {name}. Please update your settings.yaml and select the correct embedding schema names."
|
||||
raise ValueError(msg)
|
||||
|
||||
if self.type == VectorStoreType.CosmosDB:
|
||||
for id_field in self.embeddings_schema:
|
||||
if id_field != "id":
|
||||
msg = "When using CosmosDB, the id_field in embeddings_schema must be 'id'. Please update your settings.yaml and set the id_field to 'id'."
|
||||
raise ValueError(msg)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _validate_model(self):
|
||||
"""Validate the model."""
|
||||
self._validate_db_uri()
|
||||
self._validate_url()
|
||||
self._validate_embeddings_schema()
|
||||
return self
|
||||
|
||||
66
graphrag/config/models/vector_store_schema_config.py
Normal file
66
graphrag/config/models/vector_store_schema_config.py
Normal file
@ -0,0 +1,66 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""Parameterization settings for the default configuration."""
|
||||
|
||||
import re
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
DEFAULT_VECTOR_SIZE: int = 1536
|
||||
|
||||
VALID_IDENTIFIER_REGEX = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$")
|
||||
|
||||
|
||||
def is_valid_field_name(field: str) -> bool:
|
||||
"""Check if a field name is valid for CosmosDB."""
|
||||
return bool(VALID_IDENTIFIER_REGEX.match(field))
|
||||
|
||||
|
||||
class VectorStoreSchemaConfig(BaseModel):
|
||||
"""The default configuration section for Vector Store Schema."""
|
||||
|
||||
id_field: str = Field(
|
||||
description="The ID field to use.",
|
||||
default="id",
|
||||
)
|
||||
|
||||
vector_field: str = Field(
|
||||
description="The vector field to use.",
|
||||
default="vector",
|
||||
)
|
||||
|
||||
text_field: str = Field(
|
||||
description="The text field to use.",
|
||||
default="text",
|
||||
)
|
||||
|
||||
attributes_field: str = Field(
|
||||
description="The attributes field to use.",
|
||||
default="attributes",
|
||||
)
|
||||
|
||||
vector_size: int = Field(
|
||||
description="The vector size to use.",
|
||||
default=DEFAULT_VECTOR_SIZE,
|
||||
)
|
||||
|
||||
index_name: str | None = Field(description="The index name to use.", default=None)
|
||||
|
||||
def _validate_schema(self) -> None:
|
||||
"""Validate the schema."""
|
||||
for field in [
|
||||
self.id_field,
|
||||
self.vector_field,
|
||||
self.text_field,
|
||||
self.attributes_field,
|
||||
]:
|
||||
if not is_valid_field_name(field):
|
||||
msg = f"Unsafe or invalid field name: {field}"
|
||||
raise ValueError(msg)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _validate_model(self):
|
||||
"""Validate the model."""
|
||||
self._validate_schema()
|
||||
return self
|
||||
@ -9,17 +9,17 @@ from pathlib import Path
|
||||
|
||||
from dotenv import dotenv_values
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def read_dotenv(root: str) -> None:
|
||||
"""Read a .env file in the given root path."""
|
||||
env_path = Path(root) / ".env"
|
||||
if env_path.exists():
|
||||
log.info("Loading pipeline .env file")
|
||||
logger.info("Loading pipeline .env file")
|
||||
env_config = dotenv_values(f"{env_path}")
|
||||
for key, value in env_config.items():
|
||||
if key not in os.environ:
|
||||
os.environ[key] = value or ""
|
||||
else:
|
||||
log.info("No .env file found at %s", root)
|
||||
logger.info("No .env file found at %s", root)
|
||||
|
||||
4
graphrag/factory/__init__.py
Normal file
4
graphrag/factory/__init__.py
Normal file
@ -0,0 +1,4 @@
|
||||
# Copyright (c) 2025 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""Factory module."""
|
||||
68
graphrag/factory/factory.py
Normal file
68
graphrag/factory/factory.py
Normal file
@ -0,0 +1,68 @@
|
||||
# Copyright (c) 2025 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""Factory ABC."""
|
||||
|
||||
from abc import ABC
|
||||
from collections.abc import Callable
|
||||
from typing import Any, ClassVar, Generic, TypeVar
|
||||
|
||||
T = TypeVar("T", covariant=True)
|
||||
|
||||
|
||||
class Factory(ABC, Generic[T]):
|
||||
"""Abstract base class for factories."""
|
||||
|
||||
_instance: ClassVar["Factory | None"] = None
|
||||
|
||||
def __new__(cls, *args: Any, **kwargs: Any) -> "Factory":
|
||||
"""Create a new instance of Factory if it does not exist."""
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls, *args, **kwargs)
|
||||
return cls._instance
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(self, "_initialized"):
|
||||
self._services: dict[str, Callable[..., T]] = {}
|
||||
self._initialized = True
|
||||
|
||||
def __contains__(self, strategy: str) -> bool:
|
||||
"""Check if a strategy is registered."""
|
||||
return strategy in self._services
|
||||
|
||||
def keys(self) -> list[str]:
|
||||
"""Get a list of registered strategy names."""
|
||||
return list(self._services.keys())
|
||||
|
||||
def register(self, *, strategy: str, service_initializer: Callable[..., T]) -> None:
|
||||
"""
|
||||
Register a new service.
|
||||
|
||||
Args
|
||||
----
|
||||
strategy: The name of the strategy.
|
||||
service_initializer: A callable that creates an instance of T.
|
||||
"""
|
||||
self._services[strategy] = service_initializer
|
||||
|
||||
def create(self, *, strategy: str, **kwargs: Any) -> T:
|
||||
"""
|
||||
Create a service instance based on the strategy.
|
||||
|
||||
Args
|
||||
----
|
||||
strategy: The name of the strategy.
|
||||
**kwargs: Additional arguments to pass to the service initializer.
|
||||
|
||||
Returns
|
||||
-------
|
||||
An instance of T.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError: If the strategy is not registered.
|
||||
"""
|
||||
if strategy not in self._services:
|
||||
msg = f"Strategy '{strategy}' is not registered."
|
||||
raise ValueError(msg)
|
||||
return self._services[strategy](**kwargs)
|
||||
@ -10,19 +10,17 @@ import pandas as pd
|
||||
|
||||
from graphrag.config.models.input_config import InputConfig
|
||||
from graphrag.index.input.util import load_files, process_data_columns
|
||||
from graphrag.logger.base import ProgressLogger
|
||||
from graphrag.storage.pipeline_storage import PipelineStorage
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def load_csv(
|
||||
config: InputConfig,
|
||||
progress: ProgressLogger | None,
|
||||
storage: PipelineStorage,
|
||||
) -> pd.DataFrame:
|
||||
"""Load csv inputs from a directory."""
|
||||
log.info("Loading csv files from %s", config.base_dir)
|
||||
logger.info("Loading csv files from %s", config.storage.base_dir)
|
||||
|
||||
async def load_file(path: str, group: dict | None) -> pd.DataFrame:
|
||||
if group is None:
|
||||
@ -42,4 +40,4 @@ async def load_csv(
|
||||
|
||||
return data
|
||||
|
||||
return await load_files(load_file, config, storage, progress)
|
||||
return await load_files(load_file, config, storage)
|
||||
|
||||
@ -5,22 +5,18 @@
|
||||
|
||||
import logging
|
||||
from collections.abc import Awaitable, Callable
|
||||
from pathlib import Path
|
||||
from typing import cast
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from graphrag.config.enums import InputFileType, InputType
|
||||
from graphrag.config.enums import InputFileType
|
||||
from graphrag.config.models.input_config import InputConfig
|
||||
from graphrag.index.input.csv import load_csv
|
||||
from graphrag.index.input.json import load_json
|
||||
from graphrag.index.input.text import load_text
|
||||
from graphrag.logger.base import ProgressLogger
|
||||
from graphrag.logger.null_progress import NullProgressLogger
|
||||
from graphrag.storage.blob_pipeline_storage import BlobPipelineStorage
|
||||
from graphrag.storage.file_pipeline_storage import FilePipelineStorage
|
||||
from graphrag.storage.pipeline_storage import PipelineStorage
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
logger = logging.getLogger(__name__)
|
||||
loaders: dict[str, Callable[..., Awaitable[pd.DataFrame]]] = {
|
||||
InputFileType.text: load_text,
|
||||
InputFileType.csv: load_csv,
|
||||
@ -30,49 +26,15 @@ loaders: dict[str, Callable[..., Awaitable[pd.DataFrame]]] = {
|
||||
|
||||
async def create_input(
|
||||
config: InputConfig,
|
||||
progress_reporter: ProgressLogger | None = None,
|
||||
root_dir: str | None = None,
|
||||
storage: PipelineStorage,
|
||||
) -> pd.DataFrame:
|
||||
"""Instantiate input data for a pipeline."""
|
||||
root_dir = root_dir or ""
|
||||
log.info("loading input from root_dir=%s", config.base_dir)
|
||||
progress_reporter = progress_reporter or NullProgressLogger()
|
||||
|
||||
match config.type:
|
||||
case InputType.blob:
|
||||
log.info("using blob storage input")
|
||||
if config.container_name is None:
|
||||
msg = "Container name required for blob storage"
|
||||
raise ValueError(msg)
|
||||
if (
|
||||
config.connection_string is None
|
||||
and config.storage_account_blob_url is None
|
||||
):
|
||||
msg = "Connection string or storage account blob url required for blob storage"
|
||||
raise ValueError(msg)
|
||||
storage = BlobPipelineStorage(
|
||||
connection_string=config.connection_string,
|
||||
storage_account_blob_url=config.storage_account_blob_url,
|
||||
container_name=config.container_name,
|
||||
path_prefix=config.base_dir,
|
||||
)
|
||||
case InputType.file:
|
||||
log.info("using file storage for input")
|
||||
storage = FilePipelineStorage(
|
||||
root_dir=str(Path(root_dir) / (config.base_dir or ""))
|
||||
)
|
||||
case _:
|
||||
log.info("using file storage for input")
|
||||
storage = FilePipelineStorage(
|
||||
root_dir=str(Path(root_dir) / (config.base_dir or ""))
|
||||
)
|
||||
logger.info("loading input from root_dir=%s", config.storage.base_dir)
|
||||
|
||||
if config.file_type in loaders:
|
||||
progress = progress_reporter.child(
|
||||
f"Loading Input ({config.file_type})", transient=False
|
||||
)
|
||||
logger.info("Loading Input %s", config.file_type)
|
||||
loader = loaders[config.file_type]
|
||||
result = await loader(config, progress, storage)
|
||||
result = await loader(config, storage)
|
||||
# Convert metadata columns to strings and collapse them into a JSON object
|
||||
if config.metadata:
|
||||
if all(col in result.columns for col in config.metadata):
|
||||
|
||||
@ -10,19 +10,17 @@ import pandas as pd
|
||||
|
||||
from graphrag.config.models.input_config import InputConfig
|
||||
from graphrag.index.input.util import load_files, process_data_columns
|
||||
from graphrag.logger.base import ProgressLogger
|
||||
from graphrag.storage.pipeline_storage import PipelineStorage
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def load_json(
|
||||
config: InputConfig,
|
||||
progress: ProgressLogger | None,
|
||||
storage: PipelineStorage,
|
||||
) -> pd.DataFrame:
|
||||
"""Load json inputs from a directory."""
|
||||
log.info("Loading json files from %s", config.base_dir)
|
||||
logger.info("Loading json files from %s", config.storage.base_dir)
|
||||
|
||||
async def load_file(path: str, group: dict | None) -> pd.DataFrame:
|
||||
if group is None:
|
||||
@ -46,4 +44,4 @@ async def load_json(
|
||||
|
||||
return data
|
||||
|
||||
return await load_files(load_file, config, storage, progress)
|
||||
return await load_files(load_file, config, storage)
|
||||
|
||||
@ -11,15 +11,13 @@ import pandas as pd
|
||||
from graphrag.config.models.input_config import InputConfig
|
||||
from graphrag.index.input.util import load_files
|
||||
from graphrag.index.utils.hashing import gen_sha512_hash
|
||||
from graphrag.logger.base import ProgressLogger
|
||||
from graphrag.storage.pipeline_storage import PipelineStorage
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def load_text(
|
||||
config: InputConfig,
|
||||
progress: ProgressLogger | None,
|
||||
storage: PipelineStorage,
|
||||
) -> pd.DataFrame:
|
||||
"""Load text inputs from a directory."""
|
||||
@ -34,4 +32,4 @@ async def load_text(
|
||||
new_item["creation_date"] = await storage.get_creation_date(path)
|
||||
return pd.DataFrame([new_item])
|
||||
|
||||
return await load_files(load_file, config, storage, progress)
|
||||
return await load_files(load_file, config, storage)
|
||||
|
||||
@ -11,29 +11,26 @@ import pandas as pd
|
||||
|
||||
from graphrag.config.models.input_config import InputConfig
|
||||
from graphrag.index.utils.hashing import gen_sha512_hash
|
||||
from graphrag.logger.base import ProgressLogger
|
||||
from graphrag.storage.pipeline_storage import PipelineStorage
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def load_files(
|
||||
loader: Any,
|
||||
config: InputConfig,
|
||||
storage: PipelineStorage,
|
||||
progress: ProgressLogger | None,
|
||||
) -> pd.DataFrame:
|
||||
"""Load files from storage and apply a loader function."""
|
||||
files = list(
|
||||
storage.find(
|
||||
re.compile(config.file_pattern),
|
||||
progress=progress,
|
||||
file_filter=config.file_filter,
|
||||
)
|
||||
)
|
||||
|
||||
if len(files) == 0:
|
||||
msg = f"No {config.file_type} files found in {config.base_dir}"
|
||||
msg = f"No {config.file_type} files found in {config.storage.base_dir}"
|
||||
raise ValueError(msg)
|
||||
|
||||
files_loaded = []
|
||||
@ -42,17 +39,17 @@ async def load_files(
|
||||
try:
|
||||
files_loaded.append(await loader(file, group))
|
||||
except Exception as e: # noqa: BLE001 (catching Exception is fine here)
|
||||
log.warning("Warning! Error loading file %s. Skipping...", file)
|
||||
log.warning("Error: %s", e)
|
||||
logger.warning("Warning! Error loading file %s. Skipping...", file)
|
||||
logger.warning("Error: %s", e)
|
||||
|
||||
log.info(
|
||||
logger.info(
|
||||
"Found %d %s files, loading %d", len(files), config.file_type, len(files_loaded)
|
||||
)
|
||||
result = pd.concat(files_loaded)
|
||||
total_files_log = (
|
||||
f"Total number of unfiltered {config.file_type} rows: {len(result)}"
|
||||
)
|
||||
log.info(total_files_log)
|
||||
logger.info(total_files_log)
|
||||
return result
|
||||
|
||||
|
||||
@ -66,7 +63,7 @@ def process_data_columns(
|
||||
)
|
||||
if config.text_column is not None and "text" not in documents.columns:
|
||||
if config.text_column not in documents.columns:
|
||||
log.warning(
|
||||
logger.warning(
|
||||
"text_column %s not found in csv file %s",
|
||||
config.text_column,
|
||||
path,
|
||||
@ -75,7 +72,7 @@ def process_data_columns(
|
||||
documents["text"] = documents.apply(lambda x: x[config.text_column], axis=1)
|
||||
if config.title_column is not None:
|
||||
if config.title_column not in documents.columns:
|
||||
log.warning(
|
||||
logger.warning(
|
||||
"title_column %s not found in csv file %s",
|
||||
config.title_column,
|
||||
path,
|
||||
|
||||
@ -24,12 +24,17 @@ async def build_noun_graph(
|
||||
text_analyzer: BaseNounPhraseExtractor,
|
||||
normalize_edge_weights: bool,
|
||||
num_threads: int = 4,
|
||||
async_mode: AsyncType = AsyncType.Threaded,
|
||||
cache: PipelineCache | None = None,
|
||||
) -> tuple[pd.DataFrame, pd.DataFrame]:
|
||||
"""Build a noun graph from text units."""
|
||||
text_units = text_unit_df.loc[:, ["id", "text"]]
|
||||
nodes_df = await _extract_nodes(
|
||||
text_units, text_analyzer, num_threads=num_threads, cache=cache
|
||||
text_units,
|
||||
text_analyzer,
|
||||
num_threads=num_threads,
|
||||
async_mode=async_mode,
|
||||
cache=cache,
|
||||
)
|
||||
edges_df = _extract_edges(nodes_df, normalize_edge_weights=normalize_edge_weights)
|
||||
return (nodes_df, edges_df)
|
||||
@ -39,6 +44,7 @@ async def _extract_nodes(
|
||||
text_unit_df: pd.DataFrame,
|
||||
text_analyzer: BaseNounPhraseExtractor,
|
||||
num_threads: int = 4,
|
||||
async_mode: AsyncType = AsyncType.Threaded,
|
||||
cache: PipelineCache | None = None,
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
@ -64,7 +70,8 @@ async def _extract_nodes(
|
||||
text_unit_df,
|
||||
extract,
|
||||
num_threads=num_threads,
|
||||
async_type=AsyncType.Threaded,
|
||||
async_type=async_mode,
|
||||
progress_msg="extract noun phrases progress: ",
|
||||
)
|
||||
|
||||
noun_node_df = text_unit_df.explode("noun_phrases")
|
||||
|
||||
@ -8,7 +8,7 @@ from abc import ABCMeta, abstractmethod
|
||||
|
||||
import spacy
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BaseNounPhraseExtractor(metaclass=ABCMeta):
|
||||
@ -54,7 +54,7 @@ class BaseNounPhraseExtractor(metaclass=ABCMeta):
|
||||
return spacy.load(model_name, exclude=exclude)
|
||||
except OSError:
|
||||
msg = f"Model `{model_name}` not found. Attempting to download..."
|
||||
log.info(msg)
|
||||
logger.info(msg)
|
||||
from spacy.cli.download import download
|
||||
|
||||
download(model_name)
|
||||
|
||||
@ -11,7 +11,7 @@ import tiktoken
|
||||
from graphrag.config.models.chunking_config import ChunkingConfig
|
||||
from graphrag.index.operations.chunk_text.typing import TextChunk
|
||||
from graphrag.index.text_splitting.text_splitting import (
|
||||
Tokenizer,
|
||||
TokenChunkerOptions,
|
||||
split_multiple_texts_on_tokens,
|
||||
)
|
||||
from graphrag.logger.progress import ProgressTicker
|
||||
@ -45,7 +45,7 @@ def run_tokens(
|
||||
encode, decode = get_encoding_fn(encoding_name)
|
||||
return split_multiple_texts_on_tokens(
|
||||
input,
|
||||
Tokenizer(
|
||||
TokenChunkerOptions(
|
||||
chunk_overlap=chunk_overlap,
|
||||
tokens_per_chunk=tokens_per_chunk,
|
||||
encode=encode,
|
||||
|
||||
@ -13,7 +13,7 @@ from graphrag.index.utils.stable_lcc import stable_largest_connected_component
|
||||
Communities = list[tuple[int, int, int, list[str]]]
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def cluster_graph(
|
||||
@ -24,7 +24,7 @@ def cluster_graph(
|
||||
) -> Communities:
|
||||
"""Apply a hierarchical clustering algorithm to a graph."""
|
||||
if len(graph.nodes) == 0:
|
||||
log.warning("Graph has no nodes")
|
||||
logger.warning("Graph has no nodes")
|
||||
return []
|
||||
|
||||
node_id_to_community_map, parent_mapping = _compute_leiden_communities(
|
||||
|
||||
@ -2,10 +2,3 @@
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""The Indexing Engine text embed package root."""
|
||||
|
||||
from graphrag.index.operations.embed_text.embed_text import (
|
||||
TextEmbedStrategyType,
|
||||
embed_text,
|
||||
)
|
||||
|
||||
__all__ = ["TextEmbedStrategyType", "embed_text"]
|
||||
|
||||
@ -12,12 +12,13 @@ import pandas as pd
|
||||
|
||||
from graphrag.cache.pipeline_cache import PipelineCache
|
||||
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
|
||||
from graphrag.config.embeddings import create_collection_name
|
||||
from graphrag.config.embeddings import create_index_name
|
||||
from graphrag.config.models.vector_store_schema_config import VectorStoreSchemaConfig
|
||||
from graphrag.index.operations.embed_text.strategies.typing import TextEmbeddingStrategy
|
||||
from graphrag.vector_stores.base import BaseVectorStore, VectorStoreDocument
|
||||
from graphrag.vector_stores.factory import VectorStoreFactory
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Per Azure OpenAI Limits
|
||||
# https://learn.microsoft.com/en-us/azure/ai-services/openai/reference
|
||||
@ -49,9 +50,9 @@ async def embed_text(
|
||||
vector_store_config = strategy.get("vector_store")
|
||||
|
||||
if vector_store_config:
|
||||
collection_name = _get_collection_name(vector_store_config, embedding_name)
|
||||
index_name = _get_index_name(vector_store_config, embedding_name)
|
||||
vector_store: BaseVectorStore = _create_vector_store(
|
||||
vector_store_config, collection_name
|
||||
vector_store_config, index_name, embedding_name
|
||||
)
|
||||
vector_store_workflow_config = vector_store_config.get(
|
||||
embedding_name, vector_store_config
|
||||
@ -141,7 +142,14 @@ async def _text_embed_with_vector_store(
|
||||
|
||||
all_results = []
|
||||
|
||||
num_total_batches = (input.shape[0] + insert_batch_size - 1) // insert_batch_size
|
||||
while insert_batch_size * i < input.shape[0]:
|
||||
logger.info(
|
||||
"uploading text embeddings batch %d/%d of size %d to vector store",
|
||||
i + 1,
|
||||
num_total_batches,
|
||||
insert_batch_size,
|
||||
)
|
||||
batch = input.iloc[insert_batch_size * i : insert_batch_size * (i + 1)]
|
||||
texts: list[str] = batch[embed_column].to_numpy().tolist()
|
||||
titles: list[str] = batch[title].to_numpy().tolist()
|
||||
@ -176,27 +184,46 @@ async def _text_embed_with_vector_store(
|
||||
|
||||
|
||||
def _create_vector_store(
|
||||
vector_store_config: dict, collection_name: str
|
||||
vector_store_config: dict, index_name: str, embedding_name: str | None = None
|
||||
) -> BaseVectorStore:
|
||||
vector_store_type: str = str(vector_store_config.get("type"))
|
||||
if collection_name:
|
||||
vector_store_config.update({"collection_name": collection_name})
|
||||
|
||||
embeddings_schema: dict[str, VectorStoreSchemaConfig] = vector_store_config.get(
|
||||
"embeddings_schema", {}
|
||||
)
|
||||
single_embedding_config: VectorStoreSchemaConfig = VectorStoreSchemaConfig()
|
||||
|
||||
if (
|
||||
embeddings_schema is not None
|
||||
and embedding_name is not None
|
||||
and embedding_name in embeddings_schema
|
||||
):
|
||||
raw_config = embeddings_schema[embedding_name]
|
||||
if isinstance(raw_config, dict):
|
||||
single_embedding_config = VectorStoreSchemaConfig(**raw_config)
|
||||
else:
|
||||
single_embedding_config = raw_config
|
||||
|
||||
if single_embedding_config.index_name is None:
|
||||
single_embedding_config.index_name = index_name
|
||||
|
||||
vector_store = VectorStoreFactory().create_vector_store(
|
||||
vector_store_type, kwargs=vector_store_config
|
||||
vector_store_schema_config=single_embedding_config,
|
||||
vector_store_type=vector_store_type,
|
||||
**vector_store_config,
|
||||
)
|
||||
|
||||
vector_store.connect(**vector_store_config)
|
||||
return vector_store
|
||||
|
||||
|
||||
def _get_collection_name(vector_store_config: dict, embedding_name: str) -> str:
|
||||
def _get_index_name(vector_store_config: dict, embedding_name: str) -> str:
|
||||
container_name = vector_store_config.get("container_name", "default")
|
||||
collection_name = create_collection_name(container_name, embedding_name)
|
||||
index_name = create_index_name(container_name, embedding_name)
|
||||
|
||||
msg = f"using vector store {vector_store_config.get('type')} with container_name {container_name} for embedding {embedding_name}: {collection_name}"
|
||||
log.info(msg)
|
||||
return collection_name
|
||||
msg = f"using vector store {vector_store_config.get('type')} with container_name {container_name} for embedding {embedding_name}: {index_name}"
|
||||
logger.info(msg)
|
||||
return index_name
|
||||
|
||||
|
||||
def load_strategy(strategy: TextEmbedStrategyType) -> TextEmbeddingStrategy:
|
||||
|
||||
@ -21,7 +21,9 @@ async def run( # noqa RUF029 async is required for interface
|
||||
) -> TextEmbeddingResult:
|
||||
"""Run the Claim extraction chain."""
|
||||
input = input if isinstance(input, Iterable) else [input]
|
||||
ticker = progress_ticker(callbacks.progress, len(input))
|
||||
ticker = progress_ticker(
|
||||
callbacks.progress, len(input), description="generate embeddings progress: "
|
||||
)
|
||||
return TextEmbeddingResult(
|
||||
embeddings=[_embed_text(cache, text, ticker) for text in input]
|
||||
)
|
||||
|
||||
@ -18,8 +18,9 @@ from graphrag.index.utils.is_null import is_null
|
||||
from graphrag.language_model.manager import ModelManager
|
||||
from graphrag.language_model.protocol.base import EmbeddingModel
|
||||
from graphrag.logger.progress import ProgressTicker, progress_ticker
|
||||
from graphrag.tokenizer.get_tokenizer import get_tokenizer
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def run(
|
||||
@ -54,7 +55,7 @@ async def run(
|
||||
batch_max_tokens,
|
||||
splitter,
|
||||
)
|
||||
log.info(
|
||||
logger.info(
|
||||
"embedding %d inputs via %d snippets using %d batches. max_batch_size=%d, batch_max_tokens=%d",
|
||||
len(input),
|
||||
len(texts),
|
||||
@ -62,7 +63,11 @@ async def run(
|
||||
batch_size,
|
||||
batch_max_tokens,
|
||||
)
|
||||
ticker = progress_ticker(callbacks.progress, len(text_batches))
|
||||
ticker = progress_ticker(
|
||||
callbacks.progress,
|
||||
len(text_batches),
|
||||
description="generate embeddings progress: ",
|
||||
)
|
||||
|
||||
# Embed each chunk of snippets
|
||||
embeddings = await _execute(model, text_batches, ticker, semaphore)
|
||||
@ -75,7 +80,7 @@ def _get_splitter(
|
||||
config: LanguageModelConfig, batch_max_tokens: int
|
||||
) -> TokenTextSplitter:
|
||||
return TokenTextSplitter(
|
||||
encoding_name=config.encoding_model,
|
||||
tokenizer=get_tokenizer(model_config=config),
|
||||
chunk_size=batch_max_tokens,
|
||||
)
|
||||
|
||||
|
||||
@ -20,7 +20,7 @@ from graphrag.prompts.index.extract_claims import (
|
||||
DEFAULT_TUPLE_DELIMITER = "<|>"
|
||||
DEFAULT_RECORD_DELIMITER = "##"
|
||||
DEFAULT_COMPLETION_DELIMITER = "<|COMPLETE|>"
|
||||
log = logging.getLogger(__name__)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -119,7 +119,7 @@ class ClaimExtractor:
|
||||
]
|
||||
source_doc_map[document_id] = text
|
||||
except Exception as e:
|
||||
log.exception("error extracting claim")
|
||||
logger.exception("error extracting claim")
|
||||
self._on_error(
|
||||
e,
|
||||
traceback.format_exc(),
|
||||
|
||||
@ -23,7 +23,7 @@ from graphrag.index.operations.extract_covariates.typing import (
|
||||
from graphrag.index.utils.derive_from_rows import derive_from_rows
|
||||
from graphrag.language_model.manager import ModelManager
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
DEFAULT_ENTITY_TYPES = ["organization", "person", "geo", "event"]
|
||||
@ -41,7 +41,7 @@ async def extract_covariates(
|
||||
num_threads: int = 4,
|
||||
):
|
||||
"""Extract claims from a piece of text."""
|
||||
log.debug("extract_covariates strategy=%s", strategy)
|
||||
logger.debug("extract_covariates strategy=%s", strategy)
|
||||
if entity_types is None:
|
||||
entity_types = DEFAULT_ENTITY_TYPES
|
||||
|
||||
@ -71,6 +71,7 @@ async def extract_covariates(
|
||||
callbacks,
|
||||
async_type=async_mode,
|
||||
num_threads=num_threads,
|
||||
progress_msg="extract covariates progress: ",
|
||||
)
|
||||
return pd.DataFrame([item for row in results for item in row or []])
|
||||
|
||||
@ -110,8 +111,8 @@ async def run_extract_claims(
|
||||
model_invoker=llm,
|
||||
extraction_prompt=extraction_prompt,
|
||||
max_gleanings=max_gleanings,
|
||||
on_error=lambda e, s, d: (
|
||||
callbacks.error("Claim Extraction Error", e, s, d) if callbacks else None
|
||||
on_error=lambda e, s, d: logger.error(
|
||||
"Claim Extraction Error", exc_info=e, extra={"stack": s, "details": d}
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@ -18,7 +18,7 @@ from graphrag.index.operations.extract_graph.typing import (
|
||||
)
|
||||
from graphrag.index.utils.derive_from_rows import derive_from_rows
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
DEFAULT_ENTITY_TYPES = ["organization", "person", "geo", "event"]
|
||||
@ -36,7 +36,7 @@ async def extract_graph(
|
||||
num_threads: int = 4,
|
||||
) -> tuple[pd.DataFrame, pd.DataFrame]:
|
||||
"""Extract a graph from a piece of text using a language model."""
|
||||
log.debug("entity_extract strategy=%s", strategy)
|
||||
logger.debug("entity_extract strategy=%s", strategy)
|
||||
if entity_types is None:
|
||||
entity_types = DEFAULT_ENTITY_TYPES
|
||||
strategy = strategy or {}
|
||||
@ -54,7 +54,6 @@ async def extract_graph(
|
||||
result = await strategy_exec(
|
||||
[Document(text=text, id=id)],
|
||||
entity_types,
|
||||
callbacks,
|
||||
cache,
|
||||
strategy_config,
|
||||
)
|
||||
@ -67,6 +66,7 @@ async def extract_graph(
|
||||
callbacks,
|
||||
async_type=async_mode,
|
||||
num_threads=num_threads,
|
||||
progress_msg="extract graph progress: ",
|
||||
)
|
||||
|
||||
entity_dfs = []
|
||||
|
||||
@ -27,7 +27,7 @@ DEFAULT_RECORD_DELIMITER = "##"
|
||||
DEFAULT_COMPLETION_DELIMITER = "<|COMPLETE|>"
|
||||
DEFAULT_ENTITY_TYPES = ["organization", "person", "geo", "event"]
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -119,7 +119,7 @@ class GraphExtractor:
|
||||
source_doc_map[doc_index] = text
|
||||
all_records[doc_index] = result
|
||||
except Exception as e:
|
||||
log.exception("error extracting graph")
|
||||
logger.exception("error extracting graph")
|
||||
self._on_error(
|
||||
e,
|
||||
traceback.format_exc(),
|
||||
|
||||
@ -3,10 +3,11 @@
|
||||
|
||||
"""A module containing run_graph_intelligence, run_extract_graph and _create_text_splitter methods to run graph intelligence."""
|
||||
|
||||
import logging
|
||||
|
||||
import networkx as nx
|
||||
|
||||
from graphrag.cache.pipeline_cache import PipelineCache
|
||||
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
|
||||
from graphrag.config.defaults import graphrag_config_defaults
|
||||
from graphrag.config.models.language_model_config import LanguageModelConfig
|
||||
from graphrag.index.operations.extract_graph.graph_extractor import GraphExtractor
|
||||
@ -19,11 +20,12 @@ from graphrag.index.operations.extract_graph.typing import (
|
||||
from graphrag.language_model.manager import ModelManager
|
||||
from graphrag.language_model.protocol.base import ChatModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def run_graph_intelligence(
|
||||
docs: list[Document],
|
||||
entity_types: EntityTypes,
|
||||
callbacks: WorkflowCallbacks,
|
||||
cache: PipelineCache,
|
||||
args: StrategyConfig,
|
||||
) -> EntityExtractionResult:
|
||||
@ -34,18 +36,16 @@ async def run_graph_intelligence(
|
||||
name="extract_graph",
|
||||
model_type=llm_config.type,
|
||||
config=llm_config,
|
||||
callbacks=callbacks,
|
||||
cache=cache,
|
||||
)
|
||||
|
||||
return await run_extract_graph(llm, docs, entity_types, callbacks, args)
|
||||
return await run_extract_graph(llm, docs, entity_types, args)
|
||||
|
||||
|
||||
async def run_extract_graph(
|
||||
model: ChatModel,
|
||||
docs: list[Document],
|
||||
entity_types: EntityTypes,
|
||||
callbacks: WorkflowCallbacks | None,
|
||||
args: StrategyConfig,
|
||||
) -> EntityExtractionResult:
|
||||
"""Run the entity extraction chain."""
|
||||
@ -61,8 +61,8 @@ async def run_extract_graph(
|
||||
model_invoker=model,
|
||||
prompt=extraction_prompt,
|
||||
max_gleanings=max_gleanings,
|
||||
on_error=lambda e, s, d: (
|
||||
callbacks.error("Entity Extraction Error", e, s, d) if callbacks else None
|
||||
on_error=lambda e, s, d: logger.error(
|
||||
"Entity Extraction Error", exc_info=e, extra={"stack": s, "details": d}
|
||||
),
|
||||
)
|
||||
text_list = [doc.text.strip() for doc in docs]
|
||||
|
||||
@ -11,7 +11,6 @@ from typing import Any
|
||||
import networkx as nx
|
||||
|
||||
from graphrag.cache.pipeline_cache import PipelineCache
|
||||
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
|
||||
|
||||
ExtractedEntity = dict[str, Any]
|
||||
ExtractedRelationship = dict[str, Any]
|
||||
@ -40,7 +39,6 @@ EntityExtractStrategy = Callable[
|
||||
[
|
||||
list[Document],
|
||||
EntityTypes,
|
||||
WorkflowCallbacks,
|
||||
PipelineCache,
|
||||
StrategyConfig,
|
||||
],
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user