mirror of
https://github.com/microsoft/graphrag.git
synced 2026-01-13 16:47:20 +08:00
Unified search added to graphrag (#1862)
Some checks are pending
gh-pages / build (push) Waiting to run
Python CI / python-ci (ubuntu-latest, 3.10) (push) Waiting to run
Python CI / python-ci (ubuntu-latest, 3.11) (push) Waiting to run
Python CI / python-ci (windows-latest, 3.10) (push) Waiting to run
Python CI / python-ci (windows-latest, 3.11) (push) Waiting to run
Python Integration Tests / python-ci (ubuntu-latest, 3.10) (push) Waiting to run
Python Integration Tests / python-ci (windows-latest, 3.10) (push) Waiting to run
Python Notebook Tests / python-ci (ubuntu-latest, 3.10) (push) Waiting to run
Python Notebook Tests / python-ci (windows-latest, 3.10) (push) Waiting to run
Python Publish (pypi) / Upload release to PyPI (push) Waiting to run
Python Smoke Tests / python-ci (ubuntu-latest, 3.10) (push) Waiting to run
Python Smoke Tests / python-ci (windows-latest, 3.10) (push) Waiting to run
Spellcheck / spellcheck (push) Waiting to run
Some checks are pending
gh-pages / build (push) Waiting to run
Python CI / python-ci (ubuntu-latest, 3.10) (push) Waiting to run
Python CI / python-ci (ubuntu-latest, 3.11) (push) Waiting to run
Python CI / python-ci (windows-latest, 3.10) (push) Waiting to run
Python CI / python-ci (windows-latest, 3.11) (push) Waiting to run
Python Integration Tests / python-ci (ubuntu-latest, 3.10) (push) Waiting to run
Python Integration Tests / python-ci (windows-latest, 3.10) (push) Waiting to run
Python Notebook Tests / python-ci (ubuntu-latest, 3.10) (push) Waiting to run
Python Notebook Tests / python-ci (windows-latest, 3.10) (push) Waiting to run
Python Publish (pypi) / Upload release to PyPI (push) Waiting to run
Python Smoke Tests / python-ci (ubuntu-latest, 3.10) (push) Waiting to run
Python Smoke Tests / python-ci (windows-latest, 3.10) (push) Waiting to run
Spellcheck / spellcheck (push) Waiting to run
* unified search app added to graphrag repository * ignore print statements * update words for unified-search * fix lint errors * fix lint error * fix module name --------- Co-authored-by: Gaudy Blanco <gaudy-microsoft@MacBook-Pro-m4-Gaudy-For-Work.local>
This commit is contained in:
parent
61769dd47e
commit
0e1a6e3770
@ -200,3 +200,8 @@ unnavigated
|
||||
# Names
|
||||
Hochul
|
||||
Ashish
|
||||
|
||||
#unified-search
|
||||
apos
|
||||
dearmor
|
||||
venv
|
||||
20
unified-search-app/Dockerfile
Normal file
20
unified-search-app/Dockerfile
Normal file
@ -0,0 +1,20 @@
|
||||
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Dockerfile
|
||||
# https://eng.ms/docs/more/containers-secure-supply-chain/approved-images
|
||||
FROM mcr.microsoft.com/oryx/python:3.11
|
||||
|
||||
RUN curl -fsSL https://packages.microsoft.com/keys/microsoft.asc | gpg --dearmor -o /usr/share/keyrings/microsoft-prod.gpg
|
||||
RUN apt-get update -y
|
||||
|
||||
# Install dependencies
|
||||
WORKDIR ./
|
||||
COPY . .
|
||||
RUN curl -sSL https://install.python-poetry.org | python -
|
||||
ENV PATH="${PATH}:/root/.local/bin"
|
||||
RUN poetry config virtualenvs.in-project true
|
||||
RUN poetry install --no-root
|
||||
|
||||
# Run application
|
||||
EXPOSE 8501
|
||||
ENTRYPOINT ["poetry","run","streamlit", "run", "./app/home_page.py"]
|
||||
127
unified-search-app/README.md
Normal file
127
unified-search-app/README.md
Normal file
@ -0,0 +1,127 @@
|
||||
# Unified Search
|
||||
Unified demo for GraphRAG search comparisons.
|
||||
|
||||
⚠️ This app is maintained for demo/experimental purposes and is not supported. Issue filings on the GraphRAG repo may not be addressed.
|
||||
|
||||
## Requirements:
|
||||
- Python 3.11
|
||||
- Poetry
|
||||
|
||||
This sample app is not published to pypi, so you'll need to clone the GraphRAG repo and run from this folder.
|
||||
|
||||
We recommend always using a virtual environment:
|
||||
|
||||
- `python -m venv ./venv`
|
||||
- `source ./venv/bin/activate`
|
||||
|
||||
## Run index
|
||||
Use GraphRAG to index your dataset before running Unified Search. We recommend starting with the [Getting Started guide](https://microsoft.github.io/graphrag/get_started/). You need to run GraphRAG indexing with graph embedding umap enabled to use the functionalities of Unified Search.
|
||||
``` yaml
|
||||
embed_graph:
|
||||
enabled: true
|
||||
|
||||
umap:
|
||||
enabled: true
|
||||
```
|
||||
|
||||
## Datasets
|
||||
Unified Search supports multiple GraphRAG indexes by using a directory listing file. Create a `listing.json` file in the root folder where all your datasets are stored (locally or in blob storage), with the following format (one entry per dataset):
|
||||
|
||||
```json
|
||||
[{
|
||||
"key": "<key_to_identify_dataset_1>",
|
||||
"path": "<path_to_dataset_1>",
|
||||
"name": "<name_to_identify_dataset_1>",
|
||||
"description": "<description_for_dataset_1>",
|
||||
"community_level": "<integer for community level you want to filter>"
|
||||
},{
|
||||
"key": "<key_to_identify_dataset_2>",
|
||||
"path": "<path_to_dataset_2>",
|
||||
"name": "<name_to_identify_dataset_2>",
|
||||
"description": "<description_for_dataset_2>",
|
||||
"community_level": "<integer for community level you want to filter>"
|
||||
}]
|
||||
```
|
||||
|
||||
For example, if you have a folder of GraphRAG indexes called "projects" and inside that you ran the Getting Started instructions, your listing.json in the projects folder could look like:
|
||||
```json
|
||||
[{
|
||||
"key": "ragtest-demo",
|
||||
"path": "ragtest",
|
||||
"name": "A Christmas Carol",
|
||||
"description": "Getting Started index of the novel A Christmas Carol",
|
||||
"community_level": 2
|
||||
}]
|
||||
```
|
||||
|
||||
### Data Source Configuration
|
||||
The expected format of the projects folder will be the following:
|
||||
- projects_folder
|
||||
- listing.json
|
||||
- dataset_1
|
||||
- settings.yaml
|
||||
- .env (optional if you declare your environment variables elsewhere)
|
||||
- output
|
||||
- prompts
|
||||
- dataset_2
|
||||
- settings.yaml
|
||||
- .env (optional if you declare your environment variables elsewhere)
|
||||
- output
|
||||
- prompts
|
||||
- ...
|
||||
|
||||
Note: Any other folder inside each dataset folder will be ignored but will not affect the app. Also, only the datasets declared inside listing.json will be used for Unified Search.
|
||||
|
||||
## Storing your datasets
|
||||
You can host Unified Search datasets locally or in a blob.
|
||||
|
||||
### 1. Local data folder
|
||||
1. Create a local folder with all your data and config as described above
|
||||
2. Tell the app where your folder is using an absolute path with the following environment variable:
|
||||
- `DATA_ROOT` = `<data_folder_absolute_path>`
|
||||
|
||||
### 2. Azure Blob Storage
|
||||
1. If you want to use Azure Blob Storage, create a blob storage account with a "data" container and upload all your data and config as described above
|
||||
2. Run `az login` and select an account that has read permissions on that storage
|
||||
3. You need to tell the app what blob account to use using the following environment variable:
|
||||
- `BLOB_ACCOUNT_NAME` = `<blob_storage_name>`
|
||||
4. (optional) In your blob account you need to create a container where your projects live. We default this to `data` as mentioned in step one, but if you want to use something else you can set:
|
||||
- `BLOB_CONTAINER_NAME` = `<blob_container_with_projects>`
|
||||
|
||||
|
||||
# Run the app
|
||||
|
||||
Install all the dependencies: `poetry install`
|
||||
|
||||
Run the project using streamlit: `poetry run poe start`
|
||||
|
||||
# How to use it
|
||||
|
||||

|
||||
|
||||
## Configuration panel (left panel)
|
||||
When you run the app you will see two main panels at the beginning. The left panel provides several configuration options for the app and this panel can be closed:
|
||||
1. **Datasets**: Here all the datasets you defined inside the listing.json file are shown in order inside the dropdown.
|
||||
2. **Number of suggested questions**: this option let the user to choose how many suggested question can be generated.
|
||||
3. **Search options**: This section allows to choose which searches to use in the app. At least one search should be enabled to use the app.
|
||||
|
||||
## Searches panel (right panel)
|
||||
In the right panel you have several functionalities.
|
||||
1. At the top you can see general information related to the chosen dataset (name and description).
|
||||
2. Below the dataset information there is a button labeled "Suggest some questions" which analyzes the dataset using global search and generates the most important questions (the number of questions generated is the amount set in the configuration panel). If you want to select a question generated you have to click the checkbox at the left side of the question to select it.
|
||||
3. A textbox that it is labeled as "Ask a question to compare the results" where you can type the question that you want to send.
|
||||
4. Two tabs called Search and Graph Explorer:
|
||||
1. Search: Here all the searches results are displayed with their citations.
|
||||
2. Graph Explorer: This tab is divided in three sections: Community Reports, Entity Graph and Selected Report.
|
||||
|
||||
##### Suggest some question clicked
|
||||

|
||||
|
||||
##### Selected question clicked
|
||||

|
||||
|
||||
##### Graph Explorer tab
|
||||

|
||||
|
||||
|
||||
|
||||
4
unified-search-app/app/__init__.py
Normal file
4
unified-search-app/app/__init__.py
Normal file
@ -0,0 +1,4 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""App module."""
|
||||
368
unified-search-app/app/app_logic.py
Normal file
368
unified-search-app/app/app_logic.py
Normal file
@ -0,0 +1,368 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""App logic module."""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import streamlit as st
|
||||
from knowledge_loader.data_sources.loader import (
|
||||
create_datasource,
|
||||
load_dataset_listing,
|
||||
)
|
||||
from knowledge_loader.model import load_model
|
||||
from rag.typing import SearchResult, SearchType
|
||||
from state.session_variables import SessionVariables
|
||||
from ui.search import display_search_result
|
||||
|
||||
import graphrag.api as api
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import pandas as pd
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logging.getLogger("azure").setLevel(logging.WARNING)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def initialize() -> SessionVariables:
|
||||
"""Initialize app logic."""
|
||||
if "session_variables" not in st.session_state:
|
||||
st.set_page_config(
|
||||
layout="wide",
|
||||
initial_sidebar_state="collapsed",
|
||||
page_title="GraphRAG",
|
||||
)
|
||||
sv = SessionVariables()
|
||||
datasets = load_dataset_listing()
|
||||
sv.datasets.value = datasets
|
||||
sv.dataset.value = (
|
||||
st.query_params["dataset"].lower()
|
||||
if "dataset" in st.query_params
|
||||
else datasets[0].key
|
||||
)
|
||||
load_dataset(sv.dataset.value, sv)
|
||||
st.session_state["session_variables"] = sv
|
||||
return st.session_state["session_variables"]
|
||||
|
||||
|
||||
def load_dataset(dataset: str, sv: SessionVariables):
|
||||
"""Load dataset from the dropdown."""
|
||||
sv.dataset.value = dataset
|
||||
sv.dataset_config.value = next(
|
||||
(d for d in sv.datasets.value if d.key == dataset), None
|
||||
)
|
||||
if sv.dataset_config.value is not None:
|
||||
sv.datasource.value = create_datasource(f"{sv.dataset_config.value.path}") # type: ignore
|
||||
sv.graphrag_config.value = sv.datasource.value.read_settings("settings.yaml")
|
||||
load_knowledge_model(sv)
|
||||
|
||||
|
||||
def dataset_name(key: str, sv: SessionVariables) -> str:
|
||||
"""Get dataset name."""
|
||||
return next((d for d in sv.datasets.value if d.key == key), None).name # type: ignore
|
||||
|
||||
|
||||
async def run_all_searches(query: str, sv: SessionVariables) -> list[SearchResult]:
|
||||
"""Run all search engines and return the results."""
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
tasks = []
|
||||
if sv.include_drift_search.value:
|
||||
tasks.append(
|
||||
run_drift_search(
|
||||
query=query,
|
||||
sv=sv,
|
||||
)
|
||||
)
|
||||
|
||||
if sv.include_basic_rag.value:
|
||||
tasks.append(
|
||||
run_basic_search(
|
||||
query=query,
|
||||
sv=sv,
|
||||
)
|
||||
)
|
||||
if sv.include_local_search.value:
|
||||
tasks.append(
|
||||
run_local_search(
|
||||
query=query,
|
||||
sv=sv,
|
||||
)
|
||||
)
|
||||
if sv.include_global_search.value:
|
||||
tasks.append(
|
||||
run_global_search(
|
||||
query=query,
|
||||
sv=sv,
|
||||
)
|
||||
)
|
||||
|
||||
return await asyncio.gather(*tasks)
|
||||
|
||||
|
||||
async def run_generate_questions(query: str, sv: SessionVariables):
|
||||
"""Run global search to generate questions for the dataset."""
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
tasks = []
|
||||
|
||||
tasks.append(
|
||||
run_global_search_question_generation(
|
||||
query=query,
|
||||
sv=sv,
|
||||
)
|
||||
)
|
||||
|
||||
return await asyncio.gather(*tasks)
|
||||
|
||||
|
||||
async def run_global_search_question_generation(
|
||||
query: str,
|
||||
sv: SessionVariables,
|
||||
) -> SearchResult:
|
||||
"""Run global search question generation process."""
|
||||
empty_context_data: dict[str, pd.DataFrame] = {}
|
||||
|
||||
response, context_data = await api.global_search(
|
||||
config=sv.graphrag_config.value,
|
||||
entities=sv.entities.value,
|
||||
communities=sv.communities.value,
|
||||
community_reports=sv.community_reports.value,
|
||||
dynamic_community_selection=True,
|
||||
response_type="Single paragraph",
|
||||
community_level=sv.dataset_config.value.community_level,
|
||||
query=query,
|
||||
)
|
||||
|
||||
# display response and reference context to UI
|
||||
return SearchResult(
|
||||
search_type=SearchType.Global,
|
||||
response=str(response),
|
||||
context=context_data if isinstance(context_data, dict) else empty_context_data,
|
||||
)
|
||||
|
||||
|
||||
async def run_local_search(
|
||||
query: str,
|
||||
sv: SessionVariables,
|
||||
) -> SearchResult:
|
||||
"""Run local search."""
|
||||
print(f"Local search query: {query}") # noqa T201
|
||||
|
||||
# build local search engine
|
||||
response_placeholder = st.session_state[
|
||||
f"{SearchType.Local.value.lower()}_response_placeholder"
|
||||
]
|
||||
response_container = st.session_state[f"{SearchType.Local.value.lower()}_container"]
|
||||
|
||||
with response_placeholder, st.spinner("Generating answer using local search..."):
|
||||
empty_context_data: dict[str, pd.DataFrame] = {}
|
||||
|
||||
response, context_data = await api.local_search(
|
||||
config=sv.graphrag_config.value,
|
||||
communities=sv.communities.value,
|
||||
entities=sv.entities.value,
|
||||
community_reports=sv.community_reports.value,
|
||||
text_units=sv.text_units.value,
|
||||
relationships=sv.relationships.value,
|
||||
covariates=sv.covariates.value,
|
||||
community_level=sv.dataset_config.value.community_level,
|
||||
response_type="Multiple Paragraphs",
|
||||
query=query,
|
||||
)
|
||||
|
||||
print(f"Local Response: {response}") # noqa T201
|
||||
print(f"Context data: {context_data}") # noqa T201
|
||||
|
||||
# display response and reference context to UI
|
||||
search_result = SearchResult(
|
||||
search_type=SearchType.Local,
|
||||
response=str(response),
|
||||
context=context_data if isinstance(context_data, dict) else empty_context_data,
|
||||
)
|
||||
|
||||
display_search_result(
|
||||
container=response_container, result=search_result, stats=None
|
||||
)
|
||||
|
||||
if "response_lengths" not in st.session_state:
|
||||
st.session_state.response_lengths = []
|
||||
|
||||
st.session_state["response_lengths"].append({
|
||||
"result": search_result,
|
||||
"search": SearchType.Local.value.lower(),
|
||||
})
|
||||
|
||||
return search_result
|
||||
|
||||
|
||||
async def run_global_search(query: str, sv: SessionVariables) -> SearchResult:
|
||||
"""Run global search."""
|
||||
print(f"Global search query: {query}") # noqa T201
|
||||
|
||||
# build global search engine
|
||||
response_placeholder = st.session_state[
|
||||
f"{SearchType.Global.value.lower()}_response_placeholder"
|
||||
]
|
||||
response_container = st.session_state[
|
||||
f"{SearchType.Global.value.lower()}_container"
|
||||
]
|
||||
|
||||
response_placeholder.empty()
|
||||
with response_placeholder, st.spinner("Generating answer using global search..."):
|
||||
empty_context_data: dict[str, pd.DataFrame] = {}
|
||||
|
||||
response, context_data = await api.global_search(
|
||||
config=sv.graphrag_config.value,
|
||||
entities=sv.entities.value,
|
||||
communities=sv.communities.value,
|
||||
community_reports=sv.community_reports.value,
|
||||
dynamic_community_selection=False,
|
||||
response_type="Multiple Paragraphs",
|
||||
community_level=sv.dataset_config.value.community_level,
|
||||
query=query,
|
||||
)
|
||||
|
||||
print(f"Context data: {context_data}") # noqa T201
|
||||
print(f"Global Response: {response}") # noqa T201
|
||||
|
||||
# display response and reference context to UI
|
||||
search_result = SearchResult(
|
||||
search_type=SearchType.Global,
|
||||
response=str(response),
|
||||
context=context_data if isinstance(context_data, dict) else empty_context_data,
|
||||
)
|
||||
|
||||
display_search_result(
|
||||
container=response_container, result=search_result, stats=None
|
||||
)
|
||||
|
||||
if "response_lengths" not in st.session_state:
|
||||
st.session_state.response_lengths = []
|
||||
|
||||
st.session_state["response_lengths"].append({
|
||||
"result": search_result,
|
||||
"search": SearchType.Global.value.lower(),
|
||||
})
|
||||
|
||||
return search_result
|
||||
|
||||
|
||||
async def run_drift_search(
|
||||
query: str,
|
||||
sv: SessionVariables,
|
||||
) -> SearchResult:
|
||||
"""Run drift search."""
|
||||
print(f"Drift search query: {query}") # noqa T201
|
||||
|
||||
# build drift search engine
|
||||
response_placeholder = st.session_state[
|
||||
f"{SearchType.Drift.value.lower()}_response_placeholder"
|
||||
]
|
||||
response_container = st.session_state[f"{SearchType.Drift.value.lower()}_container"]
|
||||
|
||||
with response_placeholder, st.spinner("Generating answer using drift search..."):
|
||||
empty_context_data: dict[str, pd.DataFrame] = {}
|
||||
|
||||
response, context_data = await api.drift_search(
|
||||
config=sv.graphrag_config.value,
|
||||
entities=sv.entities.value,
|
||||
communities=sv.communities.value,
|
||||
community_reports=sv.community_reports.value,
|
||||
text_units=sv.text_units.value,
|
||||
relationships=sv.relationships.value,
|
||||
community_level=sv.dataset_config.value.community_level,
|
||||
response_type="Multiple Paragraphs",
|
||||
query=query,
|
||||
)
|
||||
|
||||
print(f"Drift Response: {response}") # noqa T201
|
||||
print(f"Context data: {context_data}") # noqa T201
|
||||
|
||||
# display response and reference context to UI
|
||||
search_result = SearchResult(
|
||||
search_type=SearchType.Drift,
|
||||
response=str(response),
|
||||
context=context_data if isinstance(context_data, dict) else empty_context_data,
|
||||
)
|
||||
|
||||
display_search_result(
|
||||
container=response_container, result=search_result, stats=None
|
||||
)
|
||||
|
||||
if "response_lengths" not in st.session_state:
|
||||
st.session_state.response_lengths = []
|
||||
|
||||
st.session_state["response_lengths"].append({
|
||||
"result": None,
|
||||
"search": SearchType.Drift.value.lower(),
|
||||
})
|
||||
|
||||
return search_result
|
||||
|
||||
|
||||
async def run_basic_search(
|
||||
query: str,
|
||||
sv: SessionVariables,
|
||||
) -> SearchResult:
|
||||
"""Run basic search."""
|
||||
print(f"Basic search query: {query}") # noqa T201
|
||||
|
||||
# build local search engine
|
||||
response_placeholder = st.session_state[
|
||||
f"{SearchType.Basic.value.lower()}_response_placeholder"
|
||||
]
|
||||
response_container = st.session_state[f"{SearchType.Basic.value.lower()}_container"]
|
||||
|
||||
with response_placeholder, st.spinner("Generating answer using basic RAG..."):
|
||||
empty_context_data: dict[str, pd.DataFrame] = {}
|
||||
|
||||
response, context_data = await api.basic_search(
|
||||
config=sv.graphrag_config.value,
|
||||
text_units=sv.text_units.value,
|
||||
query=query,
|
||||
)
|
||||
|
||||
print(f"Basic Response: {response}") # noqa T201
|
||||
print(f"Context data: {context_data}") # noqa T201
|
||||
|
||||
# display response and reference context to UI
|
||||
search_result = SearchResult(
|
||||
search_type=SearchType.Basic,
|
||||
response=str(response),
|
||||
context=context_data if isinstance(context_data, dict) else empty_context_data,
|
||||
)
|
||||
|
||||
display_search_result(
|
||||
container=response_container, result=search_result, stats=None
|
||||
)
|
||||
|
||||
if "response_lengths" not in st.session_state:
|
||||
st.session_state.response_lengths = []
|
||||
|
||||
st.session_state["response_lengths"].append({
|
||||
"search": SearchType.Basic.value.lower(),
|
||||
"result": search_result,
|
||||
})
|
||||
|
||||
return search_result
|
||||
|
||||
|
||||
def load_knowledge_model(sv: SessionVariables):
|
||||
"""Load knowledge model from the datasource."""
|
||||
print("Loading knowledge model...", sv.dataset.value, sv.dataset_config.value) # noqa T201
|
||||
model = load_model(sv.dataset.value, sv.datasource.value)
|
||||
|
||||
sv.generated_questions.value = []
|
||||
sv.selected_question.value = ""
|
||||
sv.entities.value = model.entities
|
||||
sv.relationships.value = model.relationships
|
||||
sv.covariates.value = model.covariates
|
||||
sv.community_reports.value = model.community_reports
|
||||
sv.communities.value = model.communities
|
||||
sv.text_units.value = model.text_units
|
||||
|
||||
return sv
|
||||
33
unified-search-app/app/data_config.py
Normal file
33
unified-search-app/app/data_config.py
Normal file
@ -0,0 +1,33 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""Data config module."""
|
||||
|
||||
# This file is used to store configurations for the graph-indexed data and the LLM/embeddings models used in the app.
|
||||
|
||||
# name of the table in the graph-indexed data where the communities are stored
|
||||
communities_table = "output/communities"
|
||||
|
||||
# name of the table in the graph-indexed data where the community reports are stored
|
||||
community_report_table = "output/community_reports"
|
||||
|
||||
# name of the table in the graph-indexed data where the entity embeddings are stored
|
||||
entity_table = "output/entities"
|
||||
|
||||
# name of the table in the graph-indexed data where the entity relationships are stored
|
||||
relationship_table = "output/relationships"
|
||||
|
||||
# name of the table in the graph-indexed data where the entity covariates are stored
|
||||
covariate_table = "output/covariates"
|
||||
|
||||
# name of the table in the graph-indexed data where the text units are stored
|
||||
text_unit_table = "output/text_units"
|
||||
|
||||
# default configurations for LLM's answer generation, used in all search types
|
||||
# this should be adjusted based on the token limits of the LLM model being used
|
||||
# The following setting is for gpt-4-1106-preview (i.e. gpt-4-turbo)
|
||||
# For gpt-4 (token-limit = 8k), a good setting could be:
|
||||
default_suggested_questions = 5
|
||||
|
||||
# default timeout for streamlit cache
|
||||
default_ttl = 60 * 60 * 24 * 7
|
||||
260
unified-search-app/app/home_page.py
Normal file
260
unified-search-app/app/home_page.py
Normal file
@ -0,0 +1,260 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""Home Page module."""
|
||||
|
||||
import asyncio
|
||||
|
||||
import streamlit as st
|
||||
from app_logic import dataset_name, initialize, run_all_searches, run_generate_questions
|
||||
from rag.typing import SearchType
|
||||
from st_tabs import TabBar
|
||||
from state.session_variables import SessionVariables
|
||||
from ui.full_graph import create_full_graph_ui
|
||||
from ui.questions_list import create_questions_list_ui
|
||||
from ui.report_details import create_report_details_ui
|
||||
from ui.report_list import create_report_list_ui
|
||||
from ui.search import display_citations, format_suggested_questions, init_search_ui
|
||||
from ui.sidebar import create_side_bar
|
||||
|
||||
|
||||
async def main():
|
||||
"""Return main streamlit component to render the app."""
|
||||
sv = initialize()
|
||||
|
||||
create_side_bar(sv)
|
||||
|
||||
st.markdown(
|
||||
"#### GraphRAG: A Novel Knowledge Graph-based Approach to Retrieval Augmented Generation (RAG)"
|
||||
)
|
||||
st.markdown("##### Dataset selected: " + dataset_name(sv.dataset.value, sv))
|
||||
st.markdown(sv.dataset_config.value.description)
|
||||
|
||||
def on_click_reset(sv: SessionVariables):
|
||||
sv.generated_questions.value = []
|
||||
sv.selected_question.value = ""
|
||||
sv.show_text_input.value = True
|
||||
|
||||
def on_change(sv: SessionVariables):
|
||||
sv.question.value = st.session_state[question_input]
|
||||
|
||||
question_input = "question_input"
|
||||
|
||||
generate_questions = st.button("Suggest some questions")
|
||||
|
||||
question = ""
|
||||
|
||||
if len(sv.question.value.strip()) > 0:
|
||||
question = sv.question.value
|
||||
|
||||
if generate_questions:
|
||||
with st.spinner("Generating suggested questions..."):
|
||||
try:
|
||||
result = await run_generate_questions(
|
||||
query=f"Generate numbered list only with the top {sv.suggested_questions.value} most important questions of this dataset (numbered list only without titles or anything extra)",
|
||||
sv=sv,
|
||||
)
|
||||
for result_item in result:
|
||||
questions = format_suggested_questions(result_item.response)
|
||||
sv.generated_questions.value = questions
|
||||
sv.show_text_input.value = False
|
||||
except Exception as e: # noqa: BLE001
|
||||
print(f"Search exception: {e}") # noqa T201
|
||||
st.write(e)
|
||||
|
||||
if sv.show_text_input.value is True:
|
||||
st.text_input(
|
||||
"Ask a question to compare the results",
|
||||
key=question_input,
|
||||
on_change=on_change,
|
||||
value=question,
|
||||
kwargs={"sv": sv},
|
||||
)
|
||||
|
||||
if len(sv.generated_questions.value) != 0:
|
||||
create_questions_list_ui(sv)
|
||||
|
||||
if sv.show_text_input.value is False:
|
||||
st.button(label="Reset", on_click=on_click_reset, kwargs={"sv": sv})
|
||||
|
||||
tab_id = TabBar(
|
||||
tabs=["Search", "Graph Explorer"],
|
||||
color="#fc9e9e",
|
||||
activeColor="#ff4b4b",
|
||||
default=0,
|
||||
)
|
||||
|
||||
if tab_id == 0:
|
||||
if len(sv.question.value.strip()) > 0:
|
||||
question = sv.question.value
|
||||
|
||||
if sv.selected_question.value != "":
|
||||
question = sv.selected_question.value
|
||||
sv.question.value = question
|
||||
|
||||
if question:
|
||||
st.write(f"##### Answering the question: *{question}*")
|
||||
|
||||
ss_basic = None
|
||||
ss_local = None
|
||||
ss_global = None
|
||||
ss_drift = None
|
||||
|
||||
ss_basic_citations = None
|
||||
ss_local_citations = None
|
||||
ss_global_citations = None
|
||||
ss_drift_citations = None
|
||||
|
||||
count = sum([
|
||||
sv.include_basic_rag.value,
|
||||
sv.include_local_search.value,
|
||||
sv.include_global_search.value,
|
||||
sv.include_drift_search.value,
|
||||
])
|
||||
|
||||
if count > 0:
|
||||
columns = st.columns(count)
|
||||
index = 0
|
||||
if sv.include_basic_rag.value:
|
||||
ss_basic = columns[index]
|
||||
index += 1
|
||||
if sv.include_local_search.value:
|
||||
ss_local = columns[index]
|
||||
index += 1
|
||||
if sv.include_global_search.value:
|
||||
ss_global = columns[index]
|
||||
index += 1
|
||||
if sv.include_drift_search.value:
|
||||
ss_drift = columns[index]
|
||||
|
||||
else:
|
||||
st.write("Please select at least one search option from the sidebar.")
|
||||
|
||||
with st.container():
|
||||
if ss_basic:
|
||||
with ss_basic:
|
||||
init_search_ui(
|
||||
container=ss_basic,
|
||||
search_type=SearchType.Basic,
|
||||
title="##### GraphRAG: Basic RAG",
|
||||
caption="###### Answer context: Fixed number of text chunks of raw documents",
|
||||
)
|
||||
|
||||
if ss_local:
|
||||
with ss_local:
|
||||
init_search_ui(
|
||||
container=ss_local,
|
||||
search_type=SearchType.Local,
|
||||
title="##### GraphRAG: Local Search",
|
||||
caption="###### Answer context: Graph index query results with relevant document text chunks",
|
||||
)
|
||||
|
||||
if ss_global:
|
||||
with ss_global:
|
||||
init_search_ui(
|
||||
container=ss_global,
|
||||
search_type=SearchType.Global,
|
||||
title="##### GraphRAG: Global Search",
|
||||
caption="###### Answer context: AI-generated network reports covering all input documents",
|
||||
)
|
||||
|
||||
if ss_drift:
|
||||
with ss_drift:
|
||||
init_search_ui(
|
||||
container=ss_drift,
|
||||
search_type=SearchType.Drift,
|
||||
title="##### GraphRAG: Drift Search",
|
||||
caption="###### Answer context: Includes community information",
|
||||
)
|
||||
|
||||
count = sum([
|
||||
sv.include_basic_rag.value,
|
||||
sv.include_local_search.value,
|
||||
sv.include_global_search.value,
|
||||
sv.include_drift_search.value,
|
||||
])
|
||||
|
||||
if count > 0:
|
||||
columns = st.columns(count)
|
||||
index = 0
|
||||
if sv.include_basic_rag.value:
|
||||
ss_basic_citations = columns[index]
|
||||
index += 1
|
||||
if sv.include_local_search.value:
|
||||
ss_local_citations = columns[index]
|
||||
index += 1
|
||||
if sv.include_global_search.value:
|
||||
ss_global_citations = columns[index]
|
||||
index += 1
|
||||
if sv.include_drift_search.value:
|
||||
ss_drift_citations = columns[index]
|
||||
|
||||
with st.container():
|
||||
if ss_basic_citations:
|
||||
with ss_basic_citations:
|
||||
st.empty()
|
||||
if ss_local_citations:
|
||||
with ss_local_citations:
|
||||
st.empty()
|
||||
if ss_global_citations:
|
||||
with ss_global_citations:
|
||||
st.empty()
|
||||
if ss_drift_citations:
|
||||
with ss_drift_citations:
|
||||
st.empty()
|
||||
|
||||
if question != "" and question != sv.question_in_progress.value:
|
||||
sv.question_in_progress.value = question
|
||||
try:
|
||||
await run_all_searches(query=question, sv=sv)
|
||||
|
||||
if "response_lengths" not in st.session_state:
|
||||
st.session_state.response_lengths = []
|
||||
|
||||
for result in st.session_state.response_lengths:
|
||||
if result["search"] == SearchType.Basic.value.lower():
|
||||
display_citations(
|
||||
container=ss_basic_citations,
|
||||
result=result["result"],
|
||||
)
|
||||
if result["search"] == SearchType.Local.value.lower():
|
||||
display_citations(
|
||||
container=ss_local_citations,
|
||||
result=result["result"],
|
||||
)
|
||||
if result["search"] == SearchType.Global.value.lower():
|
||||
display_citations(
|
||||
container=ss_global_citations,
|
||||
result=result["result"],
|
||||
)
|
||||
elif result["search"] == SearchType.Drift.value.lower():
|
||||
display_citations(
|
||||
container=ss_drift_citations,
|
||||
result=result["result"],
|
||||
)
|
||||
except Exception as e: # noqa: BLE001
|
||||
print(f"Search exception: {e}") # noqa T201
|
||||
st.write(e)
|
||||
|
||||
if tab_id == 1:
|
||||
report_list, graph, report_content = st.columns([0.20, 0.55, 0.25])
|
||||
|
||||
with report_list:
|
||||
st.markdown("##### Community Reports")
|
||||
create_report_list_ui(sv)
|
||||
|
||||
with graph:
|
||||
title, dropdown = st.columns([0.80, 0.20])
|
||||
title.markdown("##### Entity Graph (All entities)")
|
||||
dropdown.selectbox(
|
||||
"Community level", options=[0, 1], key=sv.graph_community_level.key
|
||||
)
|
||||
create_full_graph_ui(sv)
|
||||
|
||||
with report_content:
|
||||
st.markdown("##### Selected Report")
|
||||
create_report_details_ui(sv)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
4
unified-search-app/app/knowledge_loader/__init__.py
Normal file
4
unified-search-app/app/knowledge_loader/__init__.py
Normal file
@ -0,0 +1,4 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""Knowledge loader module."""
|
||||
75
unified-search-app/app/knowledge_loader/data_prep.py
Normal file
75
unified-search-app/app/knowledge_loader/data_prep.py
Normal file
@ -0,0 +1,75 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""Data prep module."""
|
||||
|
||||
import logging
|
||||
|
||||
import data_config as config
|
||||
import pandas as pd
|
||||
import streamlit as st
|
||||
from knowledge_loader.data_sources.typing import Datasource
|
||||
|
||||
"""
|
||||
Contains functions to load and prep graph-indexed data from parquet files into dataframes.
|
||||
These output dataframes will then be used to create knowledge model's objects to be used as inputs for the graphrag-orchestration functions
|
||||
"""
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logging.getLogger("azure").setLevel(logging.WARNING)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@st.cache_data(ttl=config.default_ttl)
|
||||
def get_entity_data(dataset: str, _datasource: Datasource) -> pd.DataFrame:
|
||||
"""Return a dataframe with entity data from the indexed-data."""
|
||||
entity_details_df = _datasource.read(config.entity_table)
|
||||
|
||||
print(f"Entity records: {len(entity_details_df)}") # noqa T201
|
||||
print(f"Dataset: {dataset}") # noqa T201
|
||||
return entity_details_df
|
||||
|
||||
|
||||
@st.cache_data(ttl=config.default_ttl)
|
||||
def get_relationship_data(dataset: str, _datasource: Datasource) -> pd.DataFrame:
|
||||
"""Return a dataframe with entity-entity relationship data from the indexed-data."""
|
||||
relationship_df = _datasource.read(config.relationship_table)
|
||||
print(f"Relationship records: {len(relationship_df)}") # noqa T201
|
||||
print(f"Dataset: {dataset}") # noqa T201
|
||||
return relationship_df
|
||||
|
||||
|
||||
@st.cache_data(ttl=config.default_ttl)
|
||||
def get_covariate_data(dataset: str, _datasource: Datasource) -> pd.DataFrame:
|
||||
"""Return a dataframe with covariate data from the indexed-data."""
|
||||
covariate_df = _datasource.read(config.covariate_table)
|
||||
print(f"Covariate records: {len(covariate_df)}") # noqa T201
|
||||
print(f"Dataset: {dataset}") # noqa T201
|
||||
return covariate_df
|
||||
|
||||
|
||||
@st.cache_data(ttl=config.default_ttl)
|
||||
def get_text_unit_data(dataset: str, _datasource: Datasource) -> pd.DataFrame:
|
||||
"""Return a dataframe with text units (i.e. chunks of text from the raw documents) from the indexed-data."""
|
||||
text_unit_df = _datasource.read(config.text_unit_table)
|
||||
print(f"Text unit records: {len(text_unit_df)}") # noqa T201
|
||||
print(f"Dataset: {dataset}") # noqa T201
|
||||
return text_unit_df
|
||||
|
||||
|
||||
@st.cache_data(ttl=config.default_ttl)
|
||||
def get_community_report_data(
|
||||
_datasource: Datasource,
|
||||
) -> pd.DataFrame:
|
||||
"""Return a dataframe with community report data from the indexed-data."""
|
||||
report_df = _datasource.read(config.community_report_table)
|
||||
print(f"Report records: {len(report_df)}") # noqa T201
|
||||
|
||||
return report_df
|
||||
|
||||
|
||||
@st.cache_data(ttl=config.default_ttl)
|
||||
def get_communities_data(
|
||||
_datasource: Datasource,
|
||||
) -> pd.DataFrame:
|
||||
"""Return a dataframe with communities data from the indexed-data."""
|
||||
return _datasource.read(config.communities_table)
|
||||
@ -0,0 +1,4 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""Data sources module."""
|
||||
@ -0,0 +1,127 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""Blob source module."""
|
||||
|
||||
import io
|
||||
import logging
|
||||
import os
|
||||
from io import BytesIO
|
||||
|
||||
import pandas as pd
|
||||
import streamlit as st
|
||||
import yaml
|
||||
from azure.identity import DefaultAzureCredential
|
||||
from azure.storage.blob import BlobServiceClient, ContainerClient
|
||||
from knowledge_loader.data_sources.typing import Datasource
|
||||
|
||||
from graphrag.config.create_graphrag_config import create_graphrag_config
|
||||
from graphrag.config.models.graph_rag_config import GraphRagConfig
|
||||
|
||||
from .default import blob_account_name, blob_container_name
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logging.getLogger("azure").setLevel(logging.WARNING)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@st.cache_data(ttl=60 * 60 * 24)
|
||||
def _get_container(account_name: str, container_name: str) -> ContainerClient:
|
||||
"""Return container from blob storage."""
|
||||
print("LOGIN---------------") # noqa T201
|
||||
account_url = f"https://{account_name}.blob.core.windows.net"
|
||||
default_credential = DefaultAzureCredential()
|
||||
blob_service_client = BlobServiceClient(account_url, credential=default_credential)
|
||||
return blob_service_client.get_container_client(container_name)
|
||||
|
||||
|
||||
def load_blob_prompt_config(
|
||||
dataset: str,
|
||||
account_name: str | None = blob_account_name,
|
||||
container_name: str | None = blob_container_name,
|
||||
) -> dict[str, str]:
|
||||
"""Load blob prompt configuration."""
|
||||
if account_name is None or container_name is None:
|
||||
return {}
|
||||
|
||||
container_client = _get_container(account_name, container_name)
|
||||
prompts = {}
|
||||
|
||||
prefix = f"{dataset}/prompts"
|
||||
for file in container_client.list_blobs(name_starts_with=prefix):
|
||||
map_name = file.name.split("/")[-1].split(".")[0]
|
||||
prompts[map_name] = (
|
||||
container_client.download_blob(file.name).readall().decode("utf-8")
|
||||
)
|
||||
|
||||
return prompts
|
||||
|
||||
|
||||
def load_blob_file(
|
||||
dataset: str | None,
|
||||
file: str | None,
|
||||
account_name: str | None = blob_account_name,
|
||||
container_name: str | None = blob_container_name,
|
||||
) -> BytesIO:
|
||||
"""Load blob file from container."""
|
||||
stream = io.BytesIO()
|
||||
|
||||
if account_name is None or container_name is None:
|
||||
logger.warning("No account name or container name provided")
|
||||
return stream
|
||||
|
||||
container_client = _get_container(account_name, container_name)
|
||||
blob_path = f"{dataset}/{file}" if dataset is not None else file
|
||||
|
||||
container_client.download_blob(blob_path).readinto(stream)
|
||||
|
||||
return stream
|
||||
|
||||
|
||||
class BlobDatasource(Datasource):
|
||||
"""Datasource that reads from a blob storage parquet file."""
|
||||
|
||||
def __init__(self, database: str):
|
||||
"""Init method definition."""
|
||||
self._database = database
|
||||
|
||||
def read(
|
||||
self,
|
||||
table: str,
|
||||
throw_on_missing: bool = False,
|
||||
columns: list[str] | None = None,
|
||||
) -> pd.DataFrame:
|
||||
"""Read file from container."""
|
||||
try:
|
||||
data = load_blob_file(self._database, f"{table}.parquet")
|
||||
except Exception as err:
|
||||
if throw_on_missing:
|
||||
error_msg = f"Table {table} does not exist"
|
||||
raise FileNotFoundError(error_msg) from err
|
||||
logger.warning("Table %s does not exist", table)
|
||||
return pd.DataFrame(columns=columns) if columns else pd.DataFrame()
|
||||
|
||||
return pd.read_parquet(data, columns=columns)
|
||||
|
||||
def read_settings(
|
||||
self,
|
||||
file: str,
|
||||
throw_on_missing: bool = False,
|
||||
) -> GraphRagConfig | None:
|
||||
"""Read settings from container."""
|
||||
try:
|
||||
settings = load_blob_file(self._database, file)
|
||||
settings.seek(0)
|
||||
str_settings = settings.read().decode("utf-8")
|
||||
config = os.path.expandvars(str_settings)
|
||||
settings_yaml = yaml.safe_load(config)
|
||||
graphrag_config = create_graphrag_config(values=settings_yaml)
|
||||
except Exception as err:
|
||||
if throw_on_missing:
|
||||
error_msg = f"File {file} does not exist"
|
||||
raise FileNotFoundError(error_msg) from err
|
||||
|
||||
logger.warning("File %s does not exist", file)
|
||||
return None
|
||||
|
||||
return graphrag_config
|
||||
@ -0,0 +1,20 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""Data sources default module."""
|
||||
|
||||
import os
|
||||
|
||||
container_name = "data"
|
||||
blob_container_name = os.getenv("BLOB_CONTAINER_NAME", container_name)
|
||||
blob_account_name = os.getenv("BLOB_ACCOUNT_NAME")
|
||||
|
||||
local_data_root = os.getenv("DATA_ROOT")
|
||||
|
||||
LISTING_FILE = "listing.json"
|
||||
|
||||
if local_data_root is None and blob_account_name is None:
|
||||
error_message = (
|
||||
"Either DATA_ROOT or BLOB_ACCOUNT_NAME environment variable must be set."
|
||||
)
|
||||
raise ValueError(error_message)
|
||||
@ -0,0 +1,78 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""Loader module."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
|
||||
from knowledge_loader.data_sources.blob_source import (
|
||||
BlobDatasource,
|
||||
load_blob_file,
|
||||
load_blob_prompt_config,
|
||||
)
|
||||
from knowledge_loader.data_sources.default import (
|
||||
LISTING_FILE,
|
||||
blob_account_name,
|
||||
local_data_root,
|
||||
)
|
||||
from knowledge_loader.data_sources.local_source import (
|
||||
LocalDatasource,
|
||||
load_local_prompt_config,
|
||||
)
|
||||
from knowledge_loader.data_sources.typing import DatasetConfig, Datasource
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logging.getLogger("azure").setLevel(logging.WARNING)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _get_base_path(
|
||||
dataset: str | None, root: str | None, extra_path: str | None = None
|
||||
) -> str:
|
||||
"""Construct and return the base path for the given dataset and extra path."""
|
||||
return os.path.join( # noqa: PTH118
|
||||
os.path.dirname(os.path.realpath(__file__)), # noqa: PTH120
|
||||
root if root else "",
|
||||
dataset if dataset else "",
|
||||
*(extra_path.split("/") if extra_path else []),
|
||||
)
|
||||
|
||||
|
||||
def create_datasource(dataset_folder: str) -> Datasource:
|
||||
"""Return a datasource that reads from a local or blob storage parquet file."""
|
||||
if blob_account_name is not None and blob_account_name != "":
|
||||
return BlobDatasource(dataset_folder)
|
||||
|
||||
base_path = _get_base_path(dataset_folder, local_data_root)
|
||||
return LocalDatasource(base_path)
|
||||
|
||||
|
||||
def load_dataset_listing() -> list[DatasetConfig]:
|
||||
"""Load dataset listing file."""
|
||||
datasets = []
|
||||
if blob_account_name is not None and blob_account_name != "":
|
||||
try:
|
||||
blob = load_blob_file(None, LISTING_FILE)
|
||||
datasets_str = blob.getvalue().decode("utf-8")
|
||||
if datasets_str:
|
||||
datasets = json.loads(datasets_str)
|
||||
except Exception as e: # noqa: BLE001
|
||||
print(f"Error loading dataset config: {e}") # noqa T201
|
||||
return []
|
||||
else:
|
||||
base_path = _get_base_path(None, local_data_root, LISTING_FILE)
|
||||
with open(base_path, "r") as file: # noqa: UP015, PTH123
|
||||
datasets = json.load(file)
|
||||
|
||||
return [DatasetConfig(**d) for d in datasets]
|
||||
|
||||
|
||||
def load_prompts(dataset: str) -> dict[str, str]:
|
||||
"""Return the prompts configuration for a specific dataset."""
|
||||
if blob_account_name is not None and blob_account_name != "":
|
||||
return load_blob_prompt_config(dataset)
|
||||
|
||||
base_path = _get_base_path(dataset, local_data_root, "prompts")
|
||||
return load_local_prompt_config(base_path)
|
||||
@ -0,0 +1,72 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""Local source module."""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import pandas as pd
|
||||
from knowledge_loader.data_sources.typing import Datasource
|
||||
|
||||
from graphrag.config.load_config import load_config
|
||||
from graphrag.config.models.graph_rag_config import GraphRagConfig
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logging.getLogger("azure").setLevel(logging.WARNING)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def load_local_prompt_config(base_path="") -> dict[str, str]:
|
||||
"""Load local prompt configuration."""
|
||||
# for each file inside folder base_path
|
||||
prompts = {}
|
||||
|
||||
for path in os.listdir(base_path): # noqa: PTH208
|
||||
with open(os.path.join(base_path, path), "r") as f: # noqa: UP015, PTH123, PTH118
|
||||
map_name = path.split(".")[0]
|
||||
prompts[map_name] = f.read()
|
||||
return prompts
|
||||
|
||||
|
||||
class LocalDatasource(Datasource):
|
||||
"""Datasource that reads from a local parquet file."""
|
||||
|
||||
_base_path: str
|
||||
|
||||
def __init__(self, base_path: str):
|
||||
"""Init method definition."""
|
||||
self._base_path = base_path
|
||||
|
||||
def read(
|
||||
self,
|
||||
table: str,
|
||||
throw_on_missing: bool = False,
|
||||
columns: list[str] | None = None,
|
||||
) -> pd.DataFrame:
|
||||
"""Read file from local source."""
|
||||
table = os.path.join(self._base_path, f"{table}.parquet") # noqa: PTH118
|
||||
|
||||
if not os.path.exists(table): # noqa: PTH110
|
||||
if throw_on_missing:
|
||||
error_msg = f"Table {table} does not exist"
|
||||
raise FileNotFoundError(error_msg)
|
||||
|
||||
print(f"Table {table} does not exist") # noqa T201
|
||||
return (
|
||||
pd.DataFrame(data=[], columns=columns)
|
||||
if columns is not None
|
||||
else pd.DataFrame()
|
||||
)
|
||||
return pd.read_parquet(table, columns=columns)
|
||||
|
||||
def read_settings(
|
||||
self,
|
||||
file: str,
|
||||
throw_on_missing: bool = False,
|
||||
) -> GraphRagConfig | None:
|
||||
"""Read settings file from local source."""
|
||||
cwd = Path(__file__).parent
|
||||
root_dir = (cwd / self._base_path).resolve()
|
||||
return load_config(root_dir=root_dir)
|
||||
@ -0,0 +1,75 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""Data sources typing module."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from graphrag.config.models.graph_rag_config import GraphRagConfig
|
||||
|
||||
|
||||
class WriteMode(Enum):
|
||||
"""An enum for the write modes of a datasource."""
|
||||
|
||||
# Overwrite means all the data in the table will be replaced with the new data.
|
||||
Overwrite = 1
|
||||
|
||||
# Append means the new data will be appended to the existing data in the table.
|
||||
Append = 2
|
||||
|
||||
|
||||
class Datasource(ABC):
|
||||
"""An interface for a datasource, which is a function that takes a table name and returns a DataFrame or None."""
|
||||
|
||||
def __call__(self, table: str, columns: list[str] | None) -> pd.DataFrame:
|
||||
"""Call method definition."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def read(
|
||||
self,
|
||||
table: str,
|
||||
throw_on_missing: bool = False,
|
||||
columns: list[str] | None = None,
|
||||
) -> pd.DataFrame:
|
||||
"""Read method definition."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def read_settings(self, file: str) -> GraphRagConfig | None:
|
||||
"""Read settings method definition."""
|
||||
raise NotImplementedError
|
||||
|
||||
def write(
|
||||
self, table: str, df: pd.DataFrame, mode: WriteMode | None = None
|
||||
) -> None:
|
||||
"""Write method definition."""
|
||||
raise NotImplementedError
|
||||
|
||||
def has_table(self, table: str) -> bool:
|
||||
"""Check if table exists method definition."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@dataclass
|
||||
class VectorIndexConfig:
|
||||
"""VectorIndexConfig class definition."""
|
||||
|
||||
index_name: str
|
||||
embeddings_file: str
|
||||
content_file: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class DatasetConfig:
|
||||
"""DatasetConfig class definition."""
|
||||
|
||||
key: str
|
||||
path: str
|
||||
name: str
|
||||
description: str
|
||||
community_level: int
|
||||
110
unified-search-app/app/knowledge_loader/model.py
Normal file
110
unified-search-app/app/knowledge_loader/model.py
Normal file
@ -0,0 +1,110 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""Model module."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
import pandas as pd
|
||||
import streamlit as st
|
||||
from data_config import (
|
||||
default_ttl,
|
||||
)
|
||||
from knowledge_loader.data_prep import (
|
||||
get_communities_data,
|
||||
get_community_report_data,
|
||||
get_covariate_data,
|
||||
get_entity_data,
|
||||
get_relationship_data,
|
||||
get_text_unit_data,
|
||||
)
|
||||
from knowledge_loader.data_sources.typing import Datasource
|
||||
|
||||
"""
|
||||
Contain functions to load graph-indexed data into collections of knowledge model objects.
|
||||
These collections will be used as inputs for the graphrag-orchestration functions
|
||||
"""
|
||||
|
||||
|
||||
@st.cache_data(ttl=default_ttl)
|
||||
def load_entities(
|
||||
dataset: str,
|
||||
_datasource: Datasource,
|
||||
) -> pd.DataFrame:
|
||||
"""Return a list of Entity objects."""
|
||||
return get_entity_data(dataset, _datasource)
|
||||
|
||||
|
||||
@st.cache_data(ttl=default_ttl)
|
||||
def load_entity_relationships(
|
||||
dataset: str,
|
||||
_datasource: Datasource,
|
||||
) -> pd.DataFrame:
|
||||
"""Return lists of Entity and Relationship objects."""
|
||||
return get_relationship_data(dataset, _datasource)
|
||||
|
||||
|
||||
@st.cache_data(ttl=default_ttl)
|
||||
def load_covariates(dataset: str, _datasource: Datasource) -> pd.DataFrame:
|
||||
"""Return a dictionary of Covariate objects, with the key being the covariate type."""
|
||||
return get_covariate_data(dataset, _datasource)
|
||||
|
||||
|
||||
@st.cache_data(ttl=default_ttl)
|
||||
def load_community_reports(
|
||||
_datasource: Datasource,
|
||||
) -> pd.DataFrame:
|
||||
"""Return a list of CommunityReport objects."""
|
||||
return get_community_report_data(_datasource)
|
||||
|
||||
|
||||
@st.cache_data(ttl=default_ttl)
|
||||
def load_communities(
|
||||
_datasource: Datasource,
|
||||
) -> pd.DataFrame:
|
||||
"""Return a list of Communities objects."""
|
||||
return get_communities_data(_datasource)
|
||||
|
||||
|
||||
@st.cache_data(ttl=default_ttl)
|
||||
def load_text_units(dataset: str, _datasource: Datasource) -> pd.DataFrame:
|
||||
"""Return a list of TextUnit objects."""
|
||||
return get_text_unit_data(dataset, _datasource)
|
||||
|
||||
|
||||
@dataclass
|
||||
class KnowledgeModel:
|
||||
"""KnowledgeModel class definition."""
|
||||
|
||||
entities: pd.DataFrame
|
||||
relationships: pd.DataFrame
|
||||
community_reports: pd.DataFrame
|
||||
communities: pd.DataFrame
|
||||
text_units: pd.DataFrame
|
||||
covariates: pd.DataFrame | None = None
|
||||
|
||||
|
||||
def load_model(
|
||||
dataset: str,
|
||||
datasource: Datasource,
|
||||
):
|
||||
"""
|
||||
Load all relevant graph-indexed data into collections of knowledge model objects and store the model collections in the session variables.
|
||||
|
||||
This is a one-time data retrieval and preparation per session.
|
||||
"""
|
||||
entities = load_entities(dataset, datasource)
|
||||
relationships = load_entity_relationships(dataset, datasource)
|
||||
covariates = load_covariates(dataset, datasource)
|
||||
community_reports = load_community_reports(datasource)
|
||||
communities = load_communities(datasource)
|
||||
text_units = load_text_units(dataset, datasource)
|
||||
|
||||
return KnowledgeModel(
|
||||
entities=entities,
|
||||
relationships=relationships,
|
||||
community_reports=community_reports,
|
||||
communities=communities,
|
||||
text_units=text_units,
|
||||
covariates=(None if covariates.empty else covariates),
|
||||
)
|
||||
4
unified-search-app/app/rag/__init__.py
Normal file
4
unified-search-app/app/rag/__init__.py
Normal file
@ -0,0 +1,4 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""Rag module."""
|
||||
28
unified-search-app/app/rag/typing.py
Normal file
28
unified-search-app/app/rag/typing.py
Normal file
@ -0,0 +1,28 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""Typing module."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
|
||||
import pandas as pd
|
||||
|
||||
|
||||
class SearchType(Enum):
|
||||
"""SearchType class definition."""
|
||||
|
||||
Basic = "basic"
|
||||
Local = "local"
|
||||
Global = "global"
|
||||
Drift = "drift"
|
||||
|
||||
|
||||
@dataclass
|
||||
class SearchResult:
|
||||
"""SearchResult class definition."""
|
||||
|
||||
# create a dataclass to store the search result of each algorithm
|
||||
search_type: SearchType
|
||||
response: str
|
||||
context: dict[str, pd.DataFrame]
|
||||
4
unified-search-app/app/state/__init__.py
Normal file
4
unified-search-app/app/state/__init__.py
Normal file
@ -0,0 +1,4 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""App state module."""
|
||||
45
unified-search-app/app/state/query_variable.py
Normal file
45
unified-search-app/app/state/query_variable.py
Normal file
@ -0,0 +1,45 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""Query variable module."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
import streamlit as st
|
||||
|
||||
|
||||
class QueryVariable:
|
||||
"""
|
||||
Manage reading and writing variables from the URL query string.
|
||||
|
||||
We handle translation between string values and bools, accounting for always-lowercase URLs to avoid case issues.
|
||||
Note that all variables are managed via session state to account for widgets that auto-read.
|
||||
We just push them up to the query to keep it updated.
|
||||
"""
|
||||
|
||||
def __init__(self, key: str, default: Any | None):
|
||||
"""Init method definition."""
|
||||
self._key = key
|
||||
val = st.query_params[key].lower() if key in st.query_params else default
|
||||
if val == "true":
|
||||
val = True
|
||||
elif val == "false":
|
||||
val = False
|
||||
if key not in st.session_state:
|
||||
st.session_state[key] = val
|
||||
|
||||
@property
|
||||
def key(self) -> str:
|
||||
"""Key property definition."""
|
||||
return self._key
|
||||
|
||||
@property
|
||||
def value(self) -> Any:
|
||||
"""Value property definition."""
|
||||
return st.session_state[self._key]
|
||||
|
||||
@value.setter
|
||||
def value(self, value: Any) -> None:
|
||||
"""Value setter definition."""
|
||||
st.session_state[self._key] = value
|
||||
st.query_params[self._key] = f"{value}".lower()
|
||||
53
unified-search-app/app/state/session_variable.py
Normal file
53
unified-search-app/app/state/session_variable.py
Normal file
@ -0,0 +1,53 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""Session variable module."""
|
||||
|
||||
import traceback
|
||||
from typing import Any
|
||||
|
||||
import streamlit as st
|
||||
|
||||
|
||||
class SessionVariable:
|
||||
"""Define the session variable structure that will be used in the app."""
|
||||
|
||||
def __init__(self, default: Any = "", prefix: str = ""):
|
||||
"""Create a managed session variable with a default value and a prefix.
|
||||
|
||||
The prefix is used to avoid collisions between variables with the same name.
|
||||
|
||||
To modify the variable use the value property, for example: `name.value = "Bob"`
|
||||
To get the value use the variable itself, for example: `name`
|
||||
|
||||
Use this class to avoid using st.session_state dictionary directly and be able to
|
||||
just use the variables. These variables will share values across files as long as you use
|
||||
the same variable name and prefix.
|
||||
"""
|
||||
(_, _, _, text) = traceback.extract_stack()[-2]
|
||||
var_name = text[: text.find("=")].strip()
|
||||
|
||||
self._key = "_".join(arg for arg in [prefix, var_name] if arg != "")
|
||||
self._value = default
|
||||
|
||||
if self._key not in st.session_state:
|
||||
st.session_state[self._key] = default
|
||||
|
||||
@property
|
||||
def key(self) -> str:
|
||||
"""Key property definition."""
|
||||
return self._key
|
||||
|
||||
@property
|
||||
def value(self) -> Any:
|
||||
"""Value property definition."""
|
||||
return st.session_state[self._key]
|
||||
|
||||
@value.setter
|
||||
def value(self, value: Any) -> None:
|
||||
"""Value setter definition."""
|
||||
st.session_state[self._key] = value
|
||||
|
||||
def __repr__(self) -> Any:
|
||||
"""Repr method definition."""
|
||||
return str(st.session_state[self._key])
|
||||
42
unified-search-app/app/state/session_variables.py
Normal file
42
unified-search-app/app/state/session_variables.py
Normal file
@ -0,0 +1,42 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""Session variables module."""
|
||||
|
||||
from data_config import (
|
||||
default_suggested_questions,
|
||||
)
|
||||
from state.query_variable import QueryVariable
|
||||
from state.session_variable import SessionVariable
|
||||
|
||||
|
||||
class SessionVariables:
|
||||
"""Define all the session variables that will be used in the app."""
|
||||
|
||||
def __init__(self):
|
||||
"""Init method definition."""
|
||||
self.dataset = QueryVariable("dataset", "")
|
||||
self.datasets = SessionVariable([])
|
||||
self.dataset_config = SessionVariable()
|
||||
self.datasource = SessionVariable()
|
||||
self.graphrag_config = SessionVariable()
|
||||
self.question = QueryVariable("question", "")
|
||||
self.suggested_questions = SessionVariable(default_suggested_questions)
|
||||
self.entities = SessionVariable([])
|
||||
self.relationships = SessionVariable([])
|
||||
self.covariates = SessionVariable({})
|
||||
self.communities = SessionVariable([])
|
||||
self.community_reports = SessionVariable([])
|
||||
self.text_units = SessionVariable([])
|
||||
self.question_in_progress = SessionVariable("")
|
||||
self.include_global_search = QueryVariable("include_global_search", True)
|
||||
self.include_local_search = QueryVariable("include_local_search", True)
|
||||
self.include_drift_search = QueryVariable("include_drift_search", False)
|
||||
self.include_basic_rag = QueryVariable("include_basic_rag", False)
|
||||
|
||||
self.selected_report = SessionVariable()
|
||||
self.graph_community_level = SessionVariable(0)
|
||||
|
||||
self.selected_question = SessionVariable("")
|
||||
self.generated_questions = SessionVariable([])
|
||||
self.show_text_input = SessionVariable(True)
|
||||
4
unified-search-app/app/ui/__init__.py
Normal file
4
unified-search-app/app/ui/__init__.py
Normal file
@ -0,0 +1,4 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""App UI module."""
|
||||
56
unified-search-app/app/ui/full_graph.py
Normal file
56
unified-search-app/app/ui/full_graph.py
Normal file
@ -0,0 +1,56 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""Full graph module."""
|
||||
|
||||
import altair as alt
|
||||
import pandas as pd
|
||||
import streamlit as st
|
||||
from state.session_variables import SessionVariables
|
||||
|
||||
|
||||
def create_full_graph_ui(sv: SessionVariables):
|
||||
"""Return graph UI object."""
|
||||
entities = sv.entities.value.copy()
|
||||
communities = sv.communities.value.copy()
|
||||
|
||||
if not communities.empty and not entities.empty:
|
||||
communities_entities = (
|
||||
communities.explode("entity_ids")
|
||||
.merge(
|
||||
entities,
|
||||
left_on="entity_ids",
|
||||
right_on="id",
|
||||
suffixes=("_entities", "_communities"),
|
||||
)
|
||||
.dropna(subset=["x", "y"])
|
||||
)
|
||||
else:
|
||||
communities_entities = pd.DataFrame()
|
||||
|
||||
level = sv.graph_community_level.value
|
||||
communities_entities_filtered = communities_entities[
|
||||
communities_entities["level"] == level
|
||||
]
|
||||
|
||||
graph = (
|
||||
alt.Chart(communities_entities_filtered)
|
||||
.mark_circle()
|
||||
.encode(
|
||||
x="x",
|
||||
y="y",
|
||||
color=alt.Color(
|
||||
"community",
|
||||
scale=alt.Scale(
|
||||
domain=communities_entities_filtered["community"].unique(),
|
||||
scheme="category10",
|
||||
),
|
||||
),
|
||||
size=alt.Size("degree", scale=alt.Scale(range=[50, 1000]), legend=None),
|
||||
tooltip=["id_entities", "type", "description", "community"],
|
||||
)
|
||||
.properties(height=1000)
|
||||
.configure_axis(disable=True)
|
||||
)
|
||||
st.altair_chart(graph, use_container_width=True)
|
||||
return graph
|
||||
23
unified-search-app/app/ui/questions_list.py
Normal file
23
unified-search-app/app/ui/questions_list.py
Normal file
@ -0,0 +1,23 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""Question list module."""
|
||||
|
||||
import streamlit as st
|
||||
from state.session_variables import SessionVariables
|
||||
|
||||
|
||||
def create_questions_list_ui(sv: SessionVariables):
|
||||
"""Return question list UI component."""
|
||||
selection = st.dataframe(
|
||||
sv.generated_questions.value,
|
||||
use_container_width=True,
|
||||
hide_index=True,
|
||||
selection_mode="single-row",
|
||||
column_config={"value": "question"},
|
||||
on_select="rerun",
|
||||
)
|
||||
rows = selection.selection.rows
|
||||
if len(rows) > 0:
|
||||
question_index = selection.selection.rows[0]
|
||||
sv.selected_question.value = sv.generated_questions.value[question_index]
|
||||
98
unified-search-app/app/ui/report_details.py
Normal file
98
unified-search-app/app/ui/report_details.py
Normal file
@ -0,0 +1,98 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""Report details module."""
|
||||
|
||||
import json
|
||||
|
||||
import pandas as pd
|
||||
import streamlit as st
|
||||
from state.session_variables import SessionVariables
|
||||
from ui.search import (
|
||||
display_graph_citations,
|
||||
format_response_hyperlinks,
|
||||
get_ids_per_key,
|
||||
)
|
||||
|
||||
|
||||
def create_report_details_ui(sv: SessionVariables):
|
||||
"""Return report details UI component."""
|
||||
if sv.selected_report.value is not None and sv.selected_report.value.empty is False:
|
||||
text = ""
|
||||
entity_ids = []
|
||||
relationship_ids = []
|
||||
try:
|
||||
report = json.loads(sv.selected_report.value.full_content_json)
|
||||
title = report["title"]
|
||||
summary = report["summary"]
|
||||
rating = report["rating"]
|
||||
rating_explanation = report["rating_explanation"]
|
||||
findings = report["findings"]
|
||||
text += f"#### {title}\n\n{summary}\n\n"
|
||||
text += f"**Priority: {rating}**\n\n{rating_explanation}\n\n##### Key Findings\n\n"
|
||||
if isinstance(findings, list):
|
||||
for finding in findings:
|
||||
# extract data for citations
|
||||
entity_ids.extend(
|
||||
get_ids_per_key(finding["explanation"], "Entities")
|
||||
)
|
||||
relationship_ids.extend(
|
||||
get_ids_per_key(finding["explanation"], "Relationships")
|
||||
)
|
||||
|
||||
formatted_text = format_response_hyperlinks(
|
||||
finding["explanation"], "graph"
|
||||
)
|
||||
text += f"\n\n**{finding['summary']}**\n\n{formatted_text}"
|
||||
elif isinstance(findings, str):
|
||||
# extract data for citations
|
||||
entity_ids.extend(get_ids_per_key(finding["explanation"], "Entities")) # type: ignore
|
||||
relationship_ids.extend(
|
||||
get_ids_per_key(finding["explanation"], "Relationships") # type: ignore
|
||||
)
|
||||
|
||||
formatted_text = format_response_hyperlinks(findings, "graph")
|
||||
text += f"\n\n{formatted_text}"
|
||||
|
||||
except json.JSONDecodeError:
|
||||
st.write("Error parsing report.")
|
||||
st.write(sv.selected_report.value.full_content_json)
|
||||
text_replacement = (
|
||||
text.replace("Entity_Relationships", "Relationships")
|
||||
.replace("Entity_Claims", "Claims")
|
||||
.replace("Entity_Details", "Entities")
|
||||
)
|
||||
st.markdown(f"{text_replacement}", unsafe_allow_html=True)
|
||||
|
||||
# extract entities
|
||||
selected_entities = []
|
||||
for _index, row in sv.entities.value.iterrows():
|
||||
if str(row["human_readable_id"]) in entity_ids:
|
||||
selected_entities.append({
|
||||
"id": str(row["human_readable_id"]),
|
||||
"title": row["title"],
|
||||
"description": row["description"],
|
||||
})
|
||||
|
||||
sorted_entities = sorted(selected_entities, key=lambda x: int(x["id"]))
|
||||
|
||||
# extract relationships
|
||||
selected_relationships = []
|
||||
for _index, row in sv.relationships.value.iterrows():
|
||||
if str(row["human_readable_id"]) in relationship_ids:
|
||||
selected_relationships.append({
|
||||
"id": str(row["human_readable_id"]),
|
||||
"source": row["source"],
|
||||
"target": row["target"],
|
||||
"description": row["description"],
|
||||
})
|
||||
|
||||
sorted_relationships = sorted(
|
||||
selected_relationships, key=lambda x: int(x["id"])
|
||||
)
|
||||
|
||||
display_graph_citations(
|
||||
pd.DataFrame(sorted_entities), pd.DataFrame(sorted_relationships), "graph"
|
||||
)
|
||||
else:
|
||||
st.write("No report selected")
|
||||
25
unified-search-app/app/ui/report_list.py
Normal file
25
unified-search-app/app/ui/report_list.py
Normal file
@ -0,0 +1,25 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""Report list module."""
|
||||
|
||||
import streamlit as st
|
||||
from state.session_variables import SessionVariables
|
||||
|
||||
|
||||
def create_report_list_ui(sv: SessionVariables):
|
||||
"""Return report list UI component."""
|
||||
selection = st.dataframe(
|
||||
sv.community_reports.value,
|
||||
height=1000,
|
||||
hide_index=True,
|
||||
column_order=["id", "title"],
|
||||
selection_mode="single-row",
|
||||
on_select="rerun",
|
||||
)
|
||||
rows = selection.selection.rows
|
||||
if len(rows) > 0:
|
||||
report_index = selection.selection.rows[0]
|
||||
sv.selected_report.value = sv.community_reports.value.iloc[report_index]
|
||||
else:
|
||||
sv.selected_report.value = None
|
||||
284
unified-search-app/app/ui/search.py
Normal file
284
unified-search-app/app/ui/search.py
Normal file
@ -0,0 +1,284 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""Search module."""
|
||||
|
||||
import json
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
|
||||
import pandas as pd
|
||||
import streamlit as st
|
||||
from rag.typing import SearchResult, SearchType
|
||||
from streamlit.delta_generator import DeltaGenerator
|
||||
|
||||
|
||||
def init_search_ui(
|
||||
container: DeltaGenerator, search_type: SearchType, title: str, caption: str
|
||||
):
|
||||
"""Initialize search UI component."""
|
||||
with container:
|
||||
st.markdown(title)
|
||||
st.caption(caption)
|
||||
|
||||
ui_tag = search_type.value.lower()
|
||||
st.session_state[f"{ui_tag}_response_placeholder"] = st.empty()
|
||||
st.session_state[f"{ui_tag}_context_placeholder"] = st.empty()
|
||||
st.session_state[f"{ui_tag}_container"] = container
|
||||
|
||||
|
||||
@dataclass
|
||||
class SearchStats:
|
||||
"""SearchStats class definition."""
|
||||
|
||||
completion_time: float
|
||||
llm_calls: int
|
||||
prompt_tokens: int
|
||||
|
||||
|
||||
def display_search_result(
|
||||
container: DeltaGenerator, result: SearchResult, stats: SearchStats | None = None
|
||||
):
|
||||
"""Display search results data into the UI."""
|
||||
response_placeholder_attr = (
|
||||
result.search_type.value.lower() + "_response_placeholder"
|
||||
)
|
||||
|
||||
with container:
|
||||
# display response
|
||||
response = format_response_hyperlinks(
|
||||
result.response, result.search_type.value.lower()
|
||||
)
|
||||
|
||||
if stats is not None and stats.completion_time is not None:
|
||||
st.markdown(
|
||||
f"*{stats.prompt_tokens:,} tokens used, {stats.llm_calls} LLM calls, {int(stats.completion_time)} seconds elapsed.*"
|
||||
)
|
||||
st.session_state[response_placeholder_attr] = st.markdown(
|
||||
f"<div id='{result.search_type.value.lower()}-response'>{response}</div>",
|
||||
unsafe_allow_html=True,
|
||||
)
|
||||
|
||||
|
||||
def display_citations(
|
||||
container: DeltaGenerator | None = None, result: SearchResult | None = None
|
||||
):
|
||||
"""Display citations into the UI."""
|
||||
if container is not None:
|
||||
with container:
|
||||
# display context used for generating the response
|
||||
if result is not None:
|
||||
context_data = result.context
|
||||
context_data = dict(sorted(context_data.items()))
|
||||
|
||||
st.markdown("---")
|
||||
st.markdown("### Citations")
|
||||
for key, value in context_data.items():
|
||||
if len(value) > 0:
|
||||
key_type = key
|
||||
if key == "sources":
|
||||
st.markdown(
|
||||
f"Relevant chunks of source documents **({len(value)})**:"
|
||||
)
|
||||
key_type = "sources"
|
||||
elif key == "reports":
|
||||
st.markdown(
|
||||
f"Relevant AI-generated network reports **({len(value)})**:"
|
||||
)
|
||||
else:
|
||||
st.markdown(
|
||||
f"Relevant AI-extracted {key} **({len(value)})**:"
|
||||
)
|
||||
st.markdown(
|
||||
render_html_table(
|
||||
value, result.search_type.value.lower(), key_type
|
||||
),
|
||||
unsafe_allow_html=True,
|
||||
)
|
||||
|
||||
|
||||
def format_response_hyperlinks(str_response: str, search_type: str = ""):
|
||||
"""Format response to show hyperlinks inside the response UI."""
|
||||
results_with_hyperlinks = format_response_hyperlinks_by_key(
|
||||
str_response, "Entities", "Entities", search_type
|
||||
)
|
||||
results_with_hyperlinks = format_response_hyperlinks_by_key(
|
||||
results_with_hyperlinks, "Sources", "Sources", search_type
|
||||
)
|
||||
results_with_hyperlinks = format_response_hyperlinks_by_key(
|
||||
results_with_hyperlinks, "Documents", "Sources", search_type
|
||||
)
|
||||
results_with_hyperlinks = format_response_hyperlinks_by_key(
|
||||
results_with_hyperlinks, "Relationships", "Relationships", search_type
|
||||
)
|
||||
results_with_hyperlinks = format_response_hyperlinks_by_key(
|
||||
results_with_hyperlinks, "Reports", "Reports", search_type
|
||||
)
|
||||
|
||||
return results_with_hyperlinks # noqa: RET504
|
||||
|
||||
|
||||
def format_response_hyperlinks_by_key(
|
||||
str_response: str, key: str, anchor: str, search_type: str = ""
|
||||
):
|
||||
"""Format response to show hyperlinks inside the response UI by key."""
|
||||
pattern = r"\(\d+(?:,\s*\d+)*(?:,\s*\+more)?\)"
|
||||
|
||||
citations_list = re.findall(f"{key} {pattern}", str_response)
|
||||
|
||||
results_with_hyperlinks = str_response
|
||||
if len(citations_list) > 0:
|
||||
for occurrence in citations_list:
|
||||
string_occurrence = str(occurrence)
|
||||
numbers_list = string_occurrence[
|
||||
string_occurrence.find("(") + 1 : string_occurrence.find(")")
|
||||
].split(",")
|
||||
string_occurrence_hyperlinks = string_occurrence
|
||||
for number in numbers_list:
|
||||
if number.lower().strip() != "+more":
|
||||
string_occurrence_hyperlinks = string_occurrence_hyperlinks.replace(
|
||||
number,
|
||||
f'<a href="#{search_type.lower().strip()}-{anchor.lower().strip()}-{number.strip()}">{number}</a>',
|
||||
)
|
||||
|
||||
results_with_hyperlinks = results_with_hyperlinks.replace(
|
||||
occurrence, string_occurrence_hyperlinks
|
||||
)
|
||||
|
||||
return results_with_hyperlinks
|
||||
|
||||
|
||||
def format_suggested_questions(questions: str):
|
||||
"""Format suggested questions to the UI."""
|
||||
citations_pattern = r"\[.*?\]"
|
||||
substring = re.sub(citations_pattern, "", questions).strip()
|
||||
return convert_numbered_list_to_array(substring)
|
||||
|
||||
|
||||
def convert_numbered_list_to_array(numbered_list_str):
|
||||
"""Convert numbered list result into an array of elements."""
|
||||
lines = numbered_list_str.strip().split("\n")
|
||||
items = []
|
||||
|
||||
for line in lines:
|
||||
match = re.match(r"^\d+\.\s*(.*)", line)
|
||||
if match:
|
||||
item = match.group(1).strip()
|
||||
items.append(item)
|
||||
|
||||
return items
|
||||
|
||||
|
||||
def get_ids_per_key(str_response: str, key: str):
|
||||
"""Filter ids per key."""
|
||||
pattern = r"\(\d+(?:,\s*\d+)*(?:,\s*\+more)?\)"
|
||||
citations_list = re.findall(f"{key} {pattern}", str_response)
|
||||
numbers_list = []
|
||||
if len(citations_list) > 0:
|
||||
for occurrence in citations_list:
|
||||
string_occurrence = str(occurrence)
|
||||
numbers_list = string_occurrence[
|
||||
string_occurrence.find("(") + 1 : string_occurrence.find(")")
|
||||
].split(",")
|
||||
|
||||
return numbers_list
|
||||
|
||||
|
||||
SHORT_WORDS = 12
|
||||
LONG_WORDS = 200
|
||||
|
||||
|
||||
# Function to generate HTML table with ids
|
||||
def render_html_table(df: pd.DataFrame, search_type: str, key: str):
|
||||
"""Render HTML table into the UI."""
|
||||
table_container = """
|
||||
max-width: 100%;
|
||||
overflow: hidden;
|
||||
margin: 0 auto;
|
||||
"""
|
||||
|
||||
table_style = """
|
||||
width: 100%;
|
||||
border-collapse: collapse;
|
||||
table-layout: fixed;
|
||||
"""
|
||||
|
||||
th_style = """
|
||||
word-wrap: break-word;
|
||||
white-space: normal;
|
||||
"""
|
||||
|
||||
td_style = """
|
||||
border: 1px solid #efefef;
|
||||
word-wrap: break-word;
|
||||
white-space: normal;
|
||||
"""
|
||||
|
||||
table_html = f'<div style="{table_container}">'
|
||||
table_html += f'<table style="{table_style}">'
|
||||
|
||||
table_html += "<thead><tr>"
|
||||
for col in pd.DataFrame(df).columns:
|
||||
table_html += f'<th style="{th_style}">{col}</th>'
|
||||
table_html += "</tr></thead>"
|
||||
|
||||
table_html += "<tbody>"
|
||||
for index, row in pd.DataFrame(df).iterrows():
|
||||
html_id = (
|
||||
f"{search_type.lower().strip()}-{key.lower().strip()}-{row.id.strip()}"
|
||||
if "id" in row
|
||||
else f"row-{index}"
|
||||
)
|
||||
table_html += f'<tr id="{html_id}">'
|
||||
for value in row:
|
||||
if isinstance(value, str):
|
||||
if value[0:1] == "{":
|
||||
value_casted = json.loads(value)
|
||||
value = value_casted["summary"]
|
||||
value_array = str(value).split(" ")
|
||||
td_value = (
|
||||
" ".join(value_array[:SHORT_WORDS]) + "..."
|
||||
if len(value_array) >= SHORT_WORDS
|
||||
else value
|
||||
)
|
||||
title_value = (
|
||||
" ".join(value_array[:LONG_WORDS]) + "..."
|
||||
if len(value_array) >= LONG_WORDS
|
||||
else value
|
||||
)
|
||||
title_value = (
|
||||
title_value.replace('"', """)
|
||||
.replace("'", "'")
|
||||
.replace("\n", " ")
|
||||
.replace("\n\n", " ")
|
||||
.replace("\r\n", " ")
|
||||
)
|
||||
table_html += (
|
||||
f'<td style="{td_style}" title="{title_value}">{td_value}</td>'
|
||||
)
|
||||
else:
|
||||
table_html += f'<td style="{td_style}" title="{value}">{value}</td>'
|
||||
table_html += "</tr>"
|
||||
table_html += "</tbody></table></div>"
|
||||
|
||||
return table_html
|
||||
|
||||
|
||||
def display_graph_citations(
|
||||
entities: pd.DataFrame, relationships: pd.DataFrame, citation_type: str
|
||||
):
|
||||
"""Display graph citations into the UI."""
|
||||
st.markdown("---")
|
||||
st.markdown("### Citations")
|
||||
|
||||
st.markdown(f"Relevant AI-extracted entities **({len(entities)})**:")
|
||||
st.markdown(
|
||||
render_html_table(entities, citation_type, "entities"),
|
||||
unsafe_allow_html=True,
|
||||
)
|
||||
|
||||
st.markdown(f"Relevant AI-extracted relationships **({len(relationships)})**:")
|
||||
st.markdown(
|
||||
render_html_table(relationships, citation_type, "relationships"),
|
||||
unsafe_allow_html=True,
|
||||
)
|
||||
97
unified-search-app/app/ui/sidebar.py
Normal file
97
unified-search-app/app/ui/sidebar.py
Normal file
@ -0,0 +1,97 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""Sidebar module."""
|
||||
|
||||
import streamlit as st
|
||||
from app_logic import dataset_name, load_dataset
|
||||
from state.session_variables import SessionVariables
|
||||
|
||||
|
||||
def reset_app():
|
||||
"""Reset app to its original state."""
|
||||
st.cache_data.clear()
|
||||
st.session_state.clear()
|
||||
st.rerun()
|
||||
|
||||
|
||||
def update_dataset(sv: SessionVariables):
|
||||
"""Update dataset from the dropdown."""
|
||||
value = st.session_state[sv.dataset.key]
|
||||
st.cache_data.clear()
|
||||
if "response_lengths" not in st.session_state:
|
||||
st.session_state.response_lengths = []
|
||||
st.session_state.response_lengths = []
|
||||
load_dataset(value, sv)
|
||||
|
||||
|
||||
def update_basic_rag(sv: SessionVariables):
|
||||
"""Update basic rag state."""
|
||||
sv.include_basic_rag.value = st.session_state[sv.include_basic_rag.key]
|
||||
|
||||
|
||||
def update_drift_search(sv: SessionVariables):
|
||||
"""Update drift rag state."""
|
||||
sv.include_drift_search.value = st.session_state[sv.include_drift_search.key]
|
||||
|
||||
|
||||
def update_local_search(sv: SessionVariables):
|
||||
"""Update local rag state."""
|
||||
sv.include_local_search.value = st.session_state[sv.include_local_search.key]
|
||||
|
||||
|
||||
def update_global_search(sv: SessionVariables):
|
||||
"""Update global rag state."""
|
||||
sv.include_global_search.value = st.session_state[sv.include_global_search.key]
|
||||
|
||||
|
||||
def create_side_bar(sv: SessionVariables):
|
||||
"""Create a side bar panel.."""
|
||||
with st.sidebar:
|
||||
st.subheader("Options")
|
||||
|
||||
options = [d.key for d in sv.datasets.value]
|
||||
|
||||
def lookup_label(key: str):
|
||||
return dataset_name(key, sv)
|
||||
|
||||
st.selectbox(
|
||||
"Dataset",
|
||||
key=sv.dataset.key,
|
||||
on_change=update_dataset,
|
||||
kwargs={"sv": sv},
|
||||
options=options,
|
||||
format_func=lookup_label,
|
||||
)
|
||||
st.number_input(
|
||||
"Number of suggested questions",
|
||||
key=sv.suggested_questions.key,
|
||||
min_value=1,
|
||||
max_value=100,
|
||||
step=1,
|
||||
)
|
||||
st.subheader("Search options:")
|
||||
st.toggle(
|
||||
"Include basic RAG",
|
||||
key=sv.include_basic_rag.key,
|
||||
on_change=update_basic_rag,
|
||||
kwargs={"sv": sv},
|
||||
)
|
||||
st.toggle(
|
||||
"Include local search",
|
||||
key=sv.include_local_search.key,
|
||||
on_change=update_local_search,
|
||||
kwargs={"sv": sv},
|
||||
)
|
||||
st.toggle(
|
||||
"Include global search",
|
||||
key=sv.include_global_search.key,
|
||||
on_change=update_global_search,
|
||||
kwargs={"sv": sv},
|
||||
)
|
||||
st.toggle(
|
||||
"Include drift search",
|
||||
key=sv.include_drift_search.key,
|
||||
on_change=update_drift_search,
|
||||
kwargs={"sv": sv},
|
||||
)
|
||||
BIN
unified-search-app/images/image-1.png
Normal file
BIN
unified-search-app/images/image-1.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 378 KiB |
BIN
unified-search-app/images/image-2.png
Normal file
BIN
unified-search-app/images/image-2.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 451 KiB |
BIN
unified-search-app/images/image-3.png
Normal file
BIN
unified-search-app/images/image-3.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 486 KiB |
BIN
unified-search-app/images/image-4.png
Normal file
BIN
unified-search-app/images/image-4.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 550 KiB |
4662
unified-search-app/poetry.lock
generated
Normal file
4662
unified-search-app/poetry.lock
generated
Normal file
File diff suppressed because it is too large
Load Diff
37
unified-search-app/pyproject.toml
Normal file
37
unified-search-app/pyproject.toml
Normal file
@ -0,0 +1,37 @@
|
||||
[tool.poetry]
|
||||
name = "unified-copilot"
|
||||
version = "1.0.0"
|
||||
description = ""
|
||||
authors = ["GraphRAG team"]
|
||||
readme = "README.md"
|
||||
package-mode = false
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = ">=3.10,<3.12"
|
||||
streamlit = "1.43.0"
|
||||
azure-search-documents = "^11.4.0"
|
||||
azure-storage-blob = "^12.20.0"
|
||||
azure-identity = "^1.16.0"
|
||||
graphrag = "2.0.0"
|
||||
altair = "^5.3.0"
|
||||
streamlit-agraph = "^0.0.45"
|
||||
st-tabs = "^0.1.1"
|
||||
spacy = ">=3.8.4,<4.0.0"
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
poethepoet = "^0.26.1"
|
||||
ipykernel = "^6.29.4"
|
||||
pyright = "^1.1.349"
|
||||
ruff = "^0.4.7"
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core"]
|
||||
build-backend = "poetry.core.masonry.api"
|
||||
|
||||
[tool.poe.tasks]
|
||||
start = "streamlit run app/home_page.py"
|
||||
start_prod = "streamlit run app/home_page.py --server.port=8501 --server.address=0.0.0.0"
|
||||
|
||||
[tool.pyright]
|
||||
include = ["app"]
|
||||
exclude = ["**/node_modules", "**/__pycache__"]
|
||||
Loading…
Reference in New Issue
Block a user