mirror of
https://github.com/microsoft/graphrag.git
synced 2026-01-14 00:57:23 +08:00
Implement dynamic community selection for global search (#1396)
* update gitignore * add dynamic community sleection to updated main branch * update SearchResult to record output_tokens. * update search result * dynamic search working * format * add llm_calls_categories and prompt_tokens and output_tokens cate * update * formatting * log drift search output and prompt tokens separately * update global_search.ipynb. update operate dulce dataset and add create_final_communities. update dynamic community selection init * add .ipynb back to cspell.config.yaml * format * add notebook example on dynamic search * rearrange * update gitignore * format code * code format * code format * fix default variable --------- Co-authored-by: Bryan Li <bryanlimy@gmail.com>
This commit is contained in:
parent
ba50caab4d
commit
e53422366d
6
.gitignore
vendored
6
.gitignore
vendored
@ -50,3 +50,9 @@ site/
|
||||
docsite/
|
||||
.yarn/
|
||||
.pnp*
|
||||
|
||||
# PyCharm
|
||||
.idea/
|
||||
|
||||
# Jupyter notebook
|
||||
.ipynb_checkpoints/
|
||||
|
||||
@ -0,0 +1,4 @@
|
||||
{
|
||||
"type": "patch",
|
||||
"description": "Implement dynamic community selection to global search"
|
||||
}
|
||||
@ -2,7 +2,7 @@
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@ -12,7 +12,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@ -21,7 +21,11 @@
|
||||
"import pandas as pd\n",
|
||||
"import tiktoken\n",
|
||||
"\n",
|
||||
"from graphrag.query.indexer_adapters import read_indexer_entities, read_indexer_reports\n",
|
||||
"from graphrag.query.indexer_adapters import (\n",
|
||||
" read_indexer_communities,\n",
|
||||
" read_indexer_entities,\n",
|
||||
" read_indexer_reports,\n",
|
||||
")\n",
|
||||
"from graphrag.query.llm.oai.chat_openai import ChatOpenAI\n",
|
||||
"from graphrag.query.llm.oai.typing import OpenaiApiType\n",
|
||||
"from graphrag.query.structured_search.global_search.community_context import (\n",
|
||||
@ -62,7 +66,7 @@
|
||||
" max_retries=20,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"token_encoder = tiktoken.get_encoding(\"cl100k_base\")"
|
||||
"token_encoder = tiktoken.encoding_for_model(llm_model)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -72,17 +76,19 @@
|
||||
"### Load community reports as context for global search\n",
|
||||
"\n",
|
||||
"- Load all community reports in the `create_final_community_reports` table from the ire-indexing engine, to be used as context data for global search.\n",
|
||||
"- Load entities from the `create_final_nodes` and `create_final_entities` tables from the ire-indexing engine, to be used for calculating community weights for context ranking. Note that this is optional (if no entities are provided, we will not calculate community weights and only use the `rank` attribute in the community reports table for context ranking)"
|
||||
"- Load entities from the `create_final_nodes` and `create_final_entities` tables from the ire-indexing engine, to be used for calculating community weights for context ranking. Note that this is optional (if no entities are provided, we will not calculate community weights and only use the rank attribute in the community reports table for context ranking)\n",
|
||||
"- Load all communities in the `create_final_communites` table from the ire-indexing engine, to be used to reconstruct the community graph hierarchy for dynamic community selection."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 16,
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# parquet files generated from indexing pipeline\n",
|
||||
"INPUT_DIR = \"./inputs/operation dulce\"\n",
|
||||
"COMMUNITY_TABLE = \"create_final_communities\"\n",
|
||||
"COMMUNITY_REPORT_TABLE = \"create_final_community_reports\"\n",
|
||||
"ENTITY_TABLE = \"create_final_nodes\"\n",
|
||||
"ENTITY_EMBEDDING_TABLE = \"create_final_entities\"\n",
|
||||
@ -94,20 +100,191 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Total report count: 20\n",
|
||||
"Report count after filtering by community level 2: 17\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"<div>\n",
|
||||
"<style scoped>\n",
|
||||
" .dataframe tbody tr th:only-of-type {\n",
|
||||
" vertical-align: middle;\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" .dataframe tbody tr th {\n",
|
||||
" vertical-align: top;\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" .dataframe thead th {\n",
|
||||
" text-align: right;\n",
|
||||
" }\n",
|
||||
"</style>\n",
|
||||
"<table border=\"1\" class=\"dataframe\">\n",
|
||||
" <thead>\n",
|
||||
" <tr style=\"text-align: right;\">\n",
|
||||
" <th></th>\n",
|
||||
" <th>community</th>\n",
|
||||
" <th>full_content</th>\n",
|
||||
" <th>level</th>\n",
|
||||
" <th>rank</th>\n",
|
||||
" <th>title</th>\n",
|
||||
" <th>rank_explanation</th>\n",
|
||||
" <th>summary</th>\n",
|
||||
" <th>findings</th>\n",
|
||||
" <th>full_content_json</th>\n",
|
||||
" <th>id</th>\n",
|
||||
" </tr>\n",
|
||||
" </thead>\n",
|
||||
" <tbody>\n",
|
||||
" <tr>\n",
|
||||
" <th>0</th>\n",
|
||||
" <td>10</td>\n",
|
||||
" <td># Paranormal Military Squad at Dulce Base: Dec...</td>\n",
|
||||
" <td>1</td>\n",
|
||||
" <td>8.5</td>\n",
|
||||
" <td>Paranormal Military Squad at Dulce Base: Decod...</td>\n",
|
||||
" <td>The impact severity rating is high due to the ...</td>\n",
|
||||
" <td>The Paranormal Military Squad, stationed at Du...</td>\n",
|
||||
" <td>[{'explanation': 'Jordan is a central figure i...</td>\n",
|
||||
" <td>{\\n \"title\": \"Paranormal Military Squad at ...</td>\n",
|
||||
" <td>1ba2d200-dd26-4693-affe-a5539d0a0e0d</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>1</th>\n",
|
||||
" <td>11</td>\n",
|
||||
" <td># Dulce and Paranormal Military Squad Operatio...</td>\n",
|
||||
" <td>1</td>\n",
|
||||
" <td>8.5</td>\n",
|
||||
" <td>Dulce and Paranormal Military Squad Operations</td>\n",
|
||||
" <td>The impact severity rating is high due to the ...</td>\n",
|
||||
" <td>The community centers around Dulce, a secretiv...</td>\n",
|
||||
" <td>[{'explanation': 'Dulce is described as a top-...</td>\n",
|
||||
" <td>{\\n \"title\": \"Dulce and Paranormal Military...</td>\n",
|
||||
" <td>a8a530b0-ae6b-44ea-b11c-9f70d138298d</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>2</th>\n",
|
||||
" <td>12</td>\n",
|
||||
" <td># Paranormal Military Squad and Dulce Base Ope...</td>\n",
|
||||
" <td>1</td>\n",
|
||||
" <td>7.5</td>\n",
|
||||
" <td>Paranormal Military Squad and Dulce Base Opera...</td>\n",
|
||||
" <td>The impact severity rating is relatively high ...</td>\n",
|
||||
" <td>The community centers around the Paranormal Mi...</td>\n",
|
||||
" <td>[{'explanation': 'Taylor is a central figure w...</td>\n",
|
||||
" <td>{\\n \"title\": \"Paranormal Military Squad and...</td>\n",
|
||||
" <td>0478975b-c805-4cc1-b746-82f3e689e2f3</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>3</th>\n",
|
||||
" <td>13</td>\n",
|
||||
" <td># Mission Dynamics and Leadership: Cruz and Wa...</td>\n",
|
||||
" <td>1</td>\n",
|
||||
" <td>7.5</td>\n",
|
||||
" <td>Mission Dynamics and Leadership: Cruz and Wash...</td>\n",
|
||||
" <td>The impact severity rating is relatively high ...</td>\n",
|
||||
" <td>This report explores the intricate dynamics of...</td>\n",
|
||||
" <td>[{'explanation': 'Cruz is a central figure in ...</td>\n",
|
||||
" <td>{\\n \"title\": \"Mission Dynamics and Leadersh...</td>\n",
|
||||
" <td>b56f6e68-3951-4f07-8760-63700944a375</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>4</th>\n",
|
||||
" <td>14</td>\n",
|
||||
" <td># Dulce Base and Paranormal Military Squad: Br...</td>\n",
|
||||
" <td>1</td>\n",
|
||||
" <td>8.5</td>\n",
|
||||
" <td>Dulce Base and Paranormal Military Squad: Brid...</td>\n",
|
||||
" <td>The impact severity rating is high due to the ...</td>\n",
|
||||
" <td>The community centers around the Dulce Base, a...</td>\n",
|
||||
" <td>[{'explanation': 'Sam Rivera, a member of the ...</td>\n",
|
||||
" <td>{\\n \"title\": \"Dulce Base and Paranormal Mil...</td>\n",
|
||||
" <td>736e7006-d050-4abb-a122-00febf3f540f</td>\n",
|
||||
" </tr>\n",
|
||||
" </tbody>\n",
|
||||
"</table>\n",
|
||||
"</div>"
|
||||
],
|
||||
"text/plain": [
|
||||
" community full_content level rank \\\n",
|
||||
"0 10 # Paranormal Military Squad at Dulce Base: Dec... 1 8.5 \n",
|
||||
"1 11 # Dulce and Paranormal Military Squad Operatio... 1 8.5 \n",
|
||||
"2 12 # Paranormal Military Squad and Dulce Base Ope... 1 7.5 \n",
|
||||
"3 13 # Mission Dynamics and Leadership: Cruz and Wa... 1 7.5 \n",
|
||||
"4 14 # Dulce Base and Paranormal Military Squad: Br... 1 8.5 \n",
|
||||
"\n",
|
||||
" title \\\n",
|
||||
"0 Paranormal Military Squad at Dulce Base: Decod... \n",
|
||||
"1 Dulce and Paranormal Military Squad Operations \n",
|
||||
"2 Paranormal Military Squad and Dulce Base Opera... \n",
|
||||
"3 Mission Dynamics and Leadership: Cruz and Wash... \n",
|
||||
"4 Dulce Base and Paranormal Military Squad: Brid... \n",
|
||||
"\n",
|
||||
" rank_explanation \\\n",
|
||||
"0 The impact severity rating is high due to the ... \n",
|
||||
"1 The impact severity rating is high due to the ... \n",
|
||||
"2 The impact severity rating is relatively high ... \n",
|
||||
"3 The impact severity rating is relatively high ... \n",
|
||||
"4 The impact severity rating is high due to the ... \n",
|
||||
"\n",
|
||||
" summary \\\n",
|
||||
"0 The Paranormal Military Squad, stationed at Du... \n",
|
||||
"1 The community centers around Dulce, a secretiv... \n",
|
||||
"2 The community centers around the Paranormal Mi... \n",
|
||||
"3 This report explores the intricate dynamics of... \n",
|
||||
"4 The community centers around the Dulce Base, a... \n",
|
||||
"\n",
|
||||
" findings \\\n",
|
||||
"0 [{'explanation': 'Jordan is a central figure i... \n",
|
||||
"1 [{'explanation': 'Dulce is described as a top-... \n",
|
||||
"2 [{'explanation': 'Taylor is a central figure w... \n",
|
||||
"3 [{'explanation': 'Cruz is a central figure in ... \n",
|
||||
"4 [{'explanation': 'Sam Rivera, a member of the ... \n",
|
||||
"\n",
|
||||
" full_content_json \\\n",
|
||||
"0 {\\n \"title\": \"Paranormal Military Squad at ... \n",
|
||||
"1 {\\n \"title\": \"Dulce and Paranormal Military... \n",
|
||||
"2 {\\n \"title\": \"Paranormal Military Squad and... \n",
|
||||
"3 {\\n \"title\": \"Mission Dynamics and Leadersh... \n",
|
||||
"4 {\\n \"title\": \"Dulce Base and Paranormal Mil... \n",
|
||||
"\n",
|
||||
" id \n",
|
||||
"0 1ba2d200-dd26-4693-affe-a5539d0a0e0d \n",
|
||||
"1 a8a530b0-ae6b-44ea-b11c-9f70d138298d \n",
|
||||
"2 0478975b-c805-4cc1-b746-82f3e689e2f3 \n",
|
||||
"3 b56f6e68-3951-4f07-8760-63700944a375 \n",
|
||||
"4 736e7006-d050-4abb-a122-00febf3f540f "
|
||||
]
|
||||
},
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"community_df = pd.read_parquet(f\"{INPUT_DIR}/{COMMUNITY_TABLE}.parquet\")\n",
|
||||
"entity_df = pd.read_parquet(f\"{INPUT_DIR}/{ENTITY_TABLE}.parquet\")\n",
|
||||
"report_df = pd.read_parquet(f\"{INPUT_DIR}/{COMMUNITY_REPORT_TABLE}.parquet\")\n",
|
||||
"entity_embedding_df = pd.read_parquet(f\"{INPUT_DIR}/{ENTITY_EMBEDDING_TABLE}.parquet\")\n",
|
||||
"\n",
|
||||
"communities = read_indexer_communities(community_df, entity_df, report_df)\n",
|
||||
"reports = read_indexer_reports(report_df, entity_df, COMMUNITY_LEVEL)\n",
|
||||
"entities = read_indexer_entities(entity_df, entity_embedding_df, COMMUNITY_LEVEL)\n",
|
||||
"\n",
|
||||
"print(f\"Total report count: {len(report_df)}\")\n",
|
||||
"print(\n",
|
||||
" f\"Report count after filtering by community level {COMMUNITY_LEVEL}: {len(reports)}\"\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"report_df.head()"
|
||||
]
|
||||
},
|
||||
@ -120,12 +297,13 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 18,
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"context_builder = GlobalCommunityContext(\n",
|
||||
" community_reports=reports,\n",
|
||||
" communities=communities,\n",
|
||||
" entities=entities, # default to None if you don't want to use community weights for ranking\n",
|
||||
" token_encoder=token_encoder,\n",
|
||||
")"
|
||||
@ -140,7 +318,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 19,
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@ -171,7 +349,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 20,
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@ -192,12 +370,36 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 9,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"### Overview of Cosmic Vocalization\n",
|
||||
"\n",
|
||||
"Cosmic Vocalization is a phenomenon that has garnered significant attention from various individuals and groups. It is perceived as a cosmic event with potential implications for security and interstellar communication. The Paranormal Military Squad is actively engaged with Cosmic Vocalization, indicating its strategic importance in security measures [Data: Reports (6)].\n",
|
||||
"\n",
|
||||
"### Key Perspectives and Concerns\n",
|
||||
"\n",
|
||||
"1. **Strategic Engagement**: The Paranormal Military Squad's involvement suggests that Cosmic Vocalization is not only a subject of interest but also a matter of strategic importance. This engagement highlights the potential security implications of these cosmic phenomena [Data: Reports (6)].\n",
|
||||
"\n",
|
||||
"2. **Community Interest**: Within the community, Cosmic Vocalization is a focal point of interest. Alex Mercer, for instance, perceives it as part of an interstellar duet, which suggests a responsive and perhaps communicative approach to these cosmic events [Data: Reports (6)].\n",
|
||||
"\n",
|
||||
"3. **Potential Threats**: Concerns have been raised by individuals like Taylor Cruz, who fears that Cosmic Vocalization might be a homing tune. This perspective adds a layer of urgency and suggests that there may be potential threats associated with these cosmic sounds [Data: Reports (6)].\n",
|
||||
"\n",
|
||||
"### Metaphorical Interpretation\n",
|
||||
"\n",
|
||||
"The Universe is metaphorically treated as a concert hall by the Paranormal Military Squad, which suggests a broader perspective on how cosmic events are interpreted and responded to by human entities. This metaphorical view may influence how strategies and responses are formulated in relation to Cosmic Vocalization [Data: Reports (6)].\n",
|
||||
"\n",
|
||||
"In summary, Cosmic Vocalization is a complex phenomenon involving strategic, communicative, and potentially threatening elements. The involvement of the Paranormal Military Squad and the concerns raised by community members underscore its significance and the need for careful consideration of its implications.\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"result = await search_engine.asearch(\n",
|
||||
" \"What is the major conflict in this story and who are the protagonist and antagonist?\"\n",
|
||||
" \"What is Cosmic Vocalization and who are involved in it?\"\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"print(result.response)"
|
||||
@ -205,9 +407,223 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 10,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"<div>\n",
|
||||
"<style scoped>\n",
|
||||
" .dataframe tbody tr th:only-of-type {\n",
|
||||
" vertical-align: middle;\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" .dataframe tbody tr th {\n",
|
||||
" vertical-align: top;\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" .dataframe thead th {\n",
|
||||
" text-align: right;\n",
|
||||
" }\n",
|
||||
"</style>\n",
|
||||
"<table border=\"1\" class=\"dataframe\">\n",
|
||||
" <thead>\n",
|
||||
" <tr style=\"text-align: right;\">\n",
|
||||
" <th></th>\n",
|
||||
" <th>id</th>\n",
|
||||
" <th>title</th>\n",
|
||||
" <th>occurrence weight</th>\n",
|
||||
" <th>content</th>\n",
|
||||
" <th>rank</th>\n",
|
||||
" </tr>\n",
|
||||
" </thead>\n",
|
||||
" <tbody>\n",
|
||||
" <tr>\n",
|
||||
" <th>0</th>\n",
|
||||
" <td>15</td>\n",
|
||||
" <td>Dulce Base and the Paranormal Military Squad: ...</td>\n",
|
||||
" <td>1.00</td>\n",
|
||||
" <td># Dulce Base and the Paranormal Military Squad...</td>\n",
|
||||
" <td>9.5</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>1</th>\n",
|
||||
" <td>11</td>\n",
|
||||
" <td>Dulce and Paranormal Military Squad Operations</td>\n",
|
||||
" <td>0.30</td>\n",
|
||||
" <td># Dulce and Paranormal Military Squad Operatio...</td>\n",
|
||||
" <td>8.5</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>2</th>\n",
|
||||
" <td>10</td>\n",
|
||||
" <td>Paranormal Military Squad at Dulce Base: Decod...</td>\n",
|
||||
" <td>0.30</td>\n",
|
||||
" <td># Paranormal Military Squad at Dulce Base: Dec...</td>\n",
|
||||
" <td>8.5</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>3</th>\n",
|
||||
" <td>7</td>\n",
|
||||
" <td>Operation: Dulce and the Paranormal Military S...</td>\n",
|
||||
" <td>0.20</td>\n",
|
||||
" <td># Operation: Dulce and the Paranormal Military...</td>\n",
|
||||
" <td>8.5</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>4</th>\n",
|
||||
" <td>8</td>\n",
|
||||
" <td>Dr. Jordan Hayes and the Paranormal Military S...</td>\n",
|
||||
" <td>0.18</td>\n",
|
||||
" <td># Dr. Jordan Hayes and the Paranormal Military...</td>\n",
|
||||
" <td>8.5</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>5</th>\n",
|
||||
" <td>1</td>\n",
|
||||
" <td>Earth's Interstellar Communication Initiative</td>\n",
|
||||
" <td>0.16</td>\n",
|
||||
" <td># Earth's Interstellar Communication Initiativ...</td>\n",
|
||||
" <td>8.5</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>6</th>\n",
|
||||
" <td>12</td>\n",
|
||||
" <td>Paranormal Military Squad and Dulce Base Opera...</td>\n",
|
||||
" <td>0.16</td>\n",
|
||||
" <td># Paranormal Military Squad and Dulce Base Ope...</td>\n",
|
||||
" <td>7.5</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>7</th>\n",
|
||||
" <td>13</td>\n",
|
||||
" <td>Mission Dynamics and Leadership: Cruz and Wash...</td>\n",
|
||||
" <td>0.16</td>\n",
|
||||
" <td># Mission Dynamics and Leadership: Cruz and Wa...</td>\n",
|
||||
" <td>7.5</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>8</th>\n",
|
||||
" <td>14</td>\n",
|
||||
" <td>Dulce Base and Paranormal Military Squad: Brid...</td>\n",
|
||||
" <td>0.12</td>\n",
|
||||
" <td># Dulce Base and Paranormal Military Squad: Br...</td>\n",
|
||||
" <td>8.5</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>9</th>\n",
|
||||
" <td>16</td>\n",
|
||||
" <td>Dulce Military Base and Alien Intelligence Com...</td>\n",
|
||||
" <td>0.08</td>\n",
|
||||
" <td># Dulce Military Base and Alien Intelligence C...</td>\n",
|
||||
" <td>8.5</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>10</th>\n",
|
||||
" <td>18</td>\n",
|
||||
" <td>Paranormal Military Squad Team and Dulce Base'...</td>\n",
|
||||
" <td>0.04</td>\n",
|
||||
" <td># Paranormal Military Squad Team and Dulce Bas...</td>\n",
|
||||
" <td>8.5</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>11</th>\n",
|
||||
" <td>5</td>\n",
|
||||
" <td>Alien Script and First Contact Operations</td>\n",
|
||||
" <td>0.02</td>\n",
|
||||
" <td># Alien Script and First Contact Operations\\n\\...</td>\n",
|
||||
" <td>8.5</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>12</th>\n",
|
||||
" <td>4</td>\n",
|
||||
" <td>Dulce Facility and Control Room of Dulce: Extr...</td>\n",
|
||||
" <td>0.02</td>\n",
|
||||
" <td># Dulce Facility and Control Room of Dulce: Ex...</td>\n",
|
||||
" <td>8.5</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>13</th>\n",
|
||||
" <td>17</td>\n",
|
||||
" <td>Dulce Team and Underground Command Center: Int...</td>\n",
|
||||
" <td>0.02</td>\n",
|
||||
" <td># Dulce Team and Underground Command Center: I...</td>\n",
|
||||
" <td>8.5</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>14</th>\n",
|
||||
" <td>19</td>\n",
|
||||
" <td>Central Terminal and Viewing Monitors at Dulce...</td>\n",
|
||||
" <td>0.02</td>\n",
|
||||
" <td># Central Terminal and Viewing Monitors at Dul...</td>\n",
|
||||
" <td>8.5</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>15</th>\n",
|
||||
" <td>6</td>\n",
|
||||
" <td>Cosmic Vocalization and Universe Interactions</td>\n",
|
||||
" <td>0.02</td>\n",
|
||||
" <td># Cosmic Vocalization and Universe Interaction...</td>\n",
|
||||
" <td>7.5</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>16</th>\n",
|
||||
" <td>9</td>\n",
|
||||
" <td>Dulce Base Exploration by TEAM and MAINFRAME ROOM</td>\n",
|
||||
" <td>0.02</td>\n",
|
||||
" <td># Dulce Base Exploration by TEAM and MAINFRAME...</td>\n",
|
||||
" <td>7.5</td>\n",
|
||||
" </tr>\n",
|
||||
" </tbody>\n",
|
||||
"</table>\n",
|
||||
"</div>"
|
||||
],
|
||||
"text/plain": [
|
||||
" id title occurrence weight \\\n",
|
||||
"0 15 Dulce Base and the Paranormal Military Squad: ... 1.00 \n",
|
||||
"1 11 Dulce and Paranormal Military Squad Operations 0.30 \n",
|
||||
"2 10 Paranormal Military Squad at Dulce Base: Decod... 0.30 \n",
|
||||
"3 7 Operation: Dulce and the Paranormal Military S... 0.20 \n",
|
||||
"4 8 Dr. Jordan Hayes and the Paranormal Military S... 0.18 \n",
|
||||
"5 1 Earth's Interstellar Communication Initiative 0.16 \n",
|
||||
"6 12 Paranormal Military Squad and Dulce Base Opera... 0.16 \n",
|
||||
"7 13 Mission Dynamics and Leadership: Cruz and Wash... 0.16 \n",
|
||||
"8 14 Dulce Base and Paranormal Military Squad: Brid... 0.12 \n",
|
||||
"9 16 Dulce Military Base and Alien Intelligence Com... 0.08 \n",
|
||||
"10 18 Paranormal Military Squad Team and Dulce Base'... 0.04 \n",
|
||||
"11 5 Alien Script and First Contact Operations 0.02 \n",
|
||||
"12 4 Dulce Facility and Control Room of Dulce: Extr... 0.02 \n",
|
||||
"13 17 Dulce Team and Underground Command Center: Int... 0.02 \n",
|
||||
"14 19 Central Terminal and Viewing Monitors at Dulce... 0.02 \n",
|
||||
"15 6 Cosmic Vocalization and Universe Interactions 0.02 \n",
|
||||
"16 9 Dulce Base Exploration by TEAM and MAINFRAME ROOM 0.02 \n",
|
||||
"\n",
|
||||
" content rank \n",
|
||||
"0 # Dulce Base and the Paranormal Military Squad... 9.5 \n",
|
||||
"1 # Dulce and Paranormal Military Squad Operatio... 8.5 \n",
|
||||
"2 # Paranormal Military Squad at Dulce Base: Dec... 8.5 \n",
|
||||
"3 # Operation: Dulce and the Paranormal Military... 8.5 \n",
|
||||
"4 # Dr. Jordan Hayes and the Paranormal Military... 8.5 \n",
|
||||
"5 # Earth's Interstellar Communication Initiativ... 8.5 \n",
|
||||
"6 # Paranormal Military Squad and Dulce Base Ope... 7.5 \n",
|
||||
"7 # Mission Dynamics and Leadership: Cruz and Wa... 7.5 \n",
|
||||
"8 # Dulce Base and Paranormal Military Squad: Br... 8.5 \n",
|
||||
"9 # Dulce Military Base and Alien Intelligence C... 8.5 \n",
|
||||
"10 # Paranormal Military Squad Team and Dulce Bas... 8.5 \n",
|
||||
"11 # Alien Script and First Contact Operations\\n\\... 8.5 \n",
|
||||
"12 # Dulce Facility and Control Room of Dulce: Ex... 8.5 \n",
|
||||
"13 # Dulce Team and Underground Command Center: I... 8.5 \n",
|
||||
"14 # Central Terminal and Viewing Monitors at Dul... 8.5 \n",
|
||||
"15 # Cosmic Vocalization and Universe Interaction... 7.5 \n",
|
||||
"16 # Dulce Base Exploration by TEAM and MAINFRAME... 7.5 "
|
||||
]
|
||||
},
|
||||
"execution_count": 10,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# inspect the data used to build the context for the LLM responses\n",
|
||||
"result.context_data[\"reports\"]"
|
||||
@ -215,18 +631,28 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 11,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"LLM calls: 2. Prompt tokens: 11292. Output tokens: 606.\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# inspect number of LLM calls and tokens\n",
|
||||
"print(f\"LLM calls: {result.llm_calls}. LLM tokens: {result.prompt_tokens}\")"
|
||||
"print(\n",
|
||||
" f\"LLM calls: {result.llm_calls}. Prompt tokens: {result.prompt_tokens}. Output tokens: {result.output_tokens}.\"\n",
|
||||
")"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"display_name": "graphrag",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
@ -240,7 +666,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.12"
|
||||
"version": "3.12.5"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
@ -0,0 +1,615 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Copyright (c) 2024 Microsoft Corporation.\n",
|
||||
"# Licensed under the MIT License."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"\n",
|
||||
"import pandas as pd\n",
|
||||
"import tiktoken\n",
|
||||
"\n",
|
||||
"from graphrag.query.indexer_adapters import (\n",
|
||||
" read_indexer_communities,\n",
|
||||
" read_indexer_entities,\n",
|
||||
" read_indexer_reports,\n",
|
||||
")\n",
|
||||
"from graphrag.query.llm.oai.chat_openai import ChatOpenAI\n",
|
||||
"from graphrag.query.llm.oai.typing import OpenaiApiType\n",
|
||||
"from graphrag.query.structured_search.global_search.community_context import (\n",
|
||||
" GlobalCommunityContext,\n",
|
||||
")\n",
|
||||
"from graphrag.query.structured_search.global_search.search import GlobalSearch"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Global Search example\n",
|
||||
"\n",
|
||||
"Global search method generates answers by searching over all AI-generated community reports in a map-reduce fashion. This is a resource-intensive method, but often gives good responses for questions that require an understanding of the dataset as a whole (e.g. What are the most significant values of the herbs mentioned in this notebook?)."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### LLM setup"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"api_key = os.environ[\"GRAPHRAG_API_KEY\"]\n",
|
||||
"llm_model = os.environ[\"GRAPHRAG_LLM_MODEL\"]\n",
|
||||
"\n",
|
||||
"llm = ChatOpenAI(\n",
|
||||
" api_key=api_key,\n",
|
||||
" model=llm_model,\n",
|
||||
" api_type=OpenaiApiType.OpenAI, # OpenaiApiType.OpenAI or OpenaiApiType.AzureOpenAI\n",
|
||||
" max_retries=20,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"token_encoder = tiktoken.encoding_for_model(llm_model)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Load community reports as context for global search\n",
|
||||
"\n",
|
||||
"- Load all community reports in the `create_final_community_reports` table from the ire-indexing engine, to be used as context data for global search.\n",
|
||||
"- Load entities from the `create_final_nodes` and `create_final_entities` tables from the ire-indexing engine, to be used for calculating community weights for context ranking. Note that this is optional (if no entities are provided, we will not calculate community weights and only use the rank attribute in the community reports table for context ranking)\n",
|
||||
"- Load all communities in the `create_final_communites` table from the ire-indexing engine, to be used to reconstruct the community graph hierarchy for dynamic community selection."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# parquet files generated from indexing pipeline\n",
|
||||
"INPUT_DIR = \"./inputs/operation dulce\"\n",
|
||||
"COMMUNITY_TABLE = \"create_final_communities\"\n",
|
||||
"COMMUNITY_REPORT_TABLE = \"create_final_community_reports\"\n",
|
||||
"ENTITY_TABLE = \"create_final_nodes\"\n",
|
||||
"ENTITY_EMBEDDING_TABLE = \"create_final_entities\"\n",
|
||||
"\n",
|
||||
"# we don't fix a specific community level but instead use an agent to dynamicially\n",
|
||||
"# search through all the community reports to check if they are relevant.\n",
|
||||
"COMMUNITY_LEVEL = None"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Total report count: 20\n",
|
||||
"Report count after filtering by community level None: 20\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"<div>\n",
|
||||
"<style scoped>\n",
|
||||
" .dataframe tbody tr th:only-of-type {\n",
|
||||
" vertical-align: middle;\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" .dataframe tbody tr th {\n",
|
||||
" vertical-align: top;\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" .dataframe thead th {\n",
|
||||
" text-align: right;\n",
|
||||
" }\n",
|
||||
"</style>\n",
|
||||
"<table border=\"1\" class=\"dataframe\">\n",
|
||||
" <thead>\n",
|
||||
" <tr style=\"text-align: right;\">\n",
|
||||
" <th></th>\n",
|
||||
" <th>community</th>\n",
|
||||
" <th>full_content</th>\n",
|
||||
" <th>level</th>\n",
|
||||
" <th>rank</th>\n",
|
||||
" <th>title</th>\n",
|
||||
" <th>rank_explanation</th>\n",
|
||||
" <th>summary</th>\n",
|
||||
" <th>findings</th>\n",
|
||||
" <th>full_content_json</th>\n",
|
||||
" <th>id</th>\n",
|
||||
" </tr>\n",
|
||||
" </thead>\n",
|
||||
" <tbody>\n",
|
||||
" <tr>\n",
|
||||
" <th>0</th>\n",
|
||||
" <td>10</td>\n",
|
||||
" <td># Paranormal Military Squad at Dulce Base: Dec...</td>\n",
|
||||
" <td>1</td>\n",
|
||||
" <td>8.5</td>\n",
|
||||
" <td>Paranormal Military Squad at Dulce Base: Decod...</td>\n",
|
||||
" <td>The impact severity rating is high due to the ...</td>\n",
|
||||
" <td>The Paranormal Military Squad, stationed at Du...</td>\n",
|
||||
" <td>[{'explanation': 'Jordan is a central figure i...</td>\n",
|
||||
" <td>{\\n \"title\": \"Paranormal Military Squad at ...</td>\n",
|
||||
" <td>1ba2d200-dd26-4693-affe-a5539d0a0e0d</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>1</th>\n",
|
||||
" <td>11</td>\n",
|
||||
" <td># Dulce and Paranormal Military Squad Operatio...</td>\n",
|
||||
" <td>1</td>\n",
|
||||
" <td>8.5</td>\n",
|
||||
" <td>Dulce and Paranormal Military Squad Operations</td>\n",
|
||||
" <td>The impact severity rating is high due to the ...</td>\n",
|
||||
" <td>The community centers around Dulce, a secretiv...</td>\n",
|
||||
" <td>[{'explanation': 'Dulce is described as a top-...</td>\n",
|
||||
" <td>{\\n \"title\": \"Dulce and Paranormal Military...</td>\n",
|
||||
" <td>a8a530b0-ae6b-44ea-b11c-9f70d138298d</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>2</th>\n",
|
||||
" <td>12</td>\n",
|
||||
" <td># Paranormal Military Squad and Dulce Base Ope...</td>\n",
|
||||
" <td>1</td>\n",
|
||||
" <td>7.5</td>\n",
|
||||
" <td>Paranormal Military Squad and Dulce Base Opera...</td>\n",
|
||||
" <td>The impact severity rating is relatively high ...</td>\n",
|
||||
" <td>The community centers around the Paranormal Mi...</td>\n",
|
||||
" <td>[{'explanation': 'Taylor is a central figure w...</td>\n",
|
||||
" <td>{\\n \"title\": \"Paranormal Military Squad and...</td>\n",
|
||||
" <td>0478975b-c805-4cc1-b746-82f3e689e2f3</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>3</th>\n",
|
||||
" <td>13</td>\n",
|
||||
" <td># Mission Dynamics and Leadership: Cruz and Wa...</td>\n",
|
||||
" <td>1</td>\n",
|
||||
" <td>7.5</td>\n",
|
||||
" <td>Mission Dynamics and Leadership: Cruz and Wash...</td>\n",
|
||||
" <td>The impact severity rating is relatively high ...</td>\n",
|
||||
" <td>This report explores the intricate dynamics of...</td>\n",
|
||||
" <td>[{'explanation': 'Cruz is a central figure in ...</td>\n",
|
||||
" <td>{\\n \"title\": \"Mission Dynamics and Leadersh...</td>\n",
|
||||
" <td>b56f6e68-3951-4f07-8760-63700944a375</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>4</th>\n",
|
||||
" <td>14</td>\n",
|
||||
" <td># Dulce Base and Paranormal Military Squad: Br...</td>\n",
|
||||
" <td>1</td>\n",
|
||||
" <td>8.5</td>\n",
|
||||
" <td>Dulce Base and Paranormal Military Squad: Brid...</td>\n",
|
||||
" <td>The impact severity rating is high due to the ...</td>\n",
|
||||
" <td>The community centers around the Dulce Base, a...</td>\n",
|
||||
" <td>[{'explanation': 'Sam Rivera, a member of the ...</td>\n",
|
||||
" <td>{\\n \"title\": \"Dulce Base and Paranormal Mil...</td>\n",
|
||||
" <td>736e7006-d050-4abb-a122-00febf3f540f</td>\n",
|
||||
" </tr>\n",
|
||||
" </tbody>\n",
|
||||
"</table>\n",
|
||||
"</div>"
|
||||
],
|
||||
"text/plain": [
|
||||
" community full_content level rank \\\n",
|
||||
"0 10 # Paranormal Military Squad at Dulce Base: Dec... 1 8.5 \n",
|
||||
"1 11 # Dulce and Paranormal Military Squad Operatio... 1 8.5 \n",
|
||||
"2 12 # Paranormal Military Squad and Dulce Base Ope... 1 7.5 \n",
|
||||
"3 13 # Mission Dynamics and Leadership: Cruz and Wa... 1 7.5 \n",
|
||||
"4 14 # Dulce Base and Paranormal Military Squad: Br... 1 8.5 \n",
|
||||
"\n",
|
||||
" title \\\n",
|
||||
"0 Paranormal Military Squad at Dulce Base: Decod... \n",
|
||||
"1 Dulce and Paranormal Military Squad Operations \n",
|
||||
"2 Paranormal Military Squad and Dulce Base Opera... \n",
|
||||
"3 Mission Dynamics and Leadership: Cruz and Wash... \n",
|
||||
"4 Dulce Base and Paranormal Military Squad: Brid... \n",
|
||||
"\n",
|
||||
" rank_explanation \\\n",
|
||||
"0 The impact severity rating is high due to the ... \n",
|
||||
"1 The impact severity rating is high due to the ... \n",
|
||||
"2 The impact severity rating is relatively high ... \n",
|
||||
"3 The impact severity rating is relatively high ... \n",
|
||||
"4 The impact severity rating is high due to the ... \n",
|
||||
"\n",
|
||||
" summary \\\n",
|
||||
"0 The Paranormal Military Squad, stationed at Du... \n",
|
||||
"1 The community centers around Dulce, a secretiv... \n",
|
||||
"2 The community centers around the Paranormal Mi... \n",
|
||||
"3 This report explores the intricate dynamics of... \n",
|
||||
"4 The community centers around the Dulce Base, a... \n",
|
||||
"\n",
|
||||
" findings \\\n",
|
||||
"0 [{'explanation': 'Jordan is a central figure i... \n",
|
||||
"1 [{'explanation': 'Dulce is described as a top-... \n",
|
||||
"2 [{'explanation': 'Taylor is a central figure w... \n",
|
||||
"3 [{'explanation': 'Cruz is a central figure in ... \n",
|
||||
"4 [{'explanation': 'Sam Rivera, a member of the ... \n",
|
||||
"\n",
|
||||
" full_content_json \\\n",
|
||||
"0 {\\n \"title\": \"Paranormal Military Squad at ... \n",
|
||||
"1 {\\n \"title\": \"Dulce and Paranormal Military... \n",
|
||||
"2 {\\n \"title\": \"Paranormal Military Squad and... \n",
|
||||
"3 {\\n \"title\": \"Mission Dynamics and Leadersh... \n",
|
||||
"4 {\\n \"title\": \"Dulce Base and Paranormal Mil... \n",
|
||||
"\n",
|
||||
" id \n",
|
||||
"0 1ba2d200-dd26-4693-affe-a5539d0a0e0d \n",
|
||||
"1 a8a530b0-ae6b-44ea-b11c-9f70d138298d \n",
|
||||
"2 0478975b-c805-4cc1-b746-82f3e689e2f3 \n",
|
||||
"3 b56f6e68-3951-4f07-8760-63700944a375 \n",
|
||||
"4 736e7006-d050-4abb-a122-00febf3f540f "
|
||||
]
|
||||
},
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"community_df = pd.read_parquet(f\"{INPUT_DIR}/{COMMUNITY_TABLE}.parquet\")\n",
|
||||
"entity_df = pd.read_parquet(f\"{INPUT_DIR}/{ENTITY_TABLE}.parquet\")\n",
|
||||
"report_df = pd.read_parquet(f\"{INPUT_DIR}/{COMMUNITY_REPORT_TABLE}.parquet\")\n",
|
||||
"entity_embedding_df = pd.read_parquet(f\"{INPUT_DIR}/{ENTITY_EMBEDDING_TABLE}.parquet\")\n",
|
||||
"\n",
|
||||
"communities = read_indexer_communities(community_df, entity_df, report_df)\n",
|
||||
"reports = read_indexer_reports(\n",
|
||||
" report_df,\n",
|
||||
" entity_df,\n",
|
||||
" community_level=COMMUNITY_LEVEL,\n",
|
||||
" dynamic_community_selection=True,\n",
|
||||
")\n",
|
||||
"entities = read_indexer_entities(\n",
|
||||
" entity_df, entity_embedding_df, community_level=COMMUNITY_LEVEL\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"print(f\"Total report count: {len(report_df)}\")\n",
|
||||
"print(\n",
|
||||
" f\"Report count after filtering by community level {COMMUNITY_LEVEL}: {len(reports)}\"\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"report_df.head()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"#### Build global context with dynamic community selection\n",
|
||||
"\n",
|
||||
"The goal of dynamic community selection reduce the number of community reports that need to be processed in the map-reduce operation. To that end, we take advantage of the hierachical structure of the indexed dataset. We first ask the LLM to rate how relevant each level 0 community is with respect to the user query, we then traverse down the child node(s) if the current community report is deemed relevant.\n",
|
||||
"\n",
|
||||
"You can still set a `COMMUNITY_LEVEL` to filter out lower level community reports and apply dynamic community selection on the filtered reports.\n",
|
||||
"\n",
|
||||
"Note that the dataset is quite small, with only consist of 20 communities from 2 levels (level 0 and 1). Dynamic community selection is more effective when there are large amount of content to be filtered out."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"mini_llm = ChatOpenAI(\n",
|
||||
" api_key=api_key,\n",
|
||||
" model=\"gpt-4o-mini\",\n",
|
||||
" api_type=OpenaiApiType.OpenAI, # OpenaiApiType.OpenAI or OpenaiApiType.AzureOpenAI\n",
|
||||
" max_retries=20,\n",
|
||||
")\n",
|
||||
"mini_token_encoder = tiktoken.encoding_for_model(mini_llm.model)\n",
|
||||
"\n",
|
||||
"context_builder = GlobalCommunityContext(\n",
|
||||
" community_reports=reports,\n",
|
||||
" communities=communities,\n",
|
||||
" entities=entities, # default to None if you don't want to use community weights for ranking\n",
|
||||
" token_encoder=token_encoder,\n",
|
||||
" dynamic_community_selection=True,\n",
|
||||
" dynamic_community_selection_kwargs={\n",
|
||||
" \"llm\": mini_llm,\n",
|
||||
" \"token_encoder\": mini_token_encoder,\n",
|
||||
" },\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"#### Perform global search with dynamic community selection"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"context_builder_params = {\n",
|
||||
" \"use_community_summary\": False, # False means using full community reports. True means using community short summaries.\n",
|
||||
" \"shuffle_data\": True,\n",
|
||||
" \"include_community_rank\": True,\n",
|
||||
" \"min_community_rank\": 0,\n",
|
||||
" \"community_rank_name\": \"rank\",\n",
|
||||
" \"include_community_weight\": True,\n",
|
||||
" \"community_weight_name\": \"occurrence weight\",\n",
|
||||
" \"normalize_community_weight\": True,\n",
|
||||
" \"max_tokens\": 12_000, # change this based on the token limit you have on your model (if you are using a model with 8k limit, a good setting could be 5000)\n",
|
||||
" \"context_name\": \"Reports\",\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"map_llm_params = {\n",
|
||||
" \"max_tokens\": 1000,\n",
|
||||
" \"temperature\": 0.0,\n",
|
||||
" \"response_format\": {\"type\": \"json_object\"},\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"reduce_llm_params = {\n",
|
||||
" \"max_tokens\": 2000, # change this based on the token limit you have on your model (if you are using a model with 8k limit, a good setting could be 1000-1500)\n",
|
||||
" \"temperature\": 0.0,\n",
|
||||
"}"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"search_engine = GlobalSearch(\n",
|
||||
" llm=llm,\n",
|
||||
" context_builder=context_builder,\n",
|
||||
" token_encoder=token_encoder,\n",
|
||||
" max_data_tokens=12_000, # change this based on the token limit you have on your model (if you are using a model with 8k limit, a good setting could be 5000)\n",
|
||||
" map_llm_params=map_llm_params,\n",
|
||||
" reduce_llm_params=reduce_llm_params,\n",
|
||||
" allow_general_knowledge=False, # set this to True will add instruction to encourage the LLM to incorporate general knowledge in the response, which may increase hallucinations, but could be useful in some use cases.\n",
|
||||
" json_mode=True, # set this to False if your LLM model does not support JSON mode.\n",
|
||||
" context_builder_params=context_builder_params,\n",
|
||||
" concurrent_coroutines=32,\n",
|
||||
" response_type=\"multiple paragraphs\", # free form text describing the response type and format, can be anything, e.g. prioritized list, single paragraph, multiple paragraphs, multiple-page report\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"### Overview of Cosmic Vocalization\n",
|
||||
"\n",
|
||||
"Cosmic Vocalization is a phenomenon that has captured the attention of various individuals and groups, becoming a focal point for community interest. It is perceived as a significant cosmic event, with interpretations ranging from a strategic security concern to a metaphorical interstellar duet [Data: Reports (6)].\n",
|
||||
"\n",
|
||||
"### Key Stakeholders and Perspectives\n",
|
||||
"\n",
|
||||
"1. **Paranormal Military Squad**: This group is actively engaged with Cosmic Vocalization, treating it as a strategic element in their security measures. Their involvement underscores the importance of Cosmic Vocalization in broader security contexts. They metaphorically view the Universe as a concert hall, suggesting a unique perspective on cosmic events and their implications for human entities [Data: Reports (6)].\n",
|
||||
"\n",
|
||||
"2. **Alex Mercer**: Alex Mercer perceives Cosmic Vocalization as part of an interstellar duet, indicating a responsive and perhaps artistic approach to understanding these cosmic phenomena. This perspective highlights the diverse interpretations and cultural significance attributed to Cosmic Vocalization [Data: Reports (6)].\n",
|
||||
"\n",
|
||||
"3. **Taylor Cruz**: Taylor Cruz expresses concerns about Cosmic Vocalization, fearing it might serve as a homing tune. This perspective introduces a layer of urgency and potential threat, suggesting that Cosmic Vocalization could have implications beyond mere observation, possibly affecting security or existential considerations [Data: Reports (6)].\n",
|
||||
"\n",
|
||||
"### Implications\n",
|
||||
"\n",
|
||||
"The involvement of these stakeholders and their varied perspectives on Cosmic Vocalization illustrate the complexity and multifaceted nature of this phenomenon. It is not only a subject of scientific and strategic interest but also a cultural and existential topic that prompts diverse interpretations and responses. The strategic engagement by the Paranormal Military Squad and the concerns raised by individuals like Taylor Cruz highlight the potential significance of Cosmic Vocalization in both security and broader cosmic contexts [Data: Reports (6)].\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"result = await search_engine.asearch(\n",
|
||||
" \"What is Cosmic Vocalization and who are involved in it?\"\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"print(result.response)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"<div>\n",
|
||||
"<style scoped>\n",
|
||||
" .dataframe tbody tr th:only-of-type {\n",
|
||||
" vertical-align: middle;\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" .dataframe tbody tr th {\n",
|
||||
" vertical-align: top;\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" .dataframe thead th {\n",
|
||||
" text-align: right;\n",
|
||||
" }\n",
|
||||
"</style>\n",
|
||||
"<table border=\"1\" class=\"dataframe\">\n",
|
||||
" <thead>\n",
|
||||
" <tr style=\"text-align: right;\">\n",
|
||||
" <th></th>\n",
|
||||
" <th>id</th>\n",
|
||||
" <th>title</th>\n",
|
||||
" <th>occurrence weight</th>\n",
|
||||
" <th>content</th>\n",
|
||||
" <th>rank</th>\n",
|
||||
" </tr>\n",
|
||||
" </thead>\n",
|
||||
" <tbody>\n",
|
||||
" <tr>\n",
|
||||
" <th>0</th>\n",
|
||||
" <td>15</td>\n",
|
||||
" <td>Dulce Base and the Paranormal Military Squad: ...</td>\n",
|
||||
" <td>1.00</td>\n",
|
||||
" <td># Dulce Base and the Paranormal Military Squad...</td>\n",
|
||||
" <td>9.5</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>1</th>\n",
|
||||
" <td>1</td>\n",
|
||||
" <td>Earth's Interstellar Communication Initiative</td>\n",
|
||||
" <td>0.16</td>\n",
|
||||
" <td># Earth's Interstellar Communication Initiativ...</td>\n",
|
||||
" <td>8.5</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>2</th>\n",
|
||||
" <td>16</td>\n",
|
||||
" <td>Dulce Military Base and Alien Intelligence Com...</td>\n",
|
||||
" <td>0.08</td>\n",
|
||||
" <td># Dulce Military Base and Alien Intelligence C...</td>\n",
|
||||
" <td>8.5</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>3</th>\n",
|
||||
" <td>18</td>\n",
|
||||
" <td>Paranormal Military Squad Team and Dulce Base'...</td>\n",
|
||||
" <td>0.04</td>\n",
|
||||
" <td># Paranormal Military Squad Team and Dulce Bas...</td>\n",
|
||||
" <td>8.5</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>4</th>\n",
|
||||
" <td>19</td>\n",
|
||||
" <td>Central Terminal and Viewing Monitors at Dulce...</td>\n",
|
||||
" <td>0.02</td>\n",
|
||||
" <td># Central Terminal and Viewing Monitors at Dul...</td>\n",
|
||||
" <td>8.5</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>5</th>\n",
|
||||
" <td>4</td>\n",
|
||||
" <td>Dulce Facility and Control Room of Dulce: Extr...</td>\n",
|
||||
" <td>0.02</td>\n",
|
||||
" <td># Dulce Facility and Control Room of Dulce: Ex...</td>\n",
|
||||
" <td>8.5</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>6</th>\n",
|
||||
" <td>6</td>\n",
|
||||
" <td>Cosmic Vocalization and Universe Interactions</td>\n",
|
||||
" <td>0.02</td>\n",
|
||||
" <td># Cosmic Vocalization and Universe Interaction...</td>\n",
|
||||
" <td>7.5</td>\n",
|
||||
" </tr>\n",
|
||||
" </tbody>\n",
|
||||
"</table>\n",
|
||||
"</div>"
|
||||
],
|
||||
"text/plain": [
|
||||
" id title occurrence weight \\\n",
|
||||
"0 15 Dulce Base and the Paranormal Military Squad: ... 1.00 \n",
|
||||
"1 1 Earth's Interstellar Communication Initiative 0.16 \n",
|
||||
"2 16 Dulce Military Base and Alien Intelligence Com... 0.08 \n",
|
||||
"3 18 Paranormal Military Squad Team and Dulce Base'... 0.04 \n",
|
||||
"4 19 Central Terminal and Viewing Monitors at Dulce... 0.02 \n",
|
||||
"5 4 Dulce Facility and Control Room of Dulce: Extr... 0.02 \n",
|
||||
"6 6 Cosmic Vocalization and Universe Interactions 0.02 \n",
|
||||
"\n",
|
||||
" content rank \n",
|
||||
"0 # Dulce Base and the Paranormal Military Squad... 9.5 \n",
|
||||
"1 # Earth's Interstellar Communication Initiativ... 8.5 \n",
|
||||
"2 # Dulce Military Base and Alien Intelligence C... 8.5 \n",
|
||||
"3 # Paranormal Military Squad Team and Dulce Bas... 8.5 \n",
|
||||
"4 # Central Terminal and Viewing Monitors at Dul... 8.5 \n",
|
||||
"5 # Dulce Facility and Control Room of Dulce: Ex... 8.5 \n",
|
||||
"6 # Cosmic Vocalization and Universe Interaction... 7.5 "
|
||||
]
|
||||
},
|
||||
"execution_count": 10,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# inspect the data used to build the context for the LLM responses\n",
|
||||
"result.context_data[\"reports\"]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Build context (gpt-4o-mini)\n",
|
||||
"LLM calls: 12. Prompt tokens: 8565. Output tokens: 1091.\n",
|
||||
"Map-reduce (gpt-4o)\n",
|
||||
"LLM calls: 2. Prompt tokens: 5771. Output tokens: 600.\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# inspect number of LLM calls and tokens in dynamic community selection\n",
|
||||
"llm_calls = result.llm_calls_categories[\"build_context\"]\n",
|
||||
"prompt_tokens = result.prompt_tokens_categories[\"build_context\"]\n",
|
||||
"output_tokens = result.output_tokens_categories[\"build_context\"]\n",
|
||||
"print(\n",
|
||||
" f\"Build context ({mini_llm.model})\\nLLM calls: {llm_calls}. Prompt tokens: {prompt_tokens}. Output tokens: {output_tokens}.\"\n",
|
||||
")\n",
|
||||
"# inspect number of LLM calls and tokens in map-reduce\n",
|
||||
"llm_calls = result.llm_calls_categories[\"map\"] + result.llm_calls_categories[\"reduce\"]\n",
|
||||
"prompt_tokens = (\n",
|
||||
" result.prompt_tokens_categories[\"map\"] + result.prompt_tokens_categories[\"reduce\"]\n",
|
||||
")\n",
|
||||
"output_tokens = (\n",
|
||||
" result.output_tokens_categories[\"map\"] + result.output_tokens_categories[\"reduce\"]\n",
|
||||
")\n",
|
||||
"print(\n",
|
||||
" f\"Map-reduce ({llm.model})\\nLLM calls: {llm_calls}. Prompt tokens: {prompt_tokens}. Output tokens: {output_tokens}.\"\n",
|
||||
")"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "graphrag",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.12.5"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -32,6 +32,7 @@ from graphrag.query.factories import (
|
||||
get_local_search_engine,
|
||||
)
|
||||
from graphrag.query.indexer_adapters import (
|
||||
read_indexer_communities,
|
||||
read_indexer_covariates,
|
||||
read_indexer_entities,
|
||||
read_indexer_relationships,
|
||||
@ -52,8 +53,10 @@ async def global_search(
|
||||
config: GraphRagConfig,
|
||||
nodes: pd.DataFrame,
|
||||
entities: pd.DataFrame,
|
||||
communities: pd.DataFrame,
|
||||
community_reports: pd.DataFrame,
|
||||
community_level: int,
|
||||
community_level: int | None,
|
||||
dynamic_community_selection: bool,
|
||||
response_type: str,
|
||||
query: str,
|
||||
) -> tuple[
|
||||
@ -67,8 +70,10 @@ async def global_search(
|
||||
- config (GraphRagConfig): A graphrag configuration (from settings.yaml)
|
||||
- nodes (pd.DataFrame): A DataFrame containing the final nodes (from create_final_nodes.parquet)
|
||||
- entities (pd.DataFrame): A DataFrame containing the final entities (from create_final_entities.parquet)
|
||||
- communities (pd.DataFrame): A DataFrame containing the final communities (from create_final_communities.parquet)
|
||||
- community_reports (pd.DataFrame): A DataFrame containing the final community reports (from create_final_community_reports.parquet)
|
||||
- community_level (int): The community level to search at.
|
||||
- dynamic_community_selection (bool): Enable dynamic community selection instead of using all community reports at a fixed level. Note that you can still provide community_level cap the maximum level to search.
|
||||
- response_type (str): The type of response to return.
|
||||
- query (str): The user query to search for.
|
||||
|
||||
@ -80,13 +85,21 @@ async def global_search(
|
||||
------
|
||||
TODO: Document any exceptions to expect.
|
||||
"""
|
||||
reports = read_indexer_reports(community_reports, nodes, community_level)
|
||||
_entities = read_indexer_entities(nodes, entities, community_level)
|
||||
_communities = read_indexer_communities(communities, nodes, community_reports)
|
||||
reports = read_indexer_reports(
|
||||
community_reports,
|
||||
nodes,
|
||||
community_level=community_level,
|
||||
dynamic_community_selection=dynamic_community_selection,
|
||||
)
|
||||
_entities = read_indexer_entities(nodes, entities, community_level=community_level)
|
||||
search_engine = get_global_search_engine(
|
||||
config,
|
||||
reports=reports,
|
||||
entities=_entities,
|
||||
communities=_communities,
|
||||
response_type=response_type,
|
||||
dynamic_community_selection=dynamic_community_selection,
|
||||
)
|
||||
result: SearchResult = await search_engine.asearch(query=query)
|
||||
response = result.response
|
||||
@ -99,8 +112,10 @@ async def global_search_streaming(
|
||||
config: GraphRagConfig,
|
||||
nodes: pd.DataFrame,
|
||||
entities: pd.DataFrame,
|
||||
communities: pd.DataFrame,
|
||||
community_reports: pd.DataFrame,
|
||||
community_level: int,
|
||||
community_level: int | None,
|
||||
dynamic_community_selection: bool,
|
||||
response_type: str,
|
||||
query: str,
|
||||
) -> AsyncGenerator:
|
||||
@ -113,8 +128,10 @@ async def global_search_streaming(
|
||||
- config (GraphRagConfig): A graphrag configuration (from settings.yaml)
|
||||
- nodes (pd.DataFrame): A DataFrame containing the final nodes (from create_final_nodes.parquet)
|
||||
- entities (pd.DataFrame): A DataFrame containing the final entities (from create_final_entities.parquet)
|
||||
- communities (pd.DataFrame): A DataFrame containing the final communities (from create_final_communities.parquet)
|
||||
- community_reports (pd.DataFrame): A DataFrame containing the final community reports (from create_final_community_reports.parquet)
|
||||
- community_level (int): The community level to search at.
|
||||
- dynamic_community_selection (bool): Enable dynamic community selection instead of using all community reports at a fixed level. Note that you can still provide community_level cap the maximum level to search.
|
||||
- response_type (str): The type of response to return.
|
||||
- query (str): The user query to search for.
|
||||
|
||||
@ -126,13 +143,21 @@ async def global_search_streaming(
|
||||
------
|
||||
TODO: Document any exceptions to expect.
|
||||
"""
|
||||
reports = read_indexer_reports(community_reports, nodes, community_level)
|
||||
_entities = read_indexer_entities(nodes, entities, community_level)
|
||||
_communities = read_indexer_communities(communities, nodes, community_reports)
|
||||
reports = read_indexer_reports(
|
||||
community_reports,
|
||||
nodes,
|
||||
community_level=community_level,
|
||||
dynamic_community_selection=dynamic_community_selection,
|
||||
)
|
||||
_entities = read_indexer_entities(nodes, entities, community_level=community_level)
|
||||
search_engine = get_global_search_engine(
|
||||
config,
|
||||
reports=reports,
|
||||
entities=_entities,
|
||||
communities=_communities,
|
||||
response_type=response_type,
|
||||
dynamic_community_selection=dynamic_community_selection,
|
||||
)
|
||||
search_result = search_engine.astream_search(query=query)
|
||||
|
||||
|
||||
@ -327,6 +327,10 @@ def _query_cli(
|
||||
help="The community level in the Leiden community hierarchy from which to load community reports. Higher values represent reports from smaller communities."
|
||||
),
|
||||
] = 2,
|
||||
dynamic_community_selection: Annotated[
|
||||
bool,
|
||||
typer.Option(help="Use global search with dynamic community selection."),
|
||||
] = False,
|
||||
response_type: Annotated[
|
||||
str,
|
||||
typer.Option(
|
||||
@ -355,6 +359,7 @@ def _query_cli(
|
||||
data_dir=data,
|
||||
root_dir=root,
|
||||
community_level=community_level,
|
||||
dynamic_community_selection=dynamic_community_selection,
|
||||
response_type=response_type,
|
||||
streaming=streaming,
|
||||
query=query,
|
||||
|
||||
@ -22,7 +22,8 @@ def run_global_search(
|
||||
config_filepath: Path | None,
|
||||
data_dir: Path | None,
|
||||
root_dir: Path,
|
||||
community_level: int,
|
||||
community_level: int | None,
|
||||
dynamic_community_selection: bool,
|
||||
response_type: str,
|
||||
streaming: bool,
|
||||
query: str,
|
||||
@ -42,12 +43,14 @@ def run_global_search(
|
||||
parquet_list=[
|
||||
"create_final_nodes.parquet",
|
||||
"create_final_entities.parquet",
|
||||
"create_final_communities.parquet",
|
||||
"create_final_community_reports.parquet",
|
||||
],
|
||||
optional_list=[],
|
||||
)
|
||||
final_nodes: pd.DataFrame = dataframe_dict["create_final_nodes"]
|
||||
final_entities: pd.DataFrame = dataframe_dict["create_final_entities"]
|
||||
final_communities: pd.DataFrame = dataframe_dict["create_final_communities"]
|
||||
final_community_reports: pd.DataFrame = dataframe_dict[
|
||||
"create_final_community_reports"
|
||||
]
|
||||
@ -63,8 +66,10 @@ def run_global_search(
|
||||
config=config,
|
||||
nodes=final_nodes,
|
||||
entities=final_entities,
|
||||
communities=final_communities,
|
||||
community_reports=final_community_reports,
|
||||
community_level=community_level,
|
||||
dynamic_community_selection=dynamic_community_selection,
|
||||
response_type=response_type,
|
||||
query=query,
|
||||
):
|
||||
@ -85,8 +90,10 @@ def run_global_search(
|
||||
config=config,
|
||||
nodes=final_nodes,
|
||||
entities=final_entities,
|
||||
communities=final_communities,
|
||||
community_reports=final_community_reports,
|
||||
community_level=community_level,
|
||||
dynamic_community_selection=dynamic_community_selection,
|
||||
response_type=response_type,
|
||||
query=query,
|
||||
)
|
||||
|
||||
@ -119,8 +119,16 @@ GLOBAL_SEARCH_MAP_MAX_TOKENS = 1000
|
||||
GLOBAL_SEARCH_REDUCE_MAX_TOKENS = 2_000
|
||||
GLOBAL_SEARCH_CONCURRENCY = 32
|
||||
|
||||
# DRIFT Search
|
||||
# Global Search with dynamic community selection
|
||||
DYNAMIC_SEARCH_LLM_MODEL = "gpt-4o-mini"
|
||||
DYNAMIC_SEARCH_RATE_THRESHOLD = 1
|
||||
DYNAMIC_SEARCH_KEEP_PARENT = False
|
||||
DYNAMIC_SEARCH_NUM_REPEATS = 1
|
||||
DYNAMIC_SEARCH_USE_SUMMARY = False
|
||||
DYNAMIC_SEARCH_CONCURRENT_COROUTINES = 16
|
||||
DYNAMIC_SEARCH_MAX_LEVEL = 2
|
||||
|
||||
# DRIFT Search
|
||||
DRIFT_SEARCH_LLM_TEMPERATURE = 0
|
||||
DRIFT_SEARCH_LLM_TOP_P = 1
|
||||
DRIFT_SEARCH_LLM_N = 3
|
||||
|
||||
@ -43,3 +43,33 @@ class GlobalSearchConfig(BaseModel):
|
||||
description="The number of concurrent requests.",
|
||||
default=defs.GLOBAL_SEARCH_CONCURRENCY,
|
||||
)
|
||||
|
||||
# configurations for dynamic community selection
|
||||
dynamic_search_llm: str = Field(
|
||||
description="LLM model to use for dynamic community selection",
|
||||
default=defs.DYNAMIC_SEARCH_LLM_MODEL,
|
||||
)
|
||||
dynamic_search_threshold: int = Field(
|
||||
description="Rating threshold in include a community report",
|
||||
default=defs.DYNAMIC_SEARCH_RATE_THRESHOLD,
|
||||
)
|
||||
dynamic_search_keep_parent: bool = Field(
|
||||
description="Keep parent community if any of the child communities are relevant",
|
||||
default=defs.DYNAMIC_SEARCH_KEEP_PARENT,
|
||||
)
|
||||
dynamic_search_num_repeats: int = Field(
|
||||
description="Number of times to rate the same community report",
|
||||
default=defs.DYNAMIC_SEARCH_NUM_REPEATS,
|
||||
)
|
||||
dynamic_search_use_summary: bool = Field(
|
||||
description="Use community summary instead of full_context",
|
||||
default=defs.DYNAMIC_SEARCH_USE_SUMMARY,
|
||||
)
|
||||
dynamic_search_concurrent_coroutines: int = Field(
|
||||
description="Number of concurrent coroutines to rate community reports",
|
||||
default=defs.DYNAMIC_SEARCH_CONCURRENT_COROUTINES,
|
||||
)
|
||||
dynamic_search_max_level: int = Field(
|
||||
description="The maximum level of community hierarchy to consider if none of the processed communities are relevant",
|
||||
default=defs.DYNAMIC_SEARCH_MAX_LEVEL,
|
||||
)
|
||||
|
||||
@ -29,7 +29,7 @@ CLAIM_DESCRIPTION = "description"
|
||||
CLAIM_DETAILS = "claim_details"
|
||||
|
||||
# COMMUNITY HIERARCHY TABLE SCHEMA
|
||||
SUB_COMMUNITY = "sub_communitty"
|
||||
SUB_COMMUNITY = "sub_community"
|
||||
SUB_COMMUNITY_SIZE = "sub_community_size"
|
||||
COMMUNITY_LEVEL = "level"
|
||||
|
||||
|
||||
@ -25,6 +25,9 @@ class Community(Named):
|
||||
covariate_ids: dict[str, list[str]] | None = None
|
||||
"""Dictionary of different types of covariates related to the community (optional), e.g. claims"""
|
||||
|
||||
sub_community_ids: list[str] | None = None
|
||||
"""List of community IDs of the child nodes of this community (optional)."""
|
||||
|
||||
attributes: dict[str, Any] | None = None
|
||||
"""A dictionary of additional attributes associated with the community (optional). To be included in the search prompt."""
|
||||
|
||||
@ -45,6 +48,7 @@ class Community(Named):
|
||||
entities_key: str = "entity_ids",
|
||||
relationships_key: str = "relationship_ids",
|
||||
covariates_key: str = "covariate_ids",
|
||||
sub_communities_key: str = "sub_community_ids",
|
||||
attributes_key: str = "attributes",
|
||||
size_key: str = "size",
|
||||
period_key: str = "period",
|
||||
@ -58,6 +62,7 @@ class Community(Named):
|
||||
entity_ids=d.get(entities_key),
|
||||
relationship_ids=d.get(relationships_key),
|
||||
covariate_ids=d.get(covariates_key),
|
||||
sub_community_ids=d.get(sub_communities_key),
|
||||
attributes=d.get(attributes_key),
|
||||
size=d.get(size_key),
|
||||
period=d.get(period_key),
|
||||
|
||||
@ -4,6 +4,7 @@
|
||||
"""Base classes for global and local context builders."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
|
||||
import pandas as pd
|
||||
|
||||
@ -12,13 +13,27 @@ from graphrag.query.context_builder.conversation_history import (
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ContextBuilderResult:
|
||||
"""A class to hold the results of the build_context."""
|
||||
|
||||
context_chunks: str | list[str]
|
||||
context_records: dict[str, pd.DataFrame]
|
||||
llm_calls: int = 0
|
||||
prompt_tokens: int = 0
|
||||
output_tokens: int = 0
|
||||
|
||||
|
||||
class GlobalContextBuilder(ABC):
|
||||
"""Base class for global-search context builders."""
|
||||
|
||||
@abstractmethod
|
||||
def build_context(
|
||||
self, conversation_history: ConversationHistory | None = None, **kwargs
|
||||
) -> tuple[str | list[str], dict[str, pd.DataFrame]]:
|
||||
async def build_context(
|
||||
self,
|
||||
query: str,
|
||||
conversation_history: ConversationHistory | None = None,
|
||||
**kwargs,
|
||||
) -> ContextBuilderResult:
|
||||
"""Build the context for the global search mode."""
|
||||
|
||||
|
||||
@ -31,7 +46,7 @@ class LocalContextBuilder(ABC):
|
||||
query: str,
|
||||
conversation_history: ConversationHistory | None = None,
|
||||
**kwargs,
|
||||
) -> tuple[str | list[str], dict[str, pd.DataFrame]]:
|
||||
) -> ContextBuilderResult:
|
||||
"""Build the context for the local search mode."""
|
||||
|
||||
|
||||
@ -43,5 +58,5 @@ class DRIFTContextBuilder(ABC):
|
||||
self,
|
||||
query: str,
|
||||
**kwargs,
|
||||
) -> pd.DataFrame:
|
||||
) -> tuple[pd.DataFrame, dict[str, int]]:
|
||||
"""Build the context for the primer search actions."""
|
||||
|
||||
183
graphrag/query/context_builder/dynamic_community_selection.py
Normal file
183
graphrag/query/context_builder/dynamic_community_selection.py
Normal file
@ -0,0 +1,183 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""Algorithm to dynamically select relevant communities with respect to a query."""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from collections import Counter
|
||||
from copy import deepcopy
|
||||
from time import time
|
||||
from typing import Any
|
||||
|
||||
import tiktoken
|
||||
|
||||
from graphrag.model import Community, CommunityReport
|
||||
from graphrag.query.context_builder.rate_prompt import RATE_QUERY
|
||||
from graphrag.query.context_builder.rate_relevancy import rate_relevancy
|
||||
from graphrag.query.llm.base import BaseLLM
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_RATE_LLM_PARAMS = {"temperature": 0.0, "max_tokens": 2000}
|
||||
|
||||
|
||||
class DynamicCommunitySelection:
|
||||
"""Dynamic community selection to select community reports that are relevant to the query.
|
||||
|
||||
Any community report with a rating EQUAL or ABOVE the rating_threshold is considered relevant.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
community_reports: list[CommunityReport],
|
||||
communities: list[Community],
|
||||
llm: BaseLLM,
|
||||
token_encoder: tiktoken.Encoding,
|
||||
rate_query: str = RATE_QUERY,
|
||||
use_summary: bool = False,
|
||||
threshold: int = 1,
|
||||
keep_parent: bool = False,
|
||||
num_repeats: int = 1,
|
||||
max_level: int = 2,
|
||||
concurrent_coroutines: int = 8,
|
||||
llm_kwargs: Any = DEFAULT_RATE_LLM_PARAMS,
|
||||
):
|
||||
self.llm = llm
|
||||
self.token_encoder = token_encoder
|
||||
self.rate_query = rate_query
|
||||
self.num_repeats = num_repeats
|
||||
self.use_summary = use_summary
|
||||
self.threshold = threshold
|
||||
self.keep_parent = keep_parent
|
||||
self.max_level = max_level
|
||||
self.semaphore = asyncio.Semaphore(concurrent_coroutines)
|
||||
self.llm_kwargs = llm_kwargs
|
||||
|
||||
self.reports = {report.community_id: report for report in community_reports}
|
||||
# mapping from community to sub communities
|
||||
self.node2children = {
|
||||
community.id: (
|
||||
[]
|
||||
if community.sub_community_ids is None
|
||||
else community.sub_community_ids
|
||||
)
|
||||
for community in communities
|
||||
}
|
||||
# mapping from community to parent community
|
||||
self.node2parent: dict[str, str] = {
|
||||
sub_community: community
|
||||
for community, sub_communities in self.node2children.items()
|
||||
for sub_community in sub_communities
|
||||
}
|
||||
# mapping from level to communities
|
||||
self.levels: dict[str, list[str]] = {}
|
||||
for community in communities:
|
||||
if community.level not in self.levels:
|
||||
self.levels[community.level] = []
|
||||
if community.id in self.reports:
|
||||
self.levels[community.level].append(community.id)
|
||||
|
||||
# start from root communities (level 0)
|
||||
self.starting_communities = self.levels["0"]
|
||||
|
||||
async def select(self, query: str) -> tuple[list[CommunityReport], dict[str, Any]]:
|
||||
"""
|
||||
Select relevant communities with respect to the query.
|
||||
|
||||
Args:
|
||||
query: the query to rate against
|
||||
"""
|
||||
start = time()
|
||||
queue = deepcopy(self.starting_communities)
|
||||
level = 0
|
||||
|
||||
ratings = {} # store the ratings for each community
|
||||
llm_info: dict[str, Any] = {
|
||||
"llm_calls": 0,
|
||||
"prompt_tokens": 0,
|
||||
"output_tokens": 0,
|
||||
}
|
||||
relevant_communities = set()
|
||||
while queue:
|
||||
gather_results = await asyncio.gather(*[
|
||||
rate_relevancy(
|
||||
query=query,
|
||||
description=(
|
||||
self.reports[community].summary
|
||||
if self.use_summary
|
||||
else self.reports[community].full_content
|
||||
),
|
||||
llm=self.llm,
|
||||
token_encoder=self.token_encoder,
|
||||
rate_query=self.rate_query,
|
||||
num_repeats=self.num_repeats,
|
||||
semaphore=self.semaphore,
|
||||
**self.llm_kwargs,
|
||||
)
|
||||
for community in queue
|
||||
])
|
||||
|
||||
communities_to_rate = []
|
||||
for community, result in zip(queue, gather_results, strict=True):
|
||||
rating = result["rating"]
|
||||
log.debug(
|
||||
"dynamic community selection: community %s rating %s",
|
||||
community,
|
||||
rating,
|
||||
)
|
||||
ratings[community] = rating
|
||||
llm_info["llm_calls"] += result["llm_calls"]
|
||||
llm_info["prompt_tokens"] += result["prompt_tokens"]
|
||||
llm_info["output_tokens"] += result["output_tokens"]
|
||||
if rating >= self.threshold:
|
||||
relevant_communities.add(community)
|
||||
# find children nodes of the current node and append them to the queue
|
||||
# TODO check why some sub_communities are NOT in report_df
|
||||
if community in self.node2children:
|
||||
for sub_community in self.node2children[community]:
|
||||
if sub_community in self.reports:
|
||||
communities_to_rate.append(sub_community)
|
||||
else:
|
||||
log.debug(
|
||||
"dynamic community selection: cannot find community %s in reports",
|
||||
sub_community,
|
||||
)
|
||||
# remove parent node if the current node is deemed relevant
|
||||
if not self.keep_parent and community in self.node2parent:
|
||||
relevant_communities.discard(self.node2parent[community])
|
||||
queue = communities_to_rate
|
||||
level += 1
|
||||
if (
|
||||
(len(queue) == 0)
|
||||
and (len(relevant_communities) == 0)
|
||||
and (str(level) in self.levels)
|
||||
and (level <= self.max_level)
|
||||
):
|
||||
log.info(
|
||||
"dynamic community selection: no relevant community "
|
||||
"reports, adding all reports at level %s to rate.",
|
||||
level,
|
||||
)
|
||||
# append all communities at the next level to queue
|
||||
queue = self.levels[str(level)]
|
||||
|
||||
community_reports = [
|
||||
self.reports[community] for community in relevant_communities
|
||||
]
|
||||
end = time()
|
||||
|
||||
log.info(
|
||||
"dynamic community selection (took: %ss)\n"
|
||||
"\trating distribution %s\n"
|
||||
"\t%s out of %s community reports are relevant\n"
|
||||
"\tprompt tokens: %s, output tokens: %s",
|
||||
int(end - start),
|
||||
dict(sorted(Counter(ratings.values()).items())),
|
||||
len(relevant_communities),
|
||||
len(self.reports),
|
||||
llm_info["prompt_tokens"],
|
||||
llm_info["output_tokens"],
|
||||
)
|
||||
llm_info["ratings"] = ratings
|
||||
return community_reports, llm_info
|
||||
23
graphrag/query/context_builder/rate_prompt.py
Normal file
23
graphrag/query/context_builder/rate_prompt.py
Normal file
@ -0,0 +1,23 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""Global search with dynamic community selection prompt."""
|
||||
|
||||
RATE_QUERY = """
|
||||
---Role---
|
||||
You are a helpful assistant responsible for deciding whether the provided information is useful in answering a given question, even if it is only partially relevant.
|
||||
---Goal---
|
||||
On a scale from 0 to 5, please rate how relevant or helpful is the provided information in answering the question.
|
||||
---Information---
|
||||
{description}
|
||||
---Question---
|
||||
{question}
|
||||
---Target response length and format---
|
||||
Please response in the following JSON format with two entries:
|
||||
- "reason": the reasoning of your rating, please include information that you have considered.
|
||||
- "rating": the relevancy rating from 0 to 5, where 0 is the least relevant and 5 is the most relevant.
|
||||
{{
|
||||
"reason": str,
|
||||
"rating": int.
|
||||
}}
|
||||
"""
|
||||
76
graphrag/query/context_builder/rate_relevancy.py
Normal file
76
graphrag/query/context_builder/rate_relevancy.py
Normal file
@ -0,0 +1,76 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""Algorithm to rate the relevancy between a query and description text."""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from contextlib import nullcontext
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import tiktoken
|
||||
|
||||
from graphrag.llm.openai.utils import try_parse_json_object
|
||||
from graphrag.query.context_builder.rate_prompt import RATE_QUERY
|
||||
from graphrag.query.llm.base import BaseLLM
|
||||
from graphrag.query.llm.text_utils import num_tokens
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def rate_relevancy(
|
||||
query: str,
|
||||
description: str,
|
||||
llm: BaseLLM,
|
||||
token_encoder: tiktoken.Encoding,
|
||||
rate_query: str = RATE_QUERY,
|
||||
num_repeats: int = 1,
|
||||
semaphore: asyncio.Semaphore | None = None,
|
||||
**llm_kwargs: Any,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Rate the relevancy between the query and description on a scale of 0 to 10.
|
||||
|
||||
Args:
|
||||
query: the query (or question) to rate against
|
||||
description: the community description to rate, it can be the community
|
||||
title, summary, or the full content.
|
||||
llm: LLM model to use for rating
|
||||
token_encoder: token encoder
|
||||
num_repeats: number of times to repeat the rating process for the same community (default: 1)
|
||||
llm_kwargs: additional arguments to pass to the LLM model
|
||||
semaphore: asyncio.Semaphore to limit the number of concurrent LLM calls (default: None)
|
||||
"""
|
||||
llm_calls, prompt_tokens, output_tokens, ratings = 0, 0, 0, []
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": rate_query.format(description=description, question=query),
|
||||
},
|
||||
{"role": "user", "content": query},
|
||||
]
|
||||
for _ in range(num_repeats):
|
||||
async with semaphore if semaphore is not None else nullcontext():
|
||||
response = await llm.agenerate(messages=messages, **llm_kwargs)
|
||||
try:
|
||||
_, parsed_response = try_parse_json_object(response)
|
||||
ratings.append(parsed_response["rating"])
|
||||
except KeyError:
|
||||
# in case of json parsing error, default to rating 1 so the report is kept.
|
||||
# json parsing error should rarely happen.
|
||||
log.info("Error parsing json response, defaulting to rating 1")
|
||||
ratings.append(1)
|
||||
llm_calls += 1
|
||||
prompt_tokens += num_tokens(messages[0]["content"], token_encoder)
|
||||
output_tokens += num_tokens(response, token_encoder)
|
||||
# select the decision with the most votes
|
||||
options, counts = np.unique(ratings, return_counts=True)
|
||||
rating = int(options[np.argmax(counts)])
|
||||
return {
|
||||
"rating": rating,
|
||||
"ratings": ratings,
|
||||
"llm_calls": llm_calls,
|
||||
"prompt_tokens": prompt_tokens,
|
||||
"output_tokens": output_tokens,
|
||||
}
|
||||
@ -3,14 +3,13 @@
|
||||
|
||||
"""Query Factory methods to support CLI."""
|
||||
|
||||
import tiktoken
|
||||
from azure.identity import DefaultAzureCredential, get_bearer_token_provider
|
||||
from copy import deepcopy
|
||||
|
||||
from graphrag.config import (
|
||||
GraphRagConfig,
|
||||
LLMType,
|
||||
)
|
||||
import tiktoken
|
||||
|
||||
from graphrag.config import GraphRagConfig
|
||||
from graphrag.model import (
|
||||
Community,
|
||||
CommunityReport,
|
||||
Covariate,
|
||||
Entity,
|
||||
@ -18,9 +17,7 @@ from graphrag.model import (
|
||||
TextUnit,
|
||||
)
|
||||
from graphrag.query.context_builder.entity_extraction import EntityVectorStoreKey
|
||||
from graphrag.query.llm.oai.chat_openai import ChatOpenAI
|
||||
from graphrag.query.llm.oai.embedding import OpenAIEmbedding
|
||||
from graphrag.query.llm.oai.typing import OpenaiApiType
|
||||
from graphrag.query.llm.get_client import get_llm, get_text_embedder
|
||||
from graphrag.query.structured_search.drift_search.drift_context import (
|
||||
DRIFTSearchContextBuilder,
|
||||
)
|
||||
@ -36,71 +33,6 @@ from graphrag.query.structured_search.local_search.search import LocalSearch
|
||||
from graphrag.vector_stores import BaseVectorStore
|
||||
|
||||
|
||||
def get_llm(config: GraphRagConfig) -> ChatOpenAI:
|
||||
"""Get the LLM client."""
|
||||
is_azure_client = (
|
||||
config.llm.type == LLMType.AzureOpenAIChat
|
||||
or config.llm.type == LLMType.AzureOpenAI
|
||||
)
|
||||
debug_llm_key = config.llm.api_key or ""
|
||||
llm_debug_info = {
|
||||
**config.llm.model_dump(),
|
||||
"api_key": f"REDACTED,len={len(debug_llm_key)}",
|
||||
}
|
||||
audience = (
|
||||
config.llm.audience
|
||||
if config.llm.audience
|
||||
else "https://cognitiveservices.azure.com/.default"
|
||||
)
|
||||
print(f"creating llm client with {llm_debug_info}") # noqa T201
|
||||
return ChatOpenAI(
|
||||
api_key=config.llm.api_key,
|
||||
azure_ad_token_provider=(
|
||||
get_bearer_token_provider(DefaultAzureCredential(), audience)
|
||||
if is_azure_client and not config.llm.api_key
|
||||
else None
|
||||
),
|
||||
api_base=config.llm.api_base,
|
||||
organization=config.llm.organization,
|
||||
model=config.llm.model,
|
||||
api_type=OpenaiApiType.AzureOpenAI if is_azure_client else OpenaiApiType.OpenAI,
|
||||
deployment_name=config.llm.deployment_name,
|
||||
api_version=config.llm.api_version,
|
||||
max_retries=config.llm.max_retries,
|
||||
request_timeout=config.llm.request_timeout,
|
||||
)
|
||||
|
||||
|
||||
def get_text_embedder(config: GraphRagConfig) -> OpenAIEmbedding:
|
||||
"""Get the LLM client for embeddings."""
|
||||
is_azure_client = config.embeddings.llm.type == LLMType.AzureOpenAIEmbedding
|
||||
debug_embedding_api_key = config.embeddings.llm.api_key or ""
|
||||
llm_debug_info = {
|
||||
**config.embeddings.llm.model_dump(),
|
||||
"api_key": f"REDACTED,len={len(debug_embedding_api_key)}",
|
||||
}
|
||||
if config.embeddings.llm.audience is None:
|
||||
audience = "https://cognitiveservices.azure.com/.default"
|
||||
else:
|
||||
audience = config.embeddings.llm.audience
|
||||
print(f"creating embedding llm client with {llm_debug_info}") # noqa T201
|
||||
return OpenAIEmbedding(
|
||||
api_key=config.embeddings.llm.api_key,
|
||||
azure_ad_token_provider=(
|
||||
get_bearer_token_provider(DefaultAzureCredential(), audience)
|
||||
if is_azure_client and not config.embeddings.llm.api_key
|
||||
else None
|
||||
),
|
||||
api_base=config.embeddings.llm.api_base,
|
||||
organization=config.llm.organization,
|
||||
api_type=OpenaiApiType.AzureOpenAI if is_azure_client else OpenaiApiType.OpenAI,
|
||||
model=config.embeddings.llm.model,
|
||||
deployment_name=config.embeddings.llm.deployment_name,
|
||||
api_version=config.embeddings.llm.api_version,
|
||||
max_retries=config.embeddings.llm.max_retries,
|
||||
)
|
||||
|
||||
|
||||
def get_local_search_engine(
|
||||
config: GraphRagConfig,
|
||||
reports: list[CommunityReport],
|
||||
@ -160,16 +92,39 @@ def get_global_search_engine(
|
||||
config: GraphRagConfig,
|
||||
reports: list[CommunityReport],
|
||||
entities: list[Entity],
|
||||
communities: list[Community],
|
||||
response_type: str,
|
||||
dynamic_community_selection: bool = False,
|
||||
) -> GlobalSearch:
|
||||
"""Create a global search engine based on data + configuration."""
|
||||
token_encoder = tiktoken.get_encoding(config.encoding_model)
|
||||
gs_config = config.global_search
|
||||
|
||||
dynamic_community_selection_kwargs = {}
|
||||
if dynamic_community_selection:
|
||||
gs_config = config.global_search
|
||||
_config = deepcopy(config)
|
||||
_config.llm.model = _config.llm.deployment_name = gs_config.dynamic_search_llm
|
||||
dynamic_community_selection_kwargs.update({
|
||||
"llm": get_llm(_config),
|
||||
"token_encoder": tiktoken.encoding_for_model(gs_config.dynamic_search_llm),
|
||||
"keep_parent": gs_config.dynamic_search_keep_parent,
|
||||
"num_repeats": gs_config.dynamic_search_num_repeats,
|
||||
"use_summary": gs_config.dynamic_search_use_summary,
|
||||
"concurrent_coroutines": gs_config.dynamic_search_concurrent_coroutines,
|
||||
"threshold": gs_config.dynamic_search_threshold,
|
||||
"max_level": gs_config.dynamic_search_max_level,
|
||||
})
|
||||
|
||||
return GlobalSearch(
|
||||
llm=get_llm(config),
|
||||
context_builder=GlobalCommunityContext(
|
||||
community_reports=reports, entities=entities, token_encoder=token_encoder
|
||||
community_reports=reports,
|
||||
communities=communities,
|
||||
entities=entities,
|
||||
token_encoder=token_encoder,
|
||||
dynamic_community_selection=dynamic_community_selection,
|
||||
dynamic_community_selection_kwargs=dynamic_community_selection_kwargs,
|
||||
),
|
||||
token_encoder=token_encoder,
|
||||
max_data_tokens=gs_config.data_max_tokens,
|
||||
|
||||
@ -3,17 +3,27 @@
|
||||
"""Indexing-Engine to Query Read Adapters.
|
||||
|
||||
The parts of these functions that do type adaptation, renaming, collating, etc. should eventually go away.
|
||||
Ideally this is just a straight read-thorugh into the object model.
|
||||
Ideally this is just a straight read-through into the object model.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import cast
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from graphrag.config.models.graph_rag_config import GraphRagConfig
|
||||
from graphrag.model import CommunityReport, Covariate, Entity, Relationship, TextUnit
|
||||
from graphrag.index.operations.summarize_communities import restore_community_hierarchy
|
||||
from graphrag.model import (
|
||||
Community,
|
||||
CommunityReport,
|
||||
Covariate,
|
||||
Entity,
|
||||
Relationship,
|
||||
TextUnit,
|
||||
)
|
||||
from graphrag.query.factories import get_text_embedder
|
||||
from graphrag.query.input.loaders.dfs import (
|
||||
read_communities,
|
||||
read_community_reports,
|
||||
read_covariates,
|
||||
read_entities,
|
||||
@ -23,6 +33,8 @@ from graphrag.query.input.loaders.dfs import (
|
||||
from graphrag.query.llm.oai.embedding import OpenAIEmbedding
|
||||
from graphrag.vector_stores.base import BaseVectorStore
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def read_indexer_text_units(final_text_units: pd.DataFrame) -> list[TextUnit]:
|
||||
"""Read in the Text Units from the raw indexing outputs."""
|
||||
@ -66,23 +78,32 @@ def read_indexer_relationships(final_relationships: pd.DataFrame) -> list[Relati
|
||||
def read_indexer_reports(
|
||||
final_community_reports: pd.DataFrame,
|
||||
final_nodes: pd.DataFrame,
|
||||
community_level: int,
|
||||
community_level: int | None,
|
||||
dynamic_community_selection: bool = False,
|
||||
content_embedding_col: str = "full_content_embedding",
|
||||
config: GraphRagConfig | None = None,
|
||||
) -> list[CommunityReport]:
|
||||
"""Read in the Community Reports from the raw indexing outputs."""
|
||||
"""Read in the Community Reports from the raw indexing outputs.
|
||||
|
||||
If not dynamic_community_selection, then select reports with the max community level that an entity belongs to.
|
||||
"""
|
||||
report_df = final_community_reports
|
||||
entity_df = final_nodes
|
||||
entity_df = _filter_under_community_level(entity_df, community_level)
|
||||
entity_df.loc[:, "community"] = entity_df["community"].fillna(-1)
|
||||
entity_df.loc[:, "community"] = entity_df["community"].astype(int)
|
||||
if community_level is not None:
|
||||
entity_df = _filter_under_community_level(entity_df, community_level)
|
||||
report_df = _filter_under_community_level(report_df, community_level)
|
||||
|
||||
entity_df = entity_df.groupby(["title"]).agg({"community": "max"}).reset_index()
|
||||
entity_df["community"] = entity_df["community"].astype(str)
|
||||
filtered_community_df = entity_df["community"].drop_duplicates()
|
||||
if not dynamic_community_selection:
|
||||
# perform community level roll up
|
||||
entity_df.loc[:, "community"] = entity_df["community"].fillna(-1)
|
||||
entity_df.loc[:, "community"] = entity_df["community"].astype(int)
|
||||
|
||||
entity_df = entity_df.groupby(["title"]).agg({"community": "max"}).reset_index()
|
||||
entity_df["community"] = entity_df["community"].astype(str)
|
||||
filtered_community_df = entity_df["community"].drop_duplicates()
|
||||
|
||||
report_df = report_df.merge(filtered_community_df, on="community", how="inner")
|
||||
|
||||
report_df = _filter_under_community_level(report_df, community_level)
|
||||
report_df = report_df.merge(filtered_community_df, on="community", how="inner")
|
||||
if config and (
|
||||
content_embedding_col not in report_df.columns
|
||||
or report_df.loc[:, content_embedding_col].isna().any()
|
||||
@ -113,13 +134,15 @@ def read_indexer_report_embeddings(
|
||||
def read_indexer_entities(
|
||||
final_nodes: pd.DataFrame,
|
||||
final_entities: pd.DataFrame,
|
||||
community_level: int,
|
||||
community_level: int | None,
|
||||
) -> list[Entity]:
|
||||
"""Read in the Entities from the raw indexing outputs."""
|
||||
entity_df = final_nodes
|
||||
entity_embedding_df = final_entities
|
||||
|
||||
entity_df = _filter_under_community_level(entity_df, community_level)
|
||||
if community_level is not None:
|
||||
entity_df = _filter_under_community_level(entity_df, community_level)
|
||||
|
||||
entity_df = cast(pd.DataFrame, entity_df[["title", "degree", "community"]]).rename(
|
||||
columns={"title": "name", "degree": "rank"}
|
||||
)
|
||||
@ -128,11 +151,11 @@ def read_indexer_entities(
|
||||
entity_df["community"] = entity_df["community"].astype(int)
|
||||
entity_df["rank"] = entity_df["rank"].astype(int)
|
||||
|
||||
# for duplicate entities, keep the one with the highest community level
|
||||
# group entities by name and rank and remove duplicated community IDs
|
||||
entity_df = (
|
||||
entity_df.groupby(["name", "rank"]).agg({"community": "max"}).reset_index()
|
||||
entity_df.groupby(["name", "rank"]).agg({"community": set}).reset_index()
|
||||
)
|
||||
entity_df["community"] = entity_df["community"].apply(lambda x: [str(x)])
|
||||
entity_df["community"] = entity_df["community"].apply(lambda x: [str(i) for i in x])
|
||||
entity_df = entity_df.merge(
|
||||
entity_embedding_df, on="name", how="inner"
|
||||
).drop_duplicates(subset=["name"])
|
||||
@ -155,6 +178,60 @@ def read_indexer_entities(
|
||||
)
|
||||
|
||||
|
||||
def read_indexer_communities(
|
||||
final_communities: pd.DataFrame,
|
||||
final_nodes: pd.DataFrame,
|
||||
final_community_reports: pd.DataFrame,
|
||||
) -> list[Community]:
|
||||
"""Read in the Communities from the raw indexing outputs.
|
||||
|
||||
Reconstruct the community hierarchy information and add to the sub-community field.
|
||||
"""
|
||||
community_df = final_communities
|
||||
node_df = final_nodes
|
||||
report_df = final_community_reports
|
||||
|
||||
# ensure communities matches community reports
|
||||
missing_reports = community_df[
|
||||
~community_df.id.isin(report_df.community.unique())
|
||||
].id.to_list()
|
||||
if len(missing_reports):
|
||||
log.warning("Missing reports for communities: %s", missing_reports)
|
||||
community_df = community_df.loc[
|
||||
community_df.id.isin(report_df.community.unique())
|
||||
]
|
||||
node_df = node_df.loc[node_df.community.isin(report_df.community.unique())]
|
||||
|
||||
# reconstruct the community hierarchy
|
||||
# note that restore_community_hierarchy only return communities with sub communities
|
||||
community_hierarchy = restore_community_hierarchy(input=node_df)
|
||||
community_hierarchy = (
|
||||
community_hierarchy.groupby(["community"])
|
||||
.agg({"sub_community": list})
|
||||
.reset_index()
|
||||
.rename(columns={"community": "id", "sub_community": "sub_community_ids"})
|
||||
)
|
||||
# add sub community IDs to community DataFrame
|
||||
community_df = community_df.merge(community_hierarchy, on="id", how="left")
|
||||
# replace NaN sub community IDs with empty list
|
||||
community_df.sub_community_ids = community_df.sub_community_ids.apply(
|
||||
lambda x: x if isinstance(x, list) else []
|
||||
)
|
||||
|
||||
return read_communities(
|
||||
community_df,
|
||||
id_col="id",
|
||||
short_id_col="id",
|
||||
title_col="title",
|
||||
level_col="level",
|
||||
entities_col=None,
|
||||
relationships_col=None,
|
||||
covariates_col=None,
|
||||
sub_communities_col="sub_community_ids",
|
||||
attributes_cols=None,
|
||||
)
|
||||
|
||||
|
||||
def embed_community_reports(
|
||||
reports_df: pd.DataFrame,
|
||||
embedder: OpenAIEmbedding,
|
||||
|
||||
@ -215,6 +215,7 @@ def read_communities(
|
||||
entities_col: str | None = "entity_ids",
|
||||
relationships_col: str | None = "relationship_ids",
|
||||
covariates_col: str | None = "covariate_ids",
|
||||
sub_communities_col: str | None = "sub_community_ids",
|
||||
attributes_cols: list[str] | None = None,
|
||||
) -> list[Community]:
|
||||
"""Read communities from a dataframe."""
|
||||
@ -230,6 +231,7 @@ def read_communities(
|
||||
covariate_ids=to_optional_dict(
|
||||
row, covariates_col, key_type=str, value_type=str
|
||||
),
|
||||
sub_community_ids=to_optional_list(row, sub_communities_col, item_type=str),
|
||||
attributes=(
|
||||
{col: row.get(col) for col in attributes_cols}
|
||||
if attributes_cols
|
||||
|
||||
76
graphrag/query/llm/get_client.py
Normal file
76
graphrag/query/llm/get_client.py
Normal file
@ -0,0 +1,76 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""Initialize LLM and Embedding clients."""
|
||||
|
||||
from azure.identity import DefaultAzureCredential, get_bearer_token_provider
|
||||
|
||||
from graphrag.config import GraphRagConfig, LLMType
|
||||
from graphrag.query.llm.oai.chat_openai import ChatOpenAI
|
||||
from graphrag.query.llm.oai.embedding import OpenAIEmbedding
|
||||
from graphrag.query.llm.oai.typing import OpenaiApiType
|
||||
|
||||
|
||||
def get_llm(config: GraphRagConfig) -> ChatOpenAI:
|
||||
"""Get the LLM client."""
|
||||
is_azure_client = (
|
||||
config.llm.type == LLMType.AzureOpenAIChat
|
||||
or config.llm.type == LLMType.AzureOpenAI
|
||||
)
|
||||
debug_llm_key = config.llm.api_key or ""
|
||||
llm_debug_info = {
|
||||
**config.llm.model_dump(),
|
||||
"api_key": f"REDACTED,len={len(debug_llm_key)}",
|
||||
}
|
||||
audience = (
|
||||
config.llm.audience
|
||||
if config.llm.audience
|
||||
else "https://cognitiveservices.azure.com/.default"
|
||||
)
|
||||
print(f"creating llm client with {llm_debug_info}") # noqa T201
|
||||
return ChatOpenAI(
|
||||
api_key=config.llm.api_key,
|
||||
azure_ad_token_provider=(
|
||||
get_bearer_token_provider(DefaultAzureCredential(), audience)
|
||||
if is_azure_client and not config.llm.api_key
|
||||
else None
|
||||
),
|
||||
api_base=config.llm.api_base,
|
||||
organization=config.llm.organization,
|
||||
model=config.llm.model,
|
||||
api_type=OpenaiApiType.AzureOpenAI if is_azure_client else OpenaiApiType.OpenAI,
|
||||
deployment_name=config.llm.deployment_name,
|
||||
api_version=config.llm.api_version,
|
||||
max_retries=config.llm.max_retries,
|
||||
request_timeout=config.llm.request_timeout,
|
||||
)
|
||||
|
||||
|
||||
def get_text_embedder(config: GraphRagConfig) -> OpenAIEmbedding:
|
||||
"""Get the LLM client for embeddings."""
|
||||
is_azure_client = config.embeddings.llm.type == LLMType.AzureOpenAIEmbedding
|
||||
debug_embedding_api_key = config.embeddings.llm.api_key or ""
|
||||
llm_debug_info = {
|
||||
**config.embeddings.llm.model_dump(),
|
||||
"api_key": f"REDACTED,len={len(debug_embedding_api_key)}",
|
||||
}
|
||||
if config.embeddings.llm.audience is None:
|
||||
audience = "https://cognitiveservices.azure.com/.default"
|
||||
else:
|
||||
audience = config.embeddings.llm.audience
|
||||
print(f"creating embedding llm client with {llm_debug_info}") # noqa T201
|
||||
return OpenAIEmbedding(
|
||||
api_key=config.embeddings.llm.api_key,
|
||||
azure_ad_token_provider=(
|
||||
get_bearer_token_provider(DefaultAzureCredential(), audience)
|
||||
if is_azure_client and not config.embeddings.llm.api_key
|
||||
else None
|
||||
),
|
||||
api_base=config.embeddings.llm.api_base,
|
||||
organization=config.llm.organization,
|
||||
api_type=OpenaiApiType.AzureOpenAI if is_azure_client else OpenaiApiType.OpenAI,
|
||||
model=config.embeddings.llm.model,
|
||||
deployment_name=config.embeddings.llm.deployment_name,
|
||||
api_version=config.embeddings.llm.api_version,
|
||||
max_retries=config.embeddings.llm.max_retries,
|
||||
)
|
||||
@ -6,12 +6,16 @@
|
||||
from enum import Enum
|
||||
from typing import Any, cast
|
||||
|
||||
import httpx
|
||||
import openai
|
||||
|
||||
OPENAI_RETRY_ERROR_TYPES = (
|
||||
# TODO: update these when we update to OpenAI 1+ library
|
||||
cast(Any, openai).RateLimitError,
|
||||
cast(Any, openai).APIConnectionError,
|
||||
cast(Any, openai).APIError,
|
||||
cast(Any, httpx).RemoteProtocolError,
|
||||
cast(Any, httpx).ReadTimeout,
|
||||
# TODO: replace with comparable OpenAI 1+ error
|
||||
)
|
||||
|
||||
|
||||
@ -31,8 +31,14 @@ class SearchResult:
|
||||
# actual text strings that are in the context window, built from context_data
|
||||
context_text: str | list[str] | dict[str, str]
|
||||
completion_time: float
|
||||
# total LLM calls and token usage
|
||||
llm_calls: int
|
||||
prompt_tokens: int
|
||||
output_tokens: int
|
||||
# breakdown of LLM calls and token usage
|
||||
llm_calls_categories: dict[str, int] | None = None
|
||||
prompt_tokens_categories: dict[str, int] | None = None
|
||||
output_tokens_categories: dict[str, int] | None = None
|
||||
|
||||
|
||||
T = TypeVar("T", GlobalContextBuilder, LocalContextBuilder, DRIFTContextBuilder)
|
||||
|
||||
@ -7,8 +7,6 @@ import json
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from graphrag.query.llm.text_utils import num_tokens
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@ -39,7 +37,11 @@ class DriftAction:
|
||||
self.follow_ups: list[DriftAction] = (
|
||||
follow_ups if follow_ups is not None else []
|
||||
)
|
||||
self.metadata: dict[str, Any] = {}
|
||||
self.metadata: dict[str, Any] = {
|
||||
"llm_calls": 0,
|
||||
"prompt_tokens": 0,
|
||||
"output_tokens": 0,
|
||||
}
|
||||
|
||||
@property
|
||||
def is_complete(self) -> bool:
|
||||
@ -85,12 +87,10 @@ class DriftAction:
|
||||
|
||||
if self.answer is None:
|
||||
log.warning("No answer found for query: %s", self.query)
|
||||
generated_tokens = 0
|
||||
else:
|
||||
generated_tokens = num_tokens(self.answer, search_engine.token_encoder)
|
||||
self.metadata.update({
|
||||
"token_ct": search_result.prompt_tokens + generated_tokens
|
||||
})
|
||||
|
||||
self.metadata["llm_calls"] += 1
|
||||
self.metadata["prompt_tokens"] += search_result.prompt_tokens
|
||||
self.metadata["output_tokens"] += search_result.output_tokens
|
||||
|
||||
self.follow_ups = response.pop("follow_up_queries", [])
|
||||
if not self.follow_ups:
|
||||
|
||||
@ -69,7 +69,6 @@ class DRIFTSearchContextBuilder(DRIFTContextBuilder):
|
||||
self.covariates = covariates
|
||||
self.embedding_vectorstore_key = embedding_vectorstore_key
|
||||
|
||||
self.llm_tokens = 0
|
||||
self.local_mixed_context = (
|
||||
local_mixed_context or self.init_local_context_builder()
|
||||
)
|
||||
@ -160,7 +159,9 @@ class DRIFTSearchContextBuilder(DRIFTContextBuilder):
|
||||
and isinstance(query_embedding[0], type(embedding[0]))
|
||||
)
|
||||
|
||||
def build_context(self, query: str, **kwargs) -> pd.DataFrame:
|
||||
def build_context(
|
||||
self, query: str, **kwargs
|
||||
) -> tuple[pd.DataFrame, dict[str, int]]:
|
||||
"""
|
||||
Build DRIFT search context.
|
||||
|
||||
@ -172,6 +173,7 @@ class DRIFTSearchContextBuilder(DRIFTContextBuilder):
|
||||
Returns
|
||||
-------
|
||||
pd.DataFrame: Top-k most similar documents.
|
||||
dict[str, int]: Number of LLM calls, and prompts and output tokens.
|
||||
|
||||
Raises
|
||||
------
|
||||
@ -192,7 +194,6 @@ class DRIFTSearchContextBuilder(DRIFTContextBuilder):
|
||||
)
|
||||
|
||||
query_embedding, token_ct = query_processor(query)
|
||||
self.llm_tokens += token_ct
|
||||
|
||||
report_df = self.convert_reports_to_df(self.reports)
|
||||
|
||||
@ -219,4 +220,4 @@ class DRIFTSearchContextBuilder(DRIFTContextBuilder):
|
||||
# Sort by similarity and select top-k
|
||||
top_k = report_df.nlargest(self.config.drift_k_followups, "similarity")
|
||||
|
||||
return top_k.loc[:, ["short_id", "community_id", "full_content"]]
|
||||
return top_k.loc[:, ["short_id", "community_id", "full_content"]], token_ct
|
||||
|
||||
@ -50,7 +50,7 @@ class PrimerQueryProcessor:
|
||||
self.token_encoder = token_encoder
|
||||
self.reports = reports
|
||||
|
||||
def expand_query(self, query: str) -> tuple[str, int]:
|
||||
def expand_query(self, query: str) -> tuple[str, dict[str, int]]:
|
||||
"""
|
||||
Expand the query using a random community report template.
|
||||
|
||||
@ -59,9 +59,8 @@ class PrimerQueryProcessor:
|
||||
|
||||
Returns
|
||||
-------
|
||||
tuple[str, int]: Expanded query text and the number of tokens used.
|
||||
tuple[str, dict[str, int]]: Expanded query text and the number of tokens used.
|
||||
"""
|
||||
token_ct = 0
|
||||
template = secrets.choice(self.reports).full_content # nosec S311
|
||||
|
||||
prompt = f"""Create a hypothetical answer to the following query: {query}\n\n
|
||||
@ -72,13 +71,19 @@ class PrimerQueryProcessor:
|
||||
messages = [{"role": "user", "content": prompt}]
|
||||
|
||||
text = self.chat_llm.generate(messages)
|
||||
token_ct = num_tokens(text + query)
|
||||
prompt_tokens = num_tokens(prompt, self.token_encoder)
|
||||
output_tokens = num_tokens(text, self.token_encoder)
|
||||
token_ct = {
|
||||
"llm_calls": 1,
|
||||
"prompt_tokens": prompt_tokens,
|
||||
"output_tokens": output_tokens,
|
||||
}
|
||||
if text == "":
|
||||
log.warning("Failed to generate expansion for query: %s", query)
|
||||
return query, token_ct
|
||||
return text, token_ct
|
||||
|
||||
def __call__(self, query: str) -> tuple[list[float], int]:
|
||||
def __call__(self, query: str) -> tuple[list[float], dict[str, int]]:
|
||||
"""
|
||||
Call method to process the query, expand it, and embed the result.
|
||||
|
||||
@ -117,7 +122,7 @@ class DRIFTPrimer:
|
||||
|
||||
async def decompose_query(
|
||||
self, query: str, reports: pd.DataFrame
|
||||
) -> tuple[dict, int]:
|
||||
) -> tuple[dict, dict[str, int]]:
|
||||
"""
|
||||
Decompose the query into subqueries based on the fetched global structures.
|
||||
|
||||
@ -127,7 +132,7 @@ class DRIFTPrimer:
|
||||
|
||||
Returns
|
||||
-------
|
||||
tuple[dict, int]: Parsed response and the number of tokens used.
|
||||
tuple[dict, int, int]: Parsed response and the number of prompt and output tokens used.
|
||||
"""
|
||||
community_reports = "\n\n".join(reports["full_content"].tolist())
|
||||
prompt = DRIFT_PRIMER_PROMPT.format(
|
||||
@ -140,7 +145,12 @@ class DRIFTPrimer:
|
||||
)
|
||||
|
||||
parsed_response = json.loads(response)
|
||||
token_ct = num_tokens(prompt + response, self.token_encoder)
|
||||
|
||||
token_ct = {
|
||||
"llm_calls": 1,
|
||||
"prompt_tokens": num_tokens(prompt, self.token_encoder),
|
||||
"output_tokens": num_tokens(response, self.token_encoder),
|
||||
}
|
||||
|
||||
return parsed_response, token_ct
|
||||
|
||||
@ -163,7 +173,7 @@ class DRIFTPrimer:
|
||||
start_time = time.perf_counter()
|
||||
report_folds = self.split_reports(top_k_reports)
|
||||
tasks = [self.decompose_query(query, fold) for fold in report_folds]
|
||||
results_with_tokens = await tqdm_asyncio.gather(*tasks)
|
||||
results_with_tokens = await tqdm_asyncio.gather(*tasks, leave=False)
|
||||
|
||||
completion_time = time.perf_counter() - start_time
|
||||
|
||||
@ -173,7 +183,8 @@ class DRIFTPrimer:
|
||||
context_text=top_k_reports.to_json() or "",
|
||||
completion_time=completion_time,
|
||||
llm_calls=len(results_with_tokens),
|
||||
prompt_tokens=sum(tokens for _, tokens in results_with_tokens),
|
||||
prompt_tokens=sum(ct["prompt_tokens"] for _, ct in results_with_tokens),
|
||||
output_tokens=sum(ct["output_tokens"] for _, ct in results_with_tokens),
|
||||
)
|
||||
|
||||
def split_reports(self, reports: pd.DataFrame) -> list[pd.DataFrame]:
|
||||
|
||||
@ -163,7 +163,7 @@ class DRIFTSearch(BaseSearch[DRIFTSearchContextBuilder]):
|
||||
action.asearch(search_engine=search_engine, global_query=global_query)
|
||||
for action in actions
|
||||
]
|
||||
return await tqdm_asyncio.gather(*tasks)
|
||||
return await tqdm_asyncio.gather(*tasks, leave=False)
|
||||
|
||||
async def asearch(
|
||||
self,
|
||||
@ -190,20 +190,25 @@ class DRIFTSearch(BaseSearch[DRIFTSearchContextBuilder]):
|
||||
error_msg = "DRIFT Search query cannot be empty."
|
||||
raise ValueError(error_msg)
|
||||
|
||||
llm_calls, prompt_tokens, output_tokens = {}, {}, {}
|
||||
|
||||
start_time = time.perf_counter()
|
||||
primer_token_ct = 0
|
||||
context_token_ct = 0
|
||||
|
||||
# Check if query state is empty
|
||||
if not self.query_state.graph:
|
||||
# Prime the search with the primer
|
||||
primer_context = self.context_builder.build_context(query)
|
||||
context_token_ct = self.context_builder.llm_tokens
|
||||
primer_context, token_ct = self.context_builder.build_context(query)
|
||||
llm_calls["build_context"] = token_ct["llm_calls"]
|
||||
prompt_tokens["build_context"] = token_ct["prompt_tokens"]
|
||||
output_tokens["build_context"] = token_ct["prompt_tokens"]
|
||||
|
||||
primer_response = await self.primer.asearch(
|
||||
query=query, top_k_reports=primer_context
|
||||
)
|
||||
primer_token_ct = primer_response.prompt_tokens
|
||||
llm_calls["primer"] = primer_response.llm_calls
|
||||
prompt_tokens["primer"] = primer_response.prompt_tokens
|
||||
output_tokens["primer"] = primer_response.output_tokens
|
||||
|
||||
# Package response into DriftAction
|
||||
init_action = self._process_primer_results(query, primer_response)
|
||||
self.query_state.add_action(init_action)
|
||||
@ -233,9 +238,10 @@ class DRIFTSearch(BaseSearch[DRIFTSearchContextBuilder]):
|
||||
t_elapsed = time.perf_counter() - start_time
|
||||
|
||||
# Calculate token usage
|
||||
total_tokens = (
|
||||
primer_token_ct + context_token_ct + self.query_state.action_token_ct()
|
||||
)
|
||||
token_ct = self.query_state.action_token_ct()
|
||||
llm_calls["action"] = token_ct["llm_calls"]
|
||||
prompt_tokens["action"] = token_ct["prompt_tokens"]
|
||||
output_tokens["action"] = token_ct["output_tokens"]
|
||||
|
||||
# Package up context data
|
||||
response_state, context_data, context_text = self.query_state.serialize(
|
||||
@ -247,10 +253,12 @@ class DRIFTSearch(BaseSearch[DRIFTSearchContextBuilder]):
|
||||
context_data=context_data,
|
||||
context_text=context_text,
|
||||
completion_time=t_elapsed,
|
||||
llm_calls=1
|
||||
+ self.config.primer_folds
|
||||
+ (self.config.drift_k_followups - llm_call_offset) * self.config.n_depth,
|
||||
prompt_tokens=total_tokens,
|
||||
llm_calls=sum(llm_calls.values()),
|
||||
prompt_tokens=sum(prompt_tokens.values()),
|
||||
output_tokens=sum(output_tokens.values()),
|
||||
llm_calls_categories=llm_calls,
|
||||
prompt_tokens_categories=prompt_tokens,
|
||||
output_tokens_categories=output_tokens,
|
||||
)
|
||||
|
||||
def search(
|
||||
|
||||
@ -136,6 +136,15 @@ class QueryState:
|
||||
if source_action and target_action:
|
||||
self.relate_actions(source_action, target_action, weight)
|
||||
|
||||
def action_token_ct(self) -> int:
|
||||
def action_token_ct(self) -> dict[str, int]:
|
||||
"""Return the token count of the action."""
|
||||
return sum(action.metadata.get("token_ct", 0) for action in self.graph.nodes)
|
||||
llm_calls, prompt_tokens, output_tokens = 0, 0, 0
|
||||
for action in self.graph.nodes:
|
||||
llm_calls += action.metadata["llm_calls"]
|
||||
prompt_tokens += action.metadata["prompt_tokens"]
|
||||
output_tokens += action.metadata["output_tokens"]
|
||||
return {
|
||||
"llm_calls": llm_calls,
|
||||
"prompt_tokens": prompt_tokens,
|
||||
"output_tokens": output_tokens,
|
||||
}
|
||||
|
||||
@ -5,16 +5,19 @@
|
||||
|
||||
from typing import Any
|
||||
|
||||
import pandas as pd
|
||||
import tiktoken
|
||||
|
||||
from graphrag.model import CommunityReport, Entity
|
||||
from graphrag.model import Community, CommunityReport, Entity
|
||||
from graphrag.query.context_builder.builders import ContextBuilderResult
|
||||
from graphrag.query.context_builder.community_context import (
|
||||
build_community_context,
|
||||
)
|
||||
from graphrag.query.context_builder.conversation_history import (
|
||||
ConversationHistory,
|
||||
)
|
||||
from graphrag.query.context_builder.dynamic_community_selection import (
|
||||
DynamicCommunitySelection,
|
||||
)
|
||||
from graphrag.query.structured_search.base import GlobalContextBuilder
|
||||
|
||||
|
||||
@ -24,17 +27,32 @@ class GlobalCommunityContext(GlobalContextBuilder):
|
||||
def __init__(
|
||||
self,
|
||||
community_reports: list[CommunityReport],
|
||||
communities: list[Community],
|
||||
entities: list[Entity] | None = None,
|
||||
token_encoder: tiktoken.Encoding | None = None,
|
||||
dynamic_community_selection: bool = False,
|
||||
dynamic_community_selection_kwargs: dict[str, Any] | None = None,
|
||||
random_state: int = 86,
|
||||
):
|
||||
self.community_reports = community_reports
|
||||
self.entities = entities
|
||||
self.token_encoder = token_encoder
|
||||
self.dynamic_community_selection = None
|
||||
if dynamic_community_selection and isinstance(
|
||||
dynamic_community_selection_kwargs, dict
|
||||
):
|
||||
self.dynamic_community_selection = DynamicCommunitySelection(
|
||||
community_reports=community_reports,
|
||||
communities=communities,
|
||||
llm=dynamic_community_selection_kwargs.pop("llm"),
|
||||
token_encoder=dynamic_community_selection_kwargs.pop("token_encoder"),
|
||||
**dynamic_community_selection_kwargs,
|
||||
)
|
||||
self.random_state = random_state
|
||||
|
||||
def build_context(
|
||||
async def build_context(
|
||||
self,
|
||||
query: str,
|
||||
conversation_history: ConversationHistory | None = None,
|
||||
use_community_summary: bool = True,
|
||||
column_delimiter: str = "|",
|
||||
@ -50,10 +68,11 @@ class GlobalCommunityContext(GlobalContextBuilder):
|
||||
conversation_history_user_turns_only: bool = True,
|
||||
conversation_history_max_turns: int | None = 5,
|
||||
**kwargs: Any,
|
||||
) -> tuple[str | list[str], dict[str, pd.DataFrame]]:
|
||||
) -> ContextBuilderResult:
|
||||
"""Prepare batches of community report data table as context data for global search."""
|
||||
conversation_history_context = ""
|
||||
final_context_data = {}
|
||||
llm_calls, prompt_tokens, output_tokens = 0, 0, 0
|
||||
if conversation_history:
|
||||
# build conversation history context
|
||||
(
|
||||
@ -69,8 +88,18 @@ class GlobalCommunityContext(GlobalContextBuilder):
|
||||
if conversation_history_context != "":
|
||||
final_context_data = conversation_history_context_data
|
||||
|
||||
community_reports = self.community_reports
|
||||
if self.dynamic_community_selection is not None:
|
||||
(
|
||||
community_reports,
|
||||
dynamic_info,
|
||||
) = await self.dynamic_community_selection.select(query)
|
||||
llm_calls += dynamic_info["llm_calls"]
|
||||
prompt_tokens += dynamic_info["prompt_tokens"]
|
||||
output_tokens += dynamic_info["output_tokens"]
|
||||
|
||||
community_context, community_context_data = build_community_context(
|
||||
community_reports=self.community_reports,
|
||||
community_reports=community_reports,
|
||||
entities=self.entities,
|
||||
token_encoder=self.token_encoder,
|
||||
use_community_summary=use_community_summary,
|
||||
@ -104,4 +133,10 @@ class GlobalCommunityContext(GlobalContextBuilder):
|
||||
# Update the final context data with the provided community_context_data
|
||||
final_context_data.update(community_context_data)
|
||||
|
||||
return final_context, final_context_data
|
||||
return ContextBuilderResult(
|
||||
context_chunks=final_context,
|
||||
context_records=final_context_data,
|
||||
llm_calls=llm_calls,
|
||||
prompt_tokens=prompt_tokens,
|
||||
output_tokens=output_tokens,
|
||||
)
|
||||
|
||||
@ -45,7 +45,7 @@ DEFAULT_REDUCE_LLM_PARAMS = {
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
@dataclass(kw_only=True)
|
||||
class GlobalSearchResult(SearchResult):
|
||||
"""A GlobalSearch result."""
|
||||
|
||||
@ -105,24 +105,25 @@ class GlobalSearch(BaseSearch[GlobalContextBuilder]):
|
||||
conversation_history: ConversationHistory | None = None,
|
||||
) -> AsyncGenerator:
|
||||
"""Stream the global search response."""
|
||||
context_chunks, context_records = self.context_builder.build_context(
|
||||
conversation_history=conversation_history, **self.context_builder_params
|
||||
context_result = await self.context_builder.build_context(
|
||||
conversation_history=conversation_history,
|
||||
**self.context_builder_params,
|
||||
)
|
||||
if self.callbacks:
|
||||
for callback in self.callbacks:
|
||||
callback.on_map_response_start(context_chunks) # type: ignore
|
||||
callback.on_map_response_start(context_result.context_chunks) # type: ignore
|
||||
map_responses = await asyncio.gather(*[
|
||||
self._map_response_single_batch(
|
||||
context_data=data, query=query, **self.map_llm_params
|
||||
)
|
||||
for data in context_chunks
|
||||
for data in context_result.context_chunks
|
||||
])
|
||||
if self.callbacks:
|
||||
for callback in self.callbacks:
|
||||
callback.on_map_response_end(map_responses) # type: ignore
|
||||
|
||||
# send context records first before sending the reduce response
|
||||
yield context_records
|
||||
yield context_result.context_records
|
||||
async for response in self._stream_reduce_response(
|
||||
map_responses=map_responses, # type: ignore
|
||||
query=query,
|
||||
@ -145,26 +146,33 @@ class GlobalSearch(BaseSearch[GlobalContextBuilder]):
|
||||
- Step 2: Combine the answers from step 2 to generate the final answer
|
||||
"""
|
||||
# Step 1: Generate answers for each batch of community short summaries
|
||||
llm_calls, prompt_tokens, output_tokens = {}, {}, {}
|
||||
|
||||
start_time = time.time()
|
||||
context_chunks, context_records = self.context_builder.build_context(
|
||||
conversation_history=conversation_history, **self.context_builder_params
|
||||
context_result = await self.context_builder.build_context(
|
||||
query=query,
|
||||
conversation_history=conversation_history,
|
||||
**self.context_builder_params,
|
||||
)
|
||||
llm_calls["build_context"] = context_result.llm_calls
|
||||
prompt_tokens["build_context"] = context_result.prompt_tokens
|
||||
output_tokens["build_context"] = context_result.output_tokens
|
||||
|
||||
if self.callbacks:
|
||||
for callback in self.callbacks:
|
||||
callback.on_map_response_start(context_chunks) # type: ignore
|
||||
callback.on_map_response_start(context_result.context_chunks) # type: ignore
|
||||
map_responses = await asyncio.gather(*[
|
||||
self._map_response_single_batch(
|
||||
context_data=data, query=query, **self.map_llm_params
|
||||
)
|
||||
for data in context_chunks
|
||||
for data in context_result.context_chunks
|
||||
])
|
||||
if self.callbacks:
|
||||
for callback in self.callbacks:
|
||||
callback.on_map_response_end(map_responses)
|
||||
map_llm_calls = sum(response.llm_calls for response in map_responses)
|
||||
map_prompt_tokens = sum(response.prompt_tokens for response in map_responses)
|
||||
llm_calls["map"] = sum(response.llm_calls for response in map_responses)
|
||||
prompt_tokens["map"] = sum(response.prompt_tokens for response in map_responses)
|
||||
output_tokens["map"] = sum(response.output_tokens for response in map_responses)
|
||||
|
||||
# Step 2: Combine the intermediate answers from step 2 to generate the final answer
|
||||
reduce_response = await self._reduce_response(
|
||||
@ -172,17 +180,24 @@ class GlobalSearch(BaseSearch[GlobalContextBuilder]):
|
||||
query=query,
|
||||
**self.reduce_llm_params,
|
||||
)
|
||||
llm_calls["reduce"] = reduce_response.llm_calls
|
||||
prompt_tokens["reduce"] = reduce_response.prompt_tokens
|
||||
output_tokens["reduce"] = reduce_response.output_tokens
|
||||
|
||||
return GlobalSearchResult(
|
||||
response=reduce_response.response,
|
||||
context_data=context_records,
|
||||
context_text=context_chunks,
|
||||
context_data=context_result.context_records,
|
||||
context_text=context_result.context_chunks,
|
||||
map_responses=map_responses,
|
||||
reduce_context_data=reduce_response.context_data,
|
||||
reduce_context_text=reduce_response.context_text,
|
||||
completion_time=time.time() - start_time,
|
||||
llm_calls=map_llm_calls + reduce_response.llm_calls,
|
||||
prompt_tokens=map_prompt_tokens + reduce_response.prompt_tokens,
|
||||
llm_calls=sum(llm_calls.values()),
|
||||
prompt_tokens=sum(prompt_tokens.values()),
|
||||
output_tokens=sum(output_tokens.values()),
|
||||
llm_calls_categories=llm_calls,
|
||||
prompt_tokens_categories=prompt_tokens,
|
||||
output_tokens_categories=output_tokens,
|
||||
)
|
||||
|
||||
def search(
|
||||
@ -230,6 +245,7 @@ class GlobalSearch(BaseSearch[GlobalContextBuilder]):
|
||||
completion_time=time.time() - start_time,
|
||||
llm_calls=1,
|
||||
prompt_tokens=num_tokens(search_prompt, self.token_encoder),
|
||||
output_tokens=num_tokens(search_response, self.token_encoder),
|
||||
)
|
||||
|
||||
except Exception:
|
||||
@ -241,6 +257,7 @@ class GlobalSearch(BaseSearch[GlobalContextBuilder]):
|
||||
completion_time=time.time() - start_time,
|
||||
llm_calls=1,
|
||||
prompt_tokens=num_tokens(search_prompt, self.token_encoder),
|
||||
output_tokens=0,
|
||||
)
|
||||
|
||||
def parse_search_response(self, search_response: str) -> list[dict[str, Any]]:
|
||||
@ -319,6 +336,7 @@ class GlobalSearch(BaseSearch[GlobalContextBuilder]):
|
||||
completion_time=time.time() - start_time,
|
||||
llm_calls=0,
|
||||
prompt_tokens=0,
|
||||
output_tokens=0,
|
||||
)
|
||||
|
||||
filtered_key_points = sorted(
|
||||
@ -372,6 +390,7 @@ class GlobalSearch(BaseSearch[GlobalContextBuilder]):
|
||||
completion_time=time.time() - start_time,
|
||||
llm_calls=1,
|
||||
prompt_tokens=num_tokens(search_prompt, self.token_encoder),
|
||||
output_tokens=num_tokens(search_response, self.token_encoder),
|
||||
)
|
||||
except Exception:
|
||||
log.exception("Exception in reduce_response")
|
||||
@ -382,6 +401,7 @@ class GlobalSearch(BaseSearch[GlobalContextBuilder]):
|
||||
completion_time=time.time() - start_time,
|
||||
llm_calls=1,
|
||||
prompt_tokens=num_tokens(search_prompt, self.token_encoder),
|
||||
output_tokens=0,
|
||||
)
|
||||
|
||||
async def _stream_reduce_response(
|
||||
|
||||
@ -16,6 +16,7 @@ from graphrag.model import (
|
||||
Relationship,
|
||||
TextUnit,
|
||||
)
|
||||
from graphrag.query.context_builder.builders import ContextBuilderResult
|
||||
from graphrag.query.context_builder.community_context import (
|
||||
build_community_context,
|
||||
)
|
||||
@ -113,7 +114,7 @@ class LocalSearchMixedContext(LocalContextBuilder):
|
||||
community_context_name: str = "Reports",
|
||||
column_delimiter: str = "|",
|
||||
**kwargs: dict[str, Any],
|
||||
) -> tuple[str | list[str], dict[str, pd.DataFrame]]:
|
||||
) -> ContextBuilderResult:
|
||||
"""
|
||||
Build data context for local search prompt.
|
||||
|
||||
@ -217,7 +218,10 @@ class LocalSearchMixedContext(LocalContextBuilder):
|
||||
final_context.append(text_unit_context)
|
||||
final_context_data = {**final_context_data, **text_unit_context_data}
|
||||
|
||||
return ("\n\n".join(final_context), final_context_data)
|
||||
return ContextBuilderResult(
|
||||
context_chunks="\n\n".join(final_context),
|
||||
context_records=final_context_data,
|
||||
)
|
||||
|
||||
def _build_community_context(
|
||||
self,
|
||||
|
||||
@ -63,25 +63,30 @@ class LocalSearch(BaseSearch[LocalContextBuilder]):
|
||||
"""Build local search context that fits a single context window and generate answer for the user query."""
|
||||
start_time = time.time()
|
||||
search_prompt = ""
|
||||
|
||||
context_text, context_records = self.context_builder.build_context(
|
||||
llm_calls, prompt_tokens, output_tokens = {}, {}, {}
|
||||
context_result = self.context_builder.build_context(
|
||||
query=query,
|
||||
conversation_history=conversation_history,
|
||||
**kwargs,
|
||||
**self.context_builder_params,
|
||||
)
|
||||
llm_calls["build_context"] = context_result.llm_calls
|
||||
prompt_tokens["build_context"] = context_result.prompt_tokens
|
||||
output_tokens["build_context"] = context_result.output_tokens
|
||||
|
||||
log.info("GENERATE ANSWER: %s. QUERY: %s", start_time, query)
|
||||
try:
|
||||
if "drift_query" in kwargs:
|
||||
drift_query = kwargs["drift_query"]
|
||||
search_prompt = self.system_prompt.format(
|
||||
context_data=context_text,
|
||||
context_data=context_result.context_chunks,
|
||||
response_type=self.response_type,
|
||||
global_query=drift_query,
|
||||
)
|
||||
else:
|
||||
search_prompt = self.system_prompt.format(
|
||||
context_data=context_text, response_type=self.response_type
|
||||
context_data=context_result.context_chunks,
|
||||
response_type=self.response_type,
|
||||
)
|
||||
search_messages = [
|
||||
{"role": "system", "content": search_prompt},
|
||||
@ -94,25 +99,33 @@ class LocalSearch(BaseSearch[LocalContextBuilder]):
|
||||
callbacks=self.callbacks,
|
||||
**self.llm_params,
|
||||
)
|
||||
llm_calls["response"] = 1
|
||||
prompt_tokens["response"] = num_tokens(search_prompt, self.token_encoder)
|
||||
output_tokens["response"] = num_tokens(response, self.token_encoder)
|
||||
|
||||
return SearchResult(
|
||||
response=response,
|
||||
context_data=context_records,
|
||||
context_text=context_text,
|
||||
context_data=context_result.context_records,
|
||||
context_text=context_result.context_chunks,
|
||||
completion_time=time.time() - start_time,
|
||||
llm_calls=1,
|
||||
prompt_tokens=num_tokens(search_prompt, self.token_encoder),
|
||||
llm_calls=sum(llm_calls.values()),
|
||||
prompt_tokens=sum(prompt_tokens.values()),
|
||||
output_tokens=sum(output_tokens.values()),
|
||||
llm_calls_categories=llm_calls,
|
||||
prompt_tokens_categories=prompt_tokens,
|
||||
output_tokens_categories=output_tokens,
|
||||
)
|
||||
|
||||
except Exception:
|
||||
log.exception("Exception in _asearch")
|
||||
return SearchResult(
|
||||
response="",
|
||||
context_data=context_records,
|
||||
context_text=context_text,
|
||||
context_data=context_result.context_records,
|
||||
context_text=context_result.context_chunks,
|
||||
completion_time=time.time() - start_time,
|
||||
llm_calls=1,
|
||||
prompt_tokens=num_tokens(search_prompt, self.token_encoder),
|
||||
output_tokens=0,
|
||||
)
|
||||
|
||||
async def astream_search(
|
||||
@ -123,14 +136,14 @@ class LocalSearch(BaseSearch[LocalContextBuilder]):
|
||||
"""Build local search context that fits a single context window and generate answer for the user query."""
|
||||
start_time = time.time()
|
||||
|
||||
context_text, context_records = self.context_builder.build_context(
|
||||
context_result = self.context_builder.build_context(
|
||||
query=query,
|
||||
conversation_history=conversation_history,
|
||||
**self.context_builder_params,
|
||||
)
|
||||
log.info("GENERATE ANSWER: %s. QUERY: %s", start_time, query)
|
||||
search_prompt = self.system_prompt.format(
|
||||
context_data=context_text, response_type=self.response_type
|
||||
context_data=context_result.context_chunks, response_type=self.response_type
|
||||
)
|
||||
search_messages = [
|
||||
{"role": "system", "content": search_prompt},
|
||||
@ -138,7 +151,7 @@ class LocalSearch(BaseSearch[LocalContextBuilder]):
|
||||
]
|
||||
|
||||
# send context records first before sending the reduce response
|
||||
yield context_records
|
||||
yield context_result.context_records
|
||||
async for response in self.llm.astream_generate( # type: ignore
|
||||
messages=search_messages,
|
||||
callbacks=self.callbacks,
|
||||
@ -155,16 +168,22 @@ class LocalSearch(BaseSearch[LocalContextBuilder]):
|
||||
"""Build local search context that fits a single context window and generate answer for the user question."""
|
||||
start_time = time.time()
|
||||
search_prompt = ""
|
||||
context_text, context_records = self.context_builder.build_context(
|
||||
llm_calls, prompt_tokens, output_tokens = {}, {}, {}
|
||||
context_result = self.context_builder.build_context(
|
||||
query=query,
|
||||
conversation_history=conversation_history,
|
||||
**kwargs,
|
||||
**self.context_builder_params,
|
||||
)
|
||||
llm_calls["build_context"] = context_result.llm_calls
|
||||
prompt_tokens["build_context"] = context_result.prompt_tokens
|
||||
output_tokens["build_context"] = context_result.output_tokens
|
||||
|
||||
log.info("GENERATE ANSWER: %d. QUERY: %s", start_time, query)
|
||||
try:
|
||||
search_prompt = self.system_prompt.format(
|
||||
context_data=context_text, response_type=self.response_type
|
||||
context_data=context_result.context_chunks,
|
||||
response_type=self.response_type,
|
||||
)
|
||||
search_messages = [
|
||||
{"role": "system", "content": search_prompt},
|
||||
@ -177,23 +196,31 @@ class LocalSearch(BaseSearch[LocalContextBuilder]):
|
||||
callbacks=self.callbacks,
|
||||
**self.llm_params,
|
||||
)
|
||||
llm_calls["response"] = 1
|
||||
prompt_tokens["response"] = num_tokens(search_prompt, self.token_encoder)
|
||||
output_tokens["response"] = num_tokens(response, self.token_encoder)
|
||||
|
||||
return SearchResult(
|
||||
response=response,
|
||||
context_data=context_records,
|
||||
context_text=context_text,
|
||||
context_data=context_result.context_records,
|
||||
context_text=context_result.context_chunks,
|
||||
completion_time=time.time() - start_time,
|
||||
llm_calls=1,
|
||||
prompt_tokens=num_tokens(search_prompt, self.token_encoder),
|
||||
llm_calls=sum(llm_calls.values()),
|
||||
prompt_tokens=sum(prompt_tokens.values()),
|
||||
output_tokens=sum(output_tokens.values()),
|
||||
llm_calls_categories=llm_calls,
|
||||
prompt_tokens_categories=prompt_tokens,
|
||||
output_tokens_categories=output_tokens,
|
||||
)
|
||||
|
||||
except Exception:
|
||||
log.exception("Exception in _map_response_single_batch")
|
||||
return SearchResult(
|
||||
response="",
|
||||
context_data=context_records,
|
||||
context_text=context_text,
|
||||
context_data=context_result.context_records,
|
||||
context_text=context_result.context_chunks,
|
||||
completion_time=time.time() - start_time,
|
||||
llm_calls=1,
|
||||
prompt_tokens=num_tokens(search_prompt, self.token_encoder),
|
||||
output_tokens=0,
|
||||
)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user