Implement dynamic community selection for global search (#1396)

* update gitignore

* add dynamic community sleection to updated main branch

* update SearchResult to record output_tokens.

* update search result

* dynamic search working

* format

* add llm_calls_categories and prompt_tokens and output_tokens cate

* update

* formatting

* log drift search output and prompt tokens separately

* update global_search.ipynb. update operate dulce dataset and add create_final_communities. update dynamic community selection init

* add .ipynb back to cspell.config.yaml

* format

* add notebook example on dynamic search

* rearrange

* update gitignore

* format code

* code format

* code format

* fix default variable

---------

Co-authored-by: Bryan Li <bryanlimy@gmail.com>
This commit is contained in:
Alonso Guevara 2024-11-11 18:45:07 -06:00 committed by GitHub
parent ba50caab4d
commit e53422366d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
37 changed files with 1873 additions and 210 deletions

6
.gitignore vendored
View File

@ -50,3 +50,9 @@ site/
docsite/
.yarn/
.pnp*
# PyCharm
.idea/
# Jupyter notebook
.ipynb_checkpoints/

View File

@ -0,0 +1,4 @@
{
"type": "patch",
"description": "Implement dynamic community selection to global search"
}

View File

@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "code",
"execution_count": null,
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
@ -12,7 +12,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
@ -21,7 +21,11 @@
"import pandas as pd\n",
"import tiktoken\n",
"\n",
"from graphrag.query.indexer_adapters import read_indexer_entities, read_indexer_reports\n",
"from graphrag.query.indexer_adapters import (\n",
" read_indexer_communities,\n",
" read_indexer_entities,\n",
" read_indexer_reports,\n",
")\n",
"from graphrag.query.llm.oai.chat_openai import ChatOpenAI\n",
"from graphrag.query.llm.oai.typing import OpenaiApiType\n",
"from graphrag.query.structured_search.global_search.community_context import (\n",
@ -62,7 +66,7 @@
" max_retries=20,\n",
")\n",
"\n",
"token_encoder = tiktoken.get_encoding(\"cl100k_base\")"
"token_encoder = tiktoken.encoding_for_model(llm_model)"
]
},
{
@ -72,17 +76,19 @@
"### Load community reports as context for global search\n",
"\n",
"- Load all community reports in the `create_final_community_reports` table from the ire-indexing engine, to be used as context data for global search.\n",
"- Load entities from the `create_final_nodes` and `create_final_entities` tables from the ire-indexing engine, to be used for calculating community weights for context ranking. Note that this is optional (if no entities are provided, we will not calculate community weights and only use the `rank` attribute in the community reports table for context ranking)"
"- Load entities from the `create_final_nodes` and `create_final_entities` tables from the ire-indexing engine, to be used for calculating community weights for context ranking. Note that this is optional (if no entities are provided, we will not calculate community weights and only use the rank attribute in the community reports table for context ranking)\n",
"- Load all communities in the `create_final_communites` table from the ire-indexing engine, to be used to reconstruct the community graph hierarchy for dynamic community selection."
]
},
{
"cell_type": "code",
"execution_count": 16,
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"# parquet files generated from indexing pipeline\n",
"INPUT_DIR = \"./inputs/operation dulce\"\n",
"COMMUNITY_TABLE = \"create_final_communities\"\n",
"COMMUNITY_REPORT_TABLE = \"create_final_community_reports\"\n",
"ENTITY_TABLE = \"create_final_nodes\"\n",
"ENTITY_EMBEDDING_TABLE = \"create_final_entities\"\n",
@ -94,20 +100,191 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 5,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Total report count: 20\n",
"Report count after filtering by community level 2: 17\n"
]
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>community</th>\n",
" <th>full_content</th>\n",
" <th>level</th>\n",
" <th>rank</th>\n",
" <th>title</th>\n",
" <th>rank_explanation</th>\n",
" <th>summary</th>\n",
" <th>findings</th>\n",
" <th>full_content_json</th>\n",
" <th>id</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>10</td>\n",
" <td># Paranormal Military Squad at Dulce Base: Dec...</td>\n",
" <td>1</td>\n",
" <td>8.5</td>\n",
" <td>Paranormal Military Squad at Dulce Base: Decod...</td>\n",
" <td>The impact severity rating is high due to the ...</td>\n",
" <td>The Paranormal Military Squad, stationed at Du...</td>\n",
" <td>[{'explanation': 'Jordan is a central figure i...</td>\n",
" <td>{\\n \"title\": \"Paranormal Military Squad at ...</td>\n",
" <td>1ba2d200-dd26-4693-affe-a5539d0a0e0d</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>11</td>\n",
" <td># Dulce and Paranormal Military Squad Operatio...</td>\n",
" <td>1</td>\n",
" <td>8.5</td>\n",
" <td>Dulce and Paranormal Military Squad Operations</td>\n",
" <td>The impact severity rating is high due to the ...</td>\n",
" <td>The community centers around Dulce, a secretiv...</td>\n",
" <td>[{'explanation': 'Dulce is described as a top-...</td>\n",
" <td>{\\n \"title\": \"Dulce and Paranormal Military...</td>\n",
" <td>a8a530b0-ae6b-44ea-b11c-9f70d138298d</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>12</td>\n",
" <td># Paranormal Military Squad and Dulce Base Ope...</td>\n",
" <td>1</td>\n",
" <td>7.5</td>\n",
" <td>Paranormal Military Squad and Dulce Base Opera...</td>\n",
" <td>The impact severity rating is relatively high ...</td>\n",
" <td>The community centers around the Paranormal Mi...</td>\n",
" <td>[{'explanation': 'Taylor is a central figure w...</td>\n",
" <td>{\\n \"title\": \"Paranormal Military Squad and...</td>\n",
" <td>0478975b-c805-4cc1-b746-82f3e689e2f3</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>13</td>\n",
" <td># Mission Dynamics and Leadership: Cruz and Wa...</td>\n",
" <td>1</td>\n",
" <td>7.5</td>\n",
" <td>Mission Dynamics and Leadership: Cruz and Wash...</td>\n",
" <td>The impact severity rating is relatively high ...</td>\n",
" <td>This report explores the intricate dynamics of...</td>\n",
" <td>[{'explanation': 'Cruz is a central figure in ...</td>\n",
" <td>{\\n \"title\": \"Mission Dynamics and Leadersh...</td>\n",
" <td>b56f6e68-3951-4f07-8760-63700944a375</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>14</td>\n",
" <td># Dulce Base and Paranormal Military Squad: Br...</td>\n",
" <td>1</td>\n",
" <td>8.5</td>\n",
" <td>Dulce Base and Paranormal Military Squad: Brid...</td>\n",
" <td>The impact severity rating is high due to the ...</td>\n",
" <td>The community centers around the Dulce Base, a...</td>\n",
" <td>[{'explanation': 'Sam Rivera, a member of the ...</td>\n",
" <td>{\\n \"title\": \"Dulce Base and Paranormal Mil...</td>\n",
" <td>736e7006-d050-4abb-a122-00febf3f540f</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" community full_content level rank \\\n",
"0 10 # Paranormal Military Squad at Dulce Base: Dec... 1 8.5 \n",
"1 11 # Dulce and Paranormal Military Squad Operatio... 1 8.5 \n",
"2 12 # Paranormal Military Squad and Dulce Base Ope... 1 7.5 \n",
"3 13 # Mission Dynamics and Leadership: Cruz and Wa... 1 7.5 \n",
"4 14 # Dulce Base and Paranormal Military Squad: Br... 1 8.5 \n",
"\n",
" title \\\n",
"0 Paranormal Military Squad at Dulce Base: Decod... \n",
"1 Dulce and Paranormal Military Squad Operations \n",
"2 Paranormal Military Squad and Dulce Base Opera... \n",
"3 Mission Dynamics and Leadership: Cruz and Wash... \n",
"4 Dulce Base and Paranormal Military Squad: Brid... \n",
"\n",
" rank_explanation \\\n",
"0 The impact severity rating is high due to the ... \n",
"1 The impact severity rating is high due to the ... \n",
"2 The impact severity rating is relatively high ... \n",
"3 The impact severity rating is relatively high ... \n",
"4 The impact severity rating is high due to the ... \n",
"\n",
" summary \\\n",
"0 The Paranormal Military Squad, stationed at Du... \n",
"1 The community centers around Dulce, a secretiv... \n",
"2 The community centers around the Paranormal Mi... \n",
"3 This report explores the intricate dynamics of... \n",
"4 The community centers around the Dulce Base, a... \n",
"\n",
" findings \\\n",
"0 [{'explanation': 'Jordan is a central figure i... \n",
"1 [{'explanation': 'Dulce is described as a top-... \n",
"2 [{'explanation': 'Taylor is a central figure w... \n",
"3 [{'explanation': 'Cruz is a central figure in ... \n",
"4 [{'explanation': 'Sam Rivera, a member of the ... \n",
"\n",
" full_content_json \\\n",
"0 {\\n \"title\": \"Paranormal Military Squad at ... \n",
"1 {\\n \"title\": \"Dulce and Paranormal Military... \n",
"2 {\\n \"title\": \"Paranormal Military Squad and... \n",
"3 {\\n \"title\": \"Mission Dynamics and Leadersh... \n",
"4 {\\n \"title\": \"Dulce Base and Paranormal Mil... \n",
"\n",
" id \n",
"0 1ba2d200-dd26-4693-affe-a5539d0a0e0d \n",
"1 a8a530b0-ae6b-44ea-b11c-9f70d138298d \n",
"2 0478975b-c805-4cc1-b746-82f3e689e2f3 \n",
"3 b56f6e68-3951-4f07-8760-63700944a375 \n",
"4 736e7006-d050-4abb-a122-00febf3f540f "
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"community_df = pd.read_parquet(f\"{INPUT_DIR}/{COMMUNITY_TABLE}.parquet\")\n",
"entity_df = pd.read_parquet(f\"{INPUT_DIR}/{ENTITY_TABLE}.parquet\")\n",
"report_df = pd.read_parquet(f\"{INPUT_DIR}/{COMMUNITY_REPORT_TABLE}.parquet\")\n",
"entity_embedding_df = pd.read_parquet(f\"{INPUT_DIR}/{ENTITY_EMBEDDING_TABLE}.parquet\")\n",
"\n",
"communities = read_indexer_communities(community_df, entity_df, report_df)\n",
"reports = read_indexer_reports(report_df, entity_df, COMMUNITY_LEVEL)\n",
"entities = read_indexer_entities(entity_df, entity_embedding_df, COMMUNITY_LEVEL)\n",
"\n",
"print(f\"Total report count: {len(report_df)}\")\n",
"print(\n",
" f\"Report count after filtering by community level {COMMUNITY_LEVEL}: {len(reports)}\"\n",
")\n",
"\n",
"report_df.head()"
]
},
@ -120,12 +297,13 @@
},
{
"cell_type": "code",
"execution_count": 18,
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"context_builder = GlobalCommunityContext(\n",
" community_reports=reports,\n",
" communities=communities,\n",
" entities=entities, # default to None if you don't want to use community weights for ranking\n",
" token_encoder=token_encoder,\n",
")"
@ -140,7 +318,7 @@
},
{
"cell_type": "code",
"execution_count": 19,
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
@ -171,7 +349,7 @@
},
{
"cell_type": "code",
"execution_count": 20,
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
@ -192,12 +370,36 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 9,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"### Overview of Cosmic Vocalization\n",
"\n",
"Cosmic Vocalization is a phenomenon that has garnered significant attention from various individuals and groups. It is perceived as a cosmic event with potential implications for security and interstellar communication. The Paranormal Military Squad is actively engaged with Cosmic Vocalization, indicating its strategic importance in security measures [Data: Reports (6)].\n",
"\n",
"### Key Perspectives and Concerns\n",
"\n",
"1. **Strategic Engagement**: The Paranormal Military Squad's involvement suggests that Cosmic Vocalization is not only a subject of interest but also a matter of strategic importance. This engagement highlights the potential security implications of these cosmic phenomena [Data: Reports (6)].\n",
"\n",
"2. **Community Interest**: Within the community, Cosmic Vocalization is a focal point of interest. Alex Mercer, for instance, perceives it as part of an interstellar duet, which suggests a responsive and perhaps communicative approach to these cosmic events [Data: Reports (6)].\n",
"\n",
"3. **Potential Threats**: Concerns have been raised by individuals like Taylor Cruz, who fears that Cosmic Vocalization might be a homing tune. This perspective adds a layer of urgency and suggests that there may be potential threats associated with these cosmic sounds [Data: Reports (6)].\n",
"\n",
"### Metaphorical Interpretation\n",
"\n",
"The Universe is metaphorically treated as a concert hall by the Paranormal Military Squad, which suggests a broader perspective on how cosmic events are interpreted and responded to by human entities. This metaphorical view may influence how strategies and responses are formulated in relation to Cosmic Vocalization [Data: Reports (6)].\n",
"\n",
"In summary, Cosmic Vocalization is a complex phenomenon involving strategic, communicative, and potentially threatening elements. The involvement of the Paranormal Military Squad and the concerns raised by community members underscore its significance and the need for careful consideration of its implications.\n"
]
}
],
"source": [
"result = await search_engine.asearch(\n",
" \"What is the major conflict in this story and who are the protagonist and antagonist?\"\n",
" \"What is Cosmic Vocalization and who are involved in it?\"\n",
")\n",
"\n",
"print(result.response)"
@ -205,9 +407,223 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 10,
"metadata": {},
"outputs": [],
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>id</th>\n",
" <th>title</th>\n",
" <th>occurrence weight</th>\n",
" <th>content</th>\n",
" <th>rank</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>15</td>\n",
" <td>Dulce Base and the Paranormal Military Squad: ...</td>\n",
" <td>1.00</td>\n",
" <td># Dulce Base and the Paranormal Military Squad...</td>\n",
" <td>9.5</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>11</td>\n",
" <td>Dulce and Paranormal Military Squad Operations</td>\n",
" <td>0.30</td>\n",
" <td># Dulce and Paranormal Military Squad Operatio...</td>\n",
" <td>8.5</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>10</td>\n",
" <td>Paranormal Military Squad at Dulce Base: Decod...</td>\n",
" <td>0.30</td>\n",
" <td># Paranormal Military Squad at Dulce Base: Dec...</td>\n",
" <td>8.5</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>7</td>\n",
" <td>Operation: Dulce and the Paranormal Military S...</td>\n",
" <td>0.20</td>\n",
" <td># Operation: Dulce and the Paranormal Military...</td>\n",
" <td>8.5</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>8</td>\n",
" <td>Dr. Jordan Hayes and the Paranormal Military S...</td>\n",
" <td>0.18</td>\n",
" <td># Dr. Jordan Hayes and the Paranormal Military...</td>\n",
" <td>8.5</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5</th>\n",
" <td>1</td>\n",
" <td>Earth's Interstellar Communication Initiative</td>\n",
" <td>0.16</td>\n",
" <td># Earth's Interstellar Communication Initiativ...</td>\n",
" <td>8.5</td>\n",
" </tr>\n",
" <tr>\n",
" <th>6</th>\n",
" <td>12</td>\n",
" <td>Paranormal Military Squad and Dulce Base Opera...</td>\n",
" <td>0.16</td>\n",
" <td># Paranormal Military Squad and Dulce Base Ope...</td>\n",
" <td>7.5</td>\n",
" </tr>\n",
" <tr>\n",
" <th>7</th>\n",
" <td>13</td>\n",
" <td>Mission Dynamics and Leadership: Cruz and Wash...</td>\n",
" <td>0.16</td>\n",
" <td># Mission Dynamics and Leadership: Cruz and Wa...</td>\n",
" <td>7.5</td>\n",
" </tr>\n",
" <tr>\n",
" <th>8</th>\n",
" <td>14</td>\n",
" <td>Dulce Base and Paranormal Military Squad: Brid...</td>\n",
" <td>0.12</td>\n",
" <td># Dulce Base and Paranormal Military Squad: Br...</td>\n",
" <td>8.5</td>\n",
" </tr>\n",
" <tr>\n",
" <th>9</th>\n",
" <td>16</td>\n",
" <td>Dulce Military Base and Alien Intelligence Com...</td>\n",
" <td>0.08</td>\n",
" <td># Dulce Military Base and Alien Intelligence C...</td>\n",
" <td>8.5</td>\n",
" </tr>\n",
" <tr>\n",
" <th>10</th>\n",
" <td>18</td>\n",
" <td>Paranormal Military Squad Team and Dulce Base'...</td>\n",
" <td>0.04</td>\n",
" <td># Paranormal Military Squad Team and Dulce Bas...</td>\n",
" <td>8.5</td>\n",
" </tr>\n",
" <tr>\n",
" <th>11</th>\n",
" <td>5</td>\n",
" <td>Alien Script and First Contact Operations</td>\n",
" <td>0.02</td>\n",
" <td># Alien Script and First Contact Operations\\n\\...</td>\n",
" <td>8.5</td>\n",
" </tr>\n",
" <tr>\n",
" <th>12</th>\n",
" <td>4</td>\n",
" <td>Dulce Facility and Control Room of Dulce: Extr...</td>\n",
" <td>0.02</td>\n",
" <td># Dulce Facility and Control Room of Dulce: Ex...</td>\n",
" <td>8.5</td>\n",
" </tr>\n",
" <tr>\n",
" <th>13</th>\n",
" <td>17</td>\n",
" <td>Dulce Team and Underground Command Center: Int...</td>\n",
" <td>0.02</td>\n",
" <td># Dulce Team and Underground Command Center: I...</td>\n",
" <td>8.5</td>\n",
" </tr>\n",
" <tr>\n",
" <th>14</th>\n",
" <td>19</td>\n",
" <td>Central Terminal and Viewing Monitors at Dulce...</td>\n",
" <td>0.02</td>\n",
" <td># Central Terminal and Viewing Monitors at Dul...</td>\n",
" <td>8.5</td>\n",
" </tr>\n",
" <tr>\n",
" <th>15</th>\n",
" <td>6</td>\n",
" <td>Cosmic Vocalization and Universe Interactions</td>\n",
" <td>0.02</td>\n",
" <td># Cosmic Vocalization and Universe Interaction...</td>\n",
" <td>7.5</td>\n",
" </tr>\n",
" <tr>\n",
" <th>16</th>\n",
" <td>9</td>\n",
" <td>Dulce Base Exploration by TEAM and MAINFRAME ROOM</td>\n",
" <td>0.02</td>\n",
" <td># Dulce Base Exploration by TEAM and MAINFRAME...</td>\n",
" <td>7.5</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" id title occurrence weight \\\n",
"0 15 Dulce Base and the Paranormal Military Squad: ... 1.00 \n",
"1 11 Dulce and Paranormal Military Squad Operations 0.30 \n",
"2 10 Paranormal Military Squad at Dulce Base: Decod... 0.30 \n",
"3 7 Operation: Dulce and the Paranormal Military S... 0.20 \n",
"4 8 Dr. Jordan Hayes and the Paranormal Military S... 0.18 \n",
"5 1 Earth's Interstellar Communication Initiative 0.16 \n",
"6 12 Paranormal Military Squad and Dulce Base Opera... 0.16 \n",
"7 13 Mission Dynamics and Leadership: Cruz and Wash... 0.16 \n",
"8 14 Dulce Base and Paranormal Military Squad: Brid... 0.12 \n",
"9 16 Dulce Military Base and Alien Intelligence Com... 0.08 \n",
"10 18 Paranormal Military Squad Team and Dulce Base'... 0.04 \n",
"11 5 Alien Script and First Contact Operations 0.02 \n",
"12 4 Dulce Facility and Control Room of Dulce: Extr... 0.02 \n",
"13 17 Dulce Team and Underground Command Center: Int... 0.02 \n",
"14 19 Central Terminal and Viewing Monitors at Dulce... 0.02 \n",
"15 6 Cosmic Vocalization and Universe Interactions 0.02 \n",
"16 9 Dulce Base Exploration by TEAM and MAINFRAME ROOM 0.02 \n",
"\n",
" content rank \n",
"0 # Dulce Base and the Paranormal Military Squad... 9.5 \n",
"1 # Dulce and Paranormal Military Squad Operatio... 8.5 \n",
"2 # Paranormal Military Squad at Dulce Base: Dec... 8.5 \n",
"3 # Operation: Dulce and the Paranormal Military... 8.5 \n",
"4 # Dr. Jordan Hayes and the Paranormal Military... 8.5 \n",
"5 # Earth's Interstellar Communication Initiativ... 8.5 \n",
"6 # Paranormal Military Squad and Dulce Base Ope... 7.5 \n",
"7 # Mission Dynamics and Leadership: Cruz and Wa... 7.5 \n",
"8 # Dulce Base and Paranormal Military Squad: Br... 8.5 \n",
"9 # Dulce Military Base and Alien Intelligence C... 8.5 \n",
"10 # Paranormal Military Squad Team and Dulce Bas... 8.5 \n",
"11 # Alien Script and First Contact Operations\\n\\... 8.5 \n",
"12 # Dulce Facility and Control Room of Dulce: Ex... 8.5 \n",
"13 # Dulce Team and Underground Command Center: I... 8.5 \n",
"14 # Central Terminal and Viewing Monitors at Dul... 8.5 \n",
"15 # Cosmic Vocalization and Universe Interaction... 7.5 \n",
"16 # Dulce Base Exploration by TEAM and MAINFRAME... 7.5 "
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# inspect the data used to build the context for the LLM responses\n",
"result.context_data[\"reports\"]"
@ -215,18 +631,28 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 11,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"LLM calls: 2. Prompt tokens: 11292. Output tokens: 606.\n"
]
}
],
"source": [
"# inspect number of LLM calls and tokens\n",
"print(f\"LLM calls: {result.llm_calls}. LLM tokens: {result.prompt_tokens}\")"
"print(\n",
" f\"LLM calls: {result.llm_calls}. Prompt tokens: {result.prompt_tokens}. Output tokens: {result.output_tokens}.\"\n",
")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"display_name": "graphrag",
"language": "python",
"name": "python3"
},
@ -240,7 +666,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
"version": "3.12.5"
}
},
"nbformat": 4,

View File

@ -0,0 +1,615 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"# Copyright (c) 2024 Microsoft Corporation.\n",
"# Licensed under the MIT License."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"\n",
"import pandas as pd\n",
"import tiktoken\n",
"\n",
"from graphrag.query.indexer_adapters import (\n",
" read_indexer_communities,\n",
" read_indexer_entities,\n",
" read_indexer_reports,\n",
")\n",
"from graphrag.query.llm.oai.chat_openai import ChatOpenAI\n",
"from graphrag.query.llm.oai.typing import OpenaiApiType\n",
"from graphrag.query.structured_search.global_search.community_context import (\n",
" GlobalCommunityContext,\n",
")\n",
"from graphrag.query.structured_search.global_search.search import GlobalSearch"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Global Search example\n",
"\n",
"Global search method generates answers by searching over all AI-generated community reports in a map-reduce fashion. This is a resource-intensive method, but often gives good responses for questions that require an understanding of the dataset as a whole (e.g. What are the most significant values of the herbs mentioned in this notebook?)."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### LLM setup"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"api_key = os.environ[\"GRAPHRAG_API_KEY\"]\n",
"llm_model = os.environ[\"GRAPHRAG_LLM_MODEL\"]\n",
"\n",
"llm = ChatOpenAI(\n",
" api_key=api_key,\n",
" model=llm_model,\n",
" api_type=OpenaiApiType.OpenAI, # OpenaiApiType.OpenAI or OpenaiApiType.AzureOpenAI\n",
" max_retries=20,\n",
")\n",
"\n",
"token_encoder = tiktoken.encoding_for_model(llm_model)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Load community reports as context for global search\n",
"\n",
"- Load all community reports in the `create_final_community_reports` table from the ire-indexing engine, to be used as context data for global search.\n",
"- Load entities from the `create_final_nodes` and `create_final_entities` tables from the ire-indexing engine, to be used for calculating community weights for context ranking. Note that this is optional (if no entities are provided, we will not calculate community weights and only use the rank attribute in the community reports table for context ranking)\n",
"- Load all communities in the `create_final_communites` table from the ire-indexing engine, to be used to reconstruct the community graph hierarchy for dynamic community selection."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"# parquet files generated from indexing pipeline\n",
"INPUT_DIR = \"./inputs/operation dulce\"\n",
"COMMUNITY_TABLE = \"create_final_communities\"\n",
"COMMUNITY_REPORT_TABLE = \"create_final_community_reports\"\n",
"ENTITY_TABLE = \"create_final_nodes\"\n",
"ENTITY_EMBEDDING_TABLE = \"create_final_entities\"\n",
"\n",
"# we don't fix a specific community level but instead use an agent to dynamicially\n",
"# search through all the community reports to check if they are relevant.\n",
"COMMUNITY_LEVEL = None"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Total report count: 20\n",
"Report count after filtering by community level None: 20\n"
]
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>community</th>\n",
" <th>full_content</th>\n",
" <th>level</th>\n",
" <th>rank</th>\n",
" <th>title</th>\n",
" <th>rank_explanation</th>\n",
" <th>summary</th>\n",
" <th>findings</th>\n",
" <th>full_content_json</th>\n",
" <th>id</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>10</td>\n",
" <td># Paranormal Military Squad at Dulce Base: Dec...</td>\n",
" <td>1</td>\n",
" <td>8.5</td>\n",
" <td>Paranormal Military Squad at Dulce Base: Decod...</td>\n",
" <td>The impact severity rating is high due to the ...</td>\n",
" <td>The Paranormal Military Squad, stationed at Du...</td>\n",
" <td>[{'explanation': 'Jordan is a central figure i...</td>\n",
" <td>{\\n \"title\": \"Paranormal Military Squad at ...</td>\n",
" <td>1ba2d200-dd26-4693-affe-a5539d0a0e0d</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>11</td>\n",
" <td># Dulce and Paranormal Military Squad Operatio...</td>\n",
" <td>1</td>\n",
" <td>8.5</td>\n",
" <td>Dulce and Paranormal Military Squad Operations</td>\n",
" <td>The impact severity rating is high due to the ...</td>\n",
" <td>The community centers around Dulce, a secretiv...</td>\n",
" <td>[{'explanation': 'Dulce is described as a top-...</td>\n",
" <td>{\\n \"title\": \"Dulce and Paranormal Military...</td>\n",
" <td>a8a530b0-ae6b-44ea-b11c-9f70d138298d</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>12</td>\n",
" <td># Paranormal Military Squad and Dulce Base Ope...</td>\n",
" <td>1</td>\n",
" <td>7.5</td>\n",
" <td>Paranormal Military Squad and Dulce Base Opera...</td>\n",
" <td>The impact severity rating is relatively high ...</td>\n",
" <td>The community centers around the Paranormal Mi...</td>\n",
" <td>[{'explanation': 'Taylor is a central figure w...</td>\n",
" <td>{\\n \"title\": \"Paranormal Military Squad and...</td>\n",
" <td>0478975b-c805-4cc1-b746-82f3e689e2f3</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>13</td>\n",
" <td># Mission Dynamics and Leadership: Cruz and Wa...</td>\n",
" <td>1</td>\n",
" <td>7.5</td>\n",
" <td>Mission Dynamics and Leadership: Cruz and Wash...</td>\n",
" <td>The impact severity rating is relatively high ...</td>\n",
" <td>This report explores the intricate dynamics of...</td>\n",
" <td>[{'explanation': 'Cruz is a central figure in ...</td>\n",
" <td>{\\n \"title\": \"Mission Dynamics and Leadersh...</td>\n",
" <td>b56f6e68-3951-4f07-8760-63700944a375</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>14</td>\n",
" <td># Dulce Base and Paranormal Military Squad: Br...</td>\n",
" <td>1</td>\n",
" <td>8.5</td>\n",
" <td>Dulce Base and Paranormal Military Squad: Brid...</td>\n",
" <td>The impact severity rating is high due to the ...</td>\n",
" <td>The community centers around the Dulce Base, a...</td>\n",
" <td>[{'explanation': 'Sam Rivera, a member of the ...</td>\n",
" <td>{\\n \"title\": \"Dulce Base and Paranormal Mil...</td>\n",
" <td>736e7006-d050-4abb-a122-00febf3f540f</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" community full_content level rank \\\n",
"0 10 # Paranormal Military Squad at Dulce Base: Dec... 1 8.5 \n",
"1 11 # Dulce and Paranormal Military Squad Operatio... 1 8.5 \n",
"2 12 # Paranormal Military Squad and Dulce Base Ope... 1 7.5 \n",
"3 13 # Mission Dynamics and Leadership: Cruz and Wa... 1 7.5 \n",
"4 14 # Dulce Base and Paranormal Military Squad: Br... 1 8.5 \n",
"\n",
" title \\\n",
"0 Paranormal Military Squad at Dulce Base: Decod... \n",
"1 Dulce and Paranormal Military Squad Operations \n",
"2 Paranormal Military Squad and Dulce Base Opera... \n",
"3 Mission Dynamics and Leadership: Cruz and Wash... \n",
"4 Dulce Base and Paranormal Military Squad: Brid... \n",
"\n",
" rank_explanation \\\n",
"0 The impact severity rating is high due to the ... \n",
"1 The impact severity rating is high due to the ... \n",
"2 The impact severity rating is relatively high ... \n",
"3 The impact severity rating is relatively high ... \n",
"4 The impact severity rating is high due to the ... \n",
"\n",
" summary \\\n",
"0 The Paranormal Military Squad, stationed at Du... \n",
"1 The community centers around Dulce, a secretiv... \n",
"2 The community centers around the Paranormal Mi... \n",
"3 This report explores the intricate dynamics of... \n",
"4 The community centers around the Dulce Base, a... \n",
"\n",
" findings \\\n",
"0 [{'explanation': 'Jordan is a central figure i... \n",
"1 [{'explanation': 'Dulce is described as a top-... \n",
"2 [{'explanation': 'Taylor is a central figure w... \n",
"3 [{'explanation': 'Cruz is a central figure in ... \n",
"4 [{'explanation': 'Sam Rivera, a member of the ... \n",
"\n",
" full_content_json \\\n",
"0 {\\n \"title\": \"Paranormal Military Squad at ... \n",
"1 {\\n \"title\": \"Dulce and Paranormal Military... \n",
"2 {\\n \"title\": \"Paranormal Military Squad and... \n",
"3 {\\n \"title\": \"Mission Dynamics and Leadersh... \n",
"4 {\\n \"title\": \"Dulce Base and Paranormal Mil... \n",
"\n",
" id \n",
"0 1ba2d200-dd26-4693-affe-a5539d0a0e0d \n",
"1 a8a530b0-ae6b-44ea-b11c-9f70d138298d \n",
"2 0478975b-c805-4cc1-b746-82f3e689e2f3 \n",
"3 b56f6e68-3951-4f07-8760-63700944a375 \n",
"4 736e7006-d050-4abb-a122-00febf3f540f "
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"community_df = pd.read_parquet(f\"{INPUT_DIR}/{COMMUNITY_TABLE}.parquet\")\n",
"entity_df = pd.read_parquet(f\"{INPUT_DIR}/{ENTITY_TABLE}.parquet\")\n",
"report_df = pd.read_parquet(f\"{INPUT_DIR}/{COMMUNITY_REPORT_TABLE}.parquet\")\n",
"entity_embedding_df = pd.read_parquet(f\"{INPUT_DIR}/{ENTITY_EMBEDDING_TABLE}.parquet\")\n",
"\n",
"communities = read_indexer_communities(community_df, entity_df, report_df)\n",
"reports = read_indexer_reports(\n",
" report_df,\n",
" entity_df,\n",
" community_level=COMMUNITY_LEVEL,\n",
" dynamic_community_selection=True,\n",
")\n",
"entities = read_indexer_entities(\n",
" entity_df, entity_embedding_df, community_level=COMMUNITY_LEVEL\n",
")\n",
"\n",
"print(f\"Total report count: {len(report_df)}\")\n",
"print(\n",
" f\"Report count after filtering by community level {COMMUNITY_LEVEL}: {len(reports)}\"\n",
")\n",
"\n",
"report_df.head()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Build global context with dynamic community selection\n",
"\n",
"The goal of dynamic community selection reduce the number of community reports that need to be processed in the map-reduce operation. To that end, we take advantage of the hierachical structure of the indexed dataset. We first ask the LLM to rate how relevant each level 0 community is with respect to the user query, we then traverse down the child node(s) if the current community report is deemed relevant.\n",
"\n",
"You can still set a `COMMUNITY_LEVEL` to filter out lower level community reports and apply dynamic community selection on the filtered reports.\n",
"\n",
"Note that the dataset is quite small, with only consist of 20 communities from 2 levels (level 0 and 1). Dynamic community selection is more effective when there are large amount of content to be filtered out."
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"mini_llm = ChatOpenAI(\n",
" api_key=api_key,\n",
" model=\"gpt-4o-mini\",\n",
" api_type=OpenaiApiType.OpenAI, # OpenaiApiType.OpenAI or OpenaiApiType.AzureOpenAI\n",
" max_retries=20,\n",
")\n",
"mini_token_encoder = tiktoken.encoding_for_model(mini_llm.model)\n",
"\n",
"context_builder = GlobalCommunityContext(\n",
" community_reports=reports,\n",
" communities=communities,\n",
" entities=entities, # default to None if you don't want to use community weights for ranking\n",
" token_encoder=token_encoder,\n",
" dynamic_community_selection=True,\n",
" dynamic_community_selection_kwargs={\n",
" \"llm\": mini_llm,\n",
" \"token_encoder\": mini_token_encoder,\n",
" },\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Perform global search with dynamic community selection"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"context_builder_params = {\n",
" \"use_community_summary\": False, # False means using full community reports. True means using community short summaries.\n",
" \"shuffle_data\": True,\n",
" \"include_community_rank\": True,\n",
" \"min_community_rank\": 0,\n",
" \"community_rank_name\": \"rank\",\n",
" \"include_community_weight\": True,\n",
" \"community_weight_name\": \"occurrence weight\",\n",
" \"normalize_community_weight\": True,\n",
" \"max_tokens\": 12_000, # change this based on the token limit you have on your model (if you are using a model with 8k limit, a good setting could be 5000)\n",
" \"context_name\": \"Reports\",\n",
"}\n",
"\n",
"map_llm_params = {\n",
" \"max_tokens\": 1000,\n",
" \"temperature\": 0.0,\n",
" \"response_format\": {\"type\": \"json_object\"},\n",
"}\n",
"\n",
"reduce_llm_params = {\n",
" \"max_tokens\": 2000, # change this based on the token limit you have on your model (if you are using a model with 8k limit, a good setting could be 1000-1500)\n",
" \"temperature\": 0.0,\n",
"}"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"search_engine = GlobalSearch(\n",
" llm=llm,\n",
" context_builder=context_builder,\n",
" token_encoder=token_encoder,\n",
" max_data_tokens=12_000, # change this based on the token limit you have on your model (if you are using a model with 8k limit, a good setting could be 5000)\n",
" map_llm_params=map_llm_params,\n",
" reduce_llm_params=reduce_llm_params,\n",
" allow_general_knowledge=False, # set this to True will add instruction to encourage the LLM to incorporate general knowledge in the response, which may increase hallucinations, but could be useful in some use cases.\n",
" json_mode=True, # set this to False if your LLM model does not support JSON mode.\n",
" context_builder_params=context_builder_params,\n",
" concurrent_coroutines=32,\n",
" response_type=\"multiple paragraphs\", # free form text describing the response type and format, can be anything, e.g. prioritized list, single paragraph, multiple paragraphs, multiple-page report\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"### Overview of Cosmic Vocalization\n",
"\n",
"Cosmic Vocalization is a phenomenon that has captured the attention of various individuals and groups, becoming a focal point for community interest. It is perceived as a significant cosmic event, with interpretations ranging from a strategic security concern to a metaphorical interstellar duet [Data: Reports (6)].\n",
"\n",
"### Key Stakeholders and Perspectives\n",
"\n",
"1. **Paranormal Military Squad**: This group is actively engaged with Cosmic Vocalization, treating it as a strategic element in their security measures. Their involvement underscores the importance of Cosmic Vocalization in broader security contexts. They metaphorically view the Universe as a concert hall, suggesting a unique perspective on cosmic events and their implications for human entities [Data: Reports (6)].\n",
"\n",
"2. **Alex Mercer**: Alex Mercer perceives Cosmic Vocalization as part of an interstellar duet, indicating a responsive and perhaps artistic approach to understanding these cosmic phenomena. This perspective highlights the diverse interpretations and cultural significance attributed to Cosmic Vocalization [Data: Reports (6)].\n",
"\n",
"3. **Taylor Cruz**: Taylor Cruz expresses concerns about Cosmic Vocalization, fearing it might serve as a homing tune. This perspective introduces a layer of urgency and potential threat, suggesting that Cosmic Vocalization could have implications beyond mere observation, possibly affecting security or existential considerations [Data: Reports (6)].\n",
"\n",
"### Implications\n",
"\n",
"The involvement of these stakeholders and their varied perspectives on Cosmic Vocalization illustrate the complexity and multifaceted nature of this phenomenon. It is not only a subject of scientific and strategic interest but also a cultural and existential topic that prompts diverse interpretations and responses. The strategic engagement by the Paranormal Military Squad and the concerns raised by individuals like Taylor Cruz highlight the potential significance of Cosmic Vocalization in both security and broader cosmic contexts [Data: Reports (6)].\n"
]
}
],
"source": [
"result = await search_engine.asearch(\n",
" \"What is Cosmic Vocalization and who are involved in it?\"\n",
")\n",
"\n",
"print(result.response)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>id</th>\n",
" <th>title</th>\n",
" <th>occurrence weight</th>\n",
" <th>content</th>\n",
" <th>rank</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>15</td>\n",
" <td>Dulce Base and the Paranormal Military Squad: ...</td>\n",
" <td>1.00</td>\n",
" <td># Dulce Base and the Paranormal Military Squad...</td>\n",
" <td>9.5</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>1</td>\n",
" <td>Earth's Interstellar Communication Initiative</td>\n",
" <td>0.16</td>\n",
" <td># Earth's Interstellar Communication Initiativ...</td>\n",
" <td>8.5</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>16</td>\n",
" <td>Dulce Military Base and Alien Intelligence Com...</td>\n",
" <td>0.08</td>\n",
" <td># Dulce Military Base and Alien Intelligence C...</td>\n",
" <td>8.5</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>18</td>\n",
" <td>Paranormal Military Squad Team and Dulce Base'...</td>\n",
" <td>0.04</td>\n",
" <td># Paranormal Military Squad Team and Dulce Bas...</td>\n",
" <td>8.5</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>19</td>\n",
" <td>Central Terminal and Viewing Monitors at Dulce...</td>\n",
" <td>0.02</td>\n",
" <td># Central Terminal and Viewing Monitors at Dul...</td>\n",
" <td>8.5</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5</th>\n",
" <td>4</td>\n",
" <td>Dulce Facility and Control Room of Dulce: Extr...</td>\n",
" <td>0.02</td>\n",
" <td># Dulce Facility and Control Room of Dulce: Ex...</td>\n",
" <td>8.5</td>\n",
" </tr>\n",
" <tr>\n",
" <th>6</th>\n",
" <td>6</td>\n",
" <td>Cosmic Vocalization and Universe Interactions</td>\n",
" <td>0.02</td>\n",
" <td># Cosmic Vocalization and Universe Interaction...</td>\n",
" <td>7.5</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" id title occurrence weight \\\n",
"0 15 Dulce Base and the Paranormal Military Squad: ... 1.00 \n",
"1 1 Earth's Interstellar Communication Initiative 0.16 \n",
"2 16 Dulce Military Base and Alien Intelligence Com... 0.08 \n",
"3 18 Paranormal Military Squad Team and Dulce Base'... 0.04 \n",
"4 19 Central Terminal and Viewing Monitors at Dulce... 0.02 \n",
"5 4 Dulce Facility and Control Room of Dulce: Extr... 0.02 \n",
"6 6 Cosmic Vocalization and Universe Interactions 0.02 \n",
"\n",
" content rank \n",
"0 # Dulce Base and the Paranormal Military Squad... 9.5 \n",
"1 # Earth's Interstellar Communication Initiativ... 8.5 \n",
"2 # Dulce Military Base and Alien Intelligence C... 8.5 \n",
"3 # Paranormal Military Squad Team and Dulce Bas... 8.5 \n",
"4 # Central Terminal and Viewing Monitors at Dul... 8.5 \n",
"5 # Dulce Facility and Control Room of Dulce: Ex... 8.5 \n",
"6 # Cosmic Vocalization and Universe Interaction... 7.5 "
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# inspect the data used to build the context for the LLM responses\n",
"result.context_data[\"reports\"]"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Build context (gpt-4o-mini)\n",
"LLM calls: 12. Prompt tokens: 8565. Output tokens: 1091.\n",
"Map-reduce (gpt-4o)\n",
"LLM calls: 2. Prompt tokens: 5771. Output tokens: 600.\n"
]
}
],
"source": [
"# inspect number of LLM calls and tokens in dynamic community selection\n",
"llm_calls = result.llm_calls_categories[\"build_context\"]\n",
"prompt_tokens = result.prompt_tokens_categories[\"build_context\"]\n",
"output_tokens = result.output_tokens_categories[\"build_context\"]\n",
"print(\n",
" f\"Build context ({mini_llm.model})\\nLLM calls: {llm_calls}. Prompt tokens: {prompt_tokens}. Output tokens: {output_tokens}.\"\n",
")\n",
"# inspect number of LLM calls and tokens in map-reduce\n",
"llm_calls = result.llm_calls_categories[\"map\"] + result.llm_calls_categories[\"reduce\"]\n",
"prompt_tokens = (\n",
" result.prompt_tokens_categories[\"map\"] + result.prompt_tokens_categories[\"reduce\"]\n",
")\n",
"output_tokens = (\n",
" result.output_tokens_categories[\"map\"] + result.output_tokens_categories[\"reduce\"]\n",
")\n",
"print(\n",
" f\"Map-reduce ({llm.model})\\nLLM calls: {llm_calls}. Prompt tokens: {prompt_tokens}. Output tokens: {output_tokens}.\"\n",
")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "graphrag",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.5"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

@ -32,6 +32,7 @@ from graphrag.query.factories import (
get_local_search_engine,
)
from graphrag.query.indexer_adapters import (
read_indexer_communities,
read_indexer_covariates,
read_indexer_entities,
read_indexer_relationships,
@ -52,8 +53,10 @@ async def global_search(
config: GraphRagConfig,
nodes: pd.DataFrame,
entities: pd.DataFrame,
communities: pd.DataFrame,
community_reports: pd.DataFrame,
community_level: int,
community_level: int | None,
dynamic_community_selection: bool,
response_type: str,
query: str,
) -> tuple[
@ -67,8 +70,10 @@ async def global_search(
- config (GraphRagConfig): A graphrag configuration (from settings.yaml)
- nodes (pd.DataFrame): A DataFrame containing the final nodes (from create_final_nodes.parquet)
- entities (pd.DataFrame): A DataFrame containing the final entities (from create_final_entities.parquet)
- communities (pd.DataFrame): A DataFrame containing the final communities (from create_final_communities.parquet)
- community_reports (pd.DataFrame): A DataFrame containing the final community reports (from create_final_community_reports.parquet)
- community_level (int): The community level to search at.
- dynamic_community_selection (bool): Enable dynamic community selection instead of using all community reports at a fixed level. Note that you can still provide community_level cap the maximum level to search.
- response_type (str): The type of response to return.
- query (str): The user query to search for.
@ -80,13 +85,21 @@ async def global_search(
------
TODO: Document any exceptions to expect.
"""
reports = read_indexer_reports(community_reports, nodes, community_level)
_entities = read_indexer_entities(nodes, entities, community_level)
_communities = read_indexer_communities(communities, nodes, community_reports)
reports = read_indexer_reports(
community_reports,
nodes,
community_level=community_level,
dynamic_community_selection=dynamic_community_selection,
)
_entities = read_indexer_entities(nodes, entities, community_level=community_level)
search_engine = get_global_search_engine(
config,
reports=reports,
entities=_entities,
communities=_communities,
response_type=response_type,
dynamic_community_selection=dynamic_community_selection,
)
result: SearchResult = await search_engine.asearch(query=query)
response = result.response
@ -99,8 +112,10 @@ async def global_search_streaming(
config: GraphRagConfig,
nodes: pd.DataFrame,
entities: pd.DataFrame,
communities: pd.DataFrame,
community_reports: pd.DataFrame,
community_level: int,
community_level: int | None,
dynamic_community_selection: bool,
response_type: str,
query: str,
) -> AsyncGenerator:
@ -113,8 +128,10 @@ async def global_search_streaming(
- config (GraphRagConfig): A graphrag configuration (from settings.yaml)
- nodes (pd.DataFrame): A DataFrame containing the final nodes (from create_final_nodes.parquet)
- entities (pd.DataFrame): A DataFrame containing the final entities (from create_final_entities.parquet)
- communities (pd.DataFrame): A DataFrame containing the final communities (from create_final_communities.parquet)
- community_reports (pd.DataFrame): A DataFrame containing the final community reports (from create_final_community_reports.parquet)
- community_level (int): The community level to search at.
- dynamic_community_selection (bool): Enable dynamic community selection instead of using all community reports at a fixed level. Note that you can still provide community_level cap the maximum level to search.
- response_type (str): The type of response to return.
- query (str): The user query to search for.
@ -126,13 +143,21 @@ async def global_search_streaming(
------
TODO: Document any exceptions to expect.
"""
reports = read_indexer_reports(community_reports, nodes, community_level)
_entities = read_indexer_entities(nodes, entities, community_level)
_communities = read_indexer_communities(communities, nodes, community_reports)
reports = read_indexer_reports(
community_reports,
nodes,
community_level=community_level,
dynamic_community_selection=dynamic_community_selection,
)
_entities = read_indexer_entities(nodes, entities, community_level=community_level)
search_engine = get_global_search_engine(
config,
reports=reports,
entities=_entities,
communities=_communities,
response_type=response_type,
dynamic_community_selection=dynamic_community_selection,
)
search_result = search_engine.astream_search(query=query)

View File

@ -327,6 +327,10 @@ def _query_cli(
help="The community level in the Leiden community hierarchy from which to load community reports. Higher values represent reports from smaller communities."
),
] = 2,
dynamic_community_selection: Annotated[
bool,
typer.Option(help="Use global search with dynamic community selection."),
] = False,
response_type: Annotated[
str,
typer.Option(
@ -355,6 +359,7 @@ def _query_cli(
data_dir=data,
root_dir=root,
community_level=community_level,
dynamic_community_selection=dynamic_community_selection,
response_type=response_type,
streaming=streaming,
query=query,

View File

@ -22,7 +22,8 @@ def run_global_search(
config_filepath: Path | None,
data_dir: Path | None,
root_dir: Path,
community_level: int,
community_level: int | None,
dynamic_community_selection: bool,
response_type: str,
streaming: bool,
query: str,
@ -42,12 +43,14 @@ def run_global_search(
parquet_list=[
"create_final_nodes.parquet",
"create_final_entities.parquet",
"create_final_communities.parquet",
"create_final_community_reports.parquet",
],
optional_list=[],
)
final_nodes: pd.DataFrame = dataframe_dict["create_final_nodes"]
final_entities: pd.DataFrame = dataframe_dict["create_final_entities"]
final_communities: pd.DataFrame = dataframe_dict["create_final_communities"]
final_community_reports: pd.DataFrame = dataframe_dict[
"create_final_community_reports"
]
@ -63,8 +66,10 @@ def run_global_search(
config=config,
nodes=final_nodes,
entities=final_entities,
communities=final_communities,
community_reports=final_community_reports,
community_level=community_level,
dynamic_community_selection=dynamic_community_selection,
response_type=response_type,
query=query,
):
@ -85,8 +90,10 @@ def run_global_search(
config=config,
nodes=final_nodes,
entities=final_entities,
communities=final_communities,
community_reports=final_community_reports,
community_level=community_level,
dynamic_community_selection=dynamic_community_selection,
response_type=response_type,
query=query,
)

View File

@ -119,8 +119,16 @@ GLOBAL_SEARCH_MAP_MAX_TOKENS = 1000
GLOBAL_SEARCH_REDUCE_MAX_TOKENS = 2_000
GLOBAL_SEARCH_CONCURRENCY = 32
# DRIFT Search
# Global Search with dynamic community selection
DYNAMIC_SEARCH_LLM_MODEL = "gpt-4o-mini"
DYNAMIC_SEARCH_RATE_THRESHOLD = 1
DYNAMIC_SEARCH_KEEP_PARENT = False
DYNAMIC_SEARCH_NUM_REPEATS = 1
DYNAMIC_SEARCH_USE_SUMMARY = False
DYNAMIC_SEARCH_CONCURRENT_COROUTINES = 16
DYNAMIC_SEARCH_MAX_LEVEL = 2
# DRIFT Search
DRIFT_SEARCH_LLM_TEMPERATURE = 0
DRIFT_SEARCH_LLM_TOP_P = 1
DRIFT_SEARCH_LLM_N = 3

View File

@ -43,3 +43,33 @@ class GlobalSearchConfig(BaseModel):
description="The number of concurrent requests.",
default=defs.GLOBAL_SEARCH_CONCURRENCY,
)
# configurations for dynamic community selection
dynamic_search_llm: str = Field(
description="LLM model to use for dynamic community selection",
default=defs.DYNAMIC_SEARCH_LLM_MODEL,
)
dynamic_search_threshold: int = Field(
description="Rating threshold in include a community report",
default=defs.DYNAMIC_SEARCH_RATE_THRESHOLD,
)
dynamic_search_keep_parent: bool = Field(
description="Keep parent community if any of the child communities are relevant",
default=defs.DYNAMIC_SEARCH_KEEP_PARENT,
)
dynamic_search_num_repeats: int = Field(
description="Number of times to rate the same community report",
default=defs.DYNAMIC_SEARCH_NUM_REPEATS,
)
dynamic_search_use_summary: bool = Field(
description="Use community summary instead of full_context",
default=defs.DYNAMIC_SEARCH_USE_SUMMARY,
)
dynamic_search_concurrent_coroutines: int = Field(
description="Number of concurrent coroutines to rate community reports",
default=defs.DYNAMIC_SEARCH_CONCURRENT_COROUTINES,
)
dynamic_search_max_level: int = Field(
description="The maximum level of community hierarchy to consider if none of the processed communities are relevant",
default=defs.DYNAMIC_SEARCH_MAX_LEVEL,
)

View File

@ -29,7 +29,7 @@ CLAIM_DESCRIPTION = "description"
CLAIM_DETAILS = "claim_details"
# COMMUNITY HIERARCHY TABLE SCHEMA
SUB_COMMUNITY = "sub_communitty"
SUB_COMMUNITY = "sub_community"
SUB_COMMUNITY_SIZE = "sub_community_size"
COMMUNITY_LEVEL = "level"

View File

@ -25,6 +25,9 @@ class Community(Named):
covariate_ids: dict[str, list[str]] | None = None
"""Dictionary of different types of covariates related to the community (optional), e.g. claims"""
sub_community_ids: list[str] | None = None
"""List of community IDs of the child nodes of this community (optional)."""
attributes: dict[str, Any] | None = None
"""A dictionary of additional attributes associated with the community (optional). To be included in the search prompt."""
@ -45,6 +48,7 @@ class Community(Named):
entities_key: str = "entity_ids",
relationships_key: str = "relationship_ids",
covariates_key: str = "covariate_ids",
sub_communities_key: str = "sub_community_ids",
attributes_key: str = "attributes",
size_key: str = "size",
period_key: str = "period",
@ -58,6 +62,7 @@ class Community(Named):
entity_ids=d.get(entities_key),
relationship_ids=d.get(relationships_key),
covariate_ids=d.get(covariates_key),
sub_community_ids=d.get(sub_communities_key),
attributes=d.get(attributes_key),
size=d.get(size_key),
period=d.get(period_key),

View File

@ -4,6 +4,7 @@
"""Base classes for global and local context builders."""
from abc import ABC, abstractmethod
from dataclasses import dataclass
import pandas as pd
@ -12,13 +13,27 @@ from graphrag.query.context_builder.conversation_history import (
)
@dataclass
class ContextBuilderResult:
"""A class to hold the results of the build_context."""
context_chunks: str | list[str]
context_records: dict[str, pd.DataFrame]
llm_calls: int = 0
prompt_tokens: int = 0
output_tokens: int = 0
class GlobalContextBuilder(ABC):
"""Base class for global-search context builders."""
@abstractmethod
def build_context(
self, conversation_history: ConversationHistory | None = None, **kwargs
) -> tuple[str | list[str], dict[str, pd.DataFrame]]:
async def build_context(
self,
query: str,
conversation_history: ConversationHistory | None = None,
**kwargs,
) -> ContextBuilderResult:
"""Build the context for the global search mode."""
@ -31,7 +46,7 @@ class LocalContextBuilder(ABC):
query: str,
conversation_history: ConversationHistory | None = None,
**kwargs,
) -> tuple[str | list[str], dict[str, pd.DataFrame]]:
) -> ContextBuilderResult:
"""Build the context for the local search mode."""
@ -43,5 +58,5 @@ class DRIFTContextBuilder(ABC):
self,
query: str,
**kwargs,
) -> pd.DataFrame:
) -> tuple[pd.DataFrame, dict[str, int]]:
"""Build the context for the primer search actions."""

View File

@ -0,0 +1,183 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""Algorithm to dynamically select relevant communities with respect to a query."""
import asyncio
import logging
from collections import Counter
from copy import deepcopy
from time import time
from typing import Any
import tiktoken
from graphrag.model import Community, CommunityReport
from graphrag.query.context_builder.rate_prompt import RATE_QUERY
from graphrag.query.context_builder.rate_relevancy import rate_relevancy
from graphrag.query.llm.base import BaseLLM
log = logging.getLogger(__name__)
DEFAULT_RATE_LLM_PARAMS = {"temperature": 0.0, "max_tokens": 2000}
class DynamicCommunitySelection:
"""Dynamic community selection to select community reports that are relevant to the query.
Any community report with a rating EQUAL or ABOVE the rating_threshold is considered relevant.
"""
def __init__(
self,
community_reports: list[CommunityReport],
communities: list[Community],
llm: BaseLLM,
token_encoder: tiktoken.Encoding,
rate_query: str = RATE_QUERY,
use_summary: bool = False,
threshold: int = 1,
keep_parent: bool = False,
num_repeats: int = 1,
max_level: int = 2,
concurrent_coroutines: int = 8,
llm_kwargs: Any = DEFAULT_RATE_LLM_PARAMS,
):
self.llm = llm
self.token_encoder = token_encoder
self.rate_query = rate_query
self.num_repeats = num_repeats
self.use_summary = use_summary
self.threshold = threshold
self.keep_parent = keep_parent
self.max_level = max_level
self.semaphore = asyncio.Semaphore(concurrent_coroutines)
self.llm_kwargs = llm_kwargs
self.reports = {report.community_id: report for report in community_reports}
# mapping from community to sub communities
self.node2children = {
community.id: (
[]
if community.sub_community_ids is None
else community.sub_community_ids
)
for community in communities
}
# mapping from community to parent community
self.node2parent: dict[str, str] = {
sub_community: community
for community, sub_communities in self.node2children.items()
for sub_community in sub_communities
}
# mapping from level to communities
self.levels: dict[str, list[str]] = {}
for community in communities:
if community.level not in self.levels:
self.levels[community.level] = []
if community.id in self.reports:
self.levels[community.level].append(community.id)
# start from root communities (level 0)
self.starting_communities = self.levels["0"]
async def select(self, query: str) -> tuple[list[CommunityReport], dict[str, Any]]:
"""
Select relevant communities with respect to the query.
Args:
query: the query to rate against
"""
start = time()
queue = deepcopy(self.starting_communities)
level = 0
ratings = {} # store the ratings for each community
llm_info: dict[str, Any] = {
"llm_calls": 0,
"prompt_tokens": 0,
"output_tokens": 0,
}
relevant_communities = set()
while queue:
gather_results = await asyncio.gather(*[
rate_relevancy(
query=query,
description=(
self.reports[community].summary
if self.use_summary
else self.reports[community].full_content
),
llm=self.llm,
token_encoder=self.token_encoder,
rate_query=self.rate_query,
num_repeats=self.num_repeats,
semaphore=self.semaphore,
**self.llm_kwargs,
)
for community in queue
])
communities_to_rate = []
for community, result in zip(queue, gather_results, strict=True):
rating = result["rating"]
log.debug(
"dynamic community selection: community %s rating %s",
community,
rating,
)
ratings[community] = rating
llm_info["llm_calls"] += result["llm_calls"]
llm_info["prompt_tokens"] += result["prompt_tokens"]
llm_info["output_tokens"] += result["output_tokens"]
if rating >= self.threshold:
relevant_communities.add(community)
# find children nodes of the current node and append them to the queue
# TODO check why some sub_communities are NOT in report_df
if community in self.node2children:
for sub_community in self.node2children[community]:
if sub_community in self.reports:
communities_to_rate.append(sub_community)
else:
log.debug(
"dynamic community selection: cannot find community %s in reports",
sub_community,
)
# remove parent node if the current node is deemed relevant
if not self.keep_parent and community in self.node2parent:
relevant_communities.discard(self.node2parent[community])
queue = communities_to_rate
level += 1
if (
(len(queue) == 0)
and (len(relevant_communities) == 0)
and (str(level) in self.levels)
and (level <= self.max_level)
):
log.info(
"dynamic community selection: no relevant community "
"reports, adding all reports at level %s to rate.",
level,
)
# append all communities at the next level to queue
queue = self.levels[str(level)]
community_reports = [
self.reports[community] for community in relevant_communities
]
end = time()
log.info(
"dynamic community selection (took: %ss)\n"
"\trating distribution %s\n"
"\t%s out of %s community reports are relevant\n"
"\tprompt tokens: %s, output tokens: %s",
int(end - start),
dict(sorted(Counter(ratings.values()).items())),
len(relevant_communities),
len(self.reports),
llm_info["prompt_tokens"],
llm_info["output_tokens"],
)
llm_info["ratings"] = ratings
return community_reports, llm_info

View File

@ -0,0 +1,23 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""Global search with dynamic community selection prompt."""
RATE_QUERY = """
---Role---
You are a helpful assistant responsible for deciding whether the provided information is useful in answering a given question, even if it is only partially relevant.
---Goal---
On a scale from 0 to 5, please rate how relevant or helpful is the provided information in answering the question.
---Information---
{description}
---Question---
{question}
---Target response length and format---
Please response in the following JSON format with two entries:
- "reason": the reasoning of your rating, please include information that you have considered.
- "rating": the relevancy rating from 0 to 5, where 0 is the least relevant and 5 is the most relevant.
{{
"reason": str,
"rating": int.
}}
"""

View File

@ -0,0 +1,76 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""Algorithm to rate the relevancy between a query and description text."""
import asyncio
import logging
from contextlib import nullcontext
from typing import Any
import numpy as np
import tiktoken
from graphrag.llm.openai.utils import try_parse_json_object
from graphrag.query.context_builder.rate_prompt import RATE_QUERY
from graphrag.query.llm.base import BaseLLM
from graphrag.query.llm.text_utils import num_tokens
log = logging.getLogger(__name__)
async def rate_relevancy(
query: str,
description: str,
llm: BaseLLM,
token_encoder: tiktoken.Encoding,
rate_query: str = RATE_QUERY,
num_repeats: int = 1,
semaphore: asyncio.Semaphore | None = None,
**llm_kwargs: Any,
) -> dict[str, Any]:
"""
Rate the relevancy between the query and description on a scale of 0 to 10.
Args:
query: the query (or question) to rate against
description: the community description to rate, it can be the community
title, summary, or the full content.
llm: LLM model to use for rating
token_encoder: token encoder
num_repeats: number of times to repeat the rating process for the same community (default: 1)
llm_kwargs: additional arguments to pass to the LLM model
semaphore: asyncio.Semaphore to limit the number of concurrent LLM calls (default: None)
"""
llm_calls, prompt_tokens, output_tokens, ratings = 0, 0, 0, []
messages = [
{
"role": "system",
"content": rate_query.format(description=description, question=query),
},
{"role": "user", "content": query},
]
for _ in range(num_repeats):
async with semaphore if semaphore is not None else nullcontext():
response = await llm.agenerate(messages=messages, **llm_kwargs)
try:
_, parsed_response = try_parse_json_object(response)
ratings.append(parsed_response["rating"])
except KeyError:
# in case of json parsing error, default to rating 1 so the report is kept.
# json parsing error should rarely happen.
log.info("Error parsing json response, defaulting to rating 1")
ratings.append(1)
llm_calls += 1
prompt_tokens += num_tokens(messages[0]["content"], token_encoder)
output_tokens += num_tokens(response, token_encoder)
# select the decision with the most votes
options, counts = np.unique(ratings, return_counts=True)
rating = int(options[np.argmax(counts)])
return {
"rating": rating,
"ratings": ratings,
"llm_calls": llm_calls,
"prompt_tokens": prompt_tokens,
"output_tokens": output_tokens,
}

View File

@ -3,14 +3,13 @@
"""Query Factory methods to support CLI."""
import tiktoken
from azure.identity import DefaultAzureCredential, get_bearer_token_provider
from copy import deepcopy
from graphrag.config import (
GraphRagConfig,
LLMType,
)
import tiktoken
from graphrag.config import GraphRagConfig
from graphrag.model import (
Community,
CommunityReport,
Covariate,
Entity,
@ -18,9 +17,7 @@ from graphrag.model import (
TextUnit,
)
from graphrag.query.context_builder.entity_extraction import EntityVectorStoreKey
from graphrag.query.llm.oai.chat_openai import ChatOpenAI
from graphrag.query.llm.oai.embedding import OpenAIEmbedding
from graphrag.query.llm.oai.typing import OpenaiApiType
from graphrag.query.llm.get_client import get_llm, get_text_embedder
from graphrag.query.structured_search.drift_search.drift_context import (
DRIFTSearchContextBuilder,
)
@ -36,71 +33,6 @@ from graphrag.query.structured_search.local_search.search import LocalSearch
from graphrag.vector_stores import BaseVectorStore
def get_llm(config: GraphRagConfig) -> ChatOpenAI:
"""Get the LLM client."""
is_azure_client = (
config.llm.type == LLMType.AzureOpenAIChat
or config.llm.type == LLMType.AzureOpenAI
)
debug_llm_key = config.llm.api_key or ""
llm_debug_info = {
**config.llm.model_dump(),
"api_key": f"REDACTED,len={len(debug_llm_key)}",
}
audience = (
config.llm.audience
if config.llm.audience
else "https://cognitiveservices.azure.com/.default"
)
print(f"creating llm client with {llm_debug_info}") # noqa T201
return ChatOpenAI(
api_key=config.llm.api_key,
azure_ad_token_provider=(
get_bearer_token_provider(DefaultAzureCredential(), audience)
if is_azure_client and not config.llm.api_key
else None
),
api_base=config.llm.api_base,
organization=config.llm.organization,
model=config.llm.model,
api_type=OpenaiApiType.AzureOpenAI if is_azure_client else OpenaiApiType.OpenAI,
deployment_name=config.llm.deployment_name,
api_version=config.llm.api_version,
max_retries=config.llm.max_retries,
request_timeout=config.llm.request_timeout,
)
def get_text_embedder(config: GraphRagConfig) -> OpenAIEmbedding:
"""Get the LLM client for embeddings."""
is_azure_client = config.embeddings.llm.type == LLMType.AzureOpenAIEmbedding
debug_embedding_api_key = config.embeddings.llm.api_key or ""
llm_debug_info = {
**config.embeddings.llm.model_dump(),
"api_key": f"REDACTED,len={len(debug_embedding_api_key)}",
}
if config.embeddings.llm.audience is None:
audience = "https://cognitiveservices.azure.com/.default"
else:
audience = config.embeddings.llm.audience
print(f"creating embedding llm client with {llm_debug_info}") # noqa T201
return OpenAIEmbedding(
api_key=config.embeddings.llm.api_key,
azure_ad_token_provider=(
get_bearer_token_provider(DefaultAzureCredential(), audience)
if is_azure_client and not config.embeddings.llm.api_key
else None
),
api_base=config.embeddings.llm.api_base,
organization=config.llm.organization,
api_type=OpenaiApiType.AzureOpenAI if is_azure_client else OpenaiApiType.OpenAI,
model=config.embeddings.llm.model,
deployment_name=config.embeddings.llm.deployment_name,
api_version=config.embeddings.llm.api_version,
max_retries=config.embeddings.llm.max_retries,
)
def get_local_search_engine(
config: GraphRagConfig,
reports: list[CommunityReport],
@ -160,16 +92,39 @@ def get_global_search_engine(
config: GraphRagConfig,
reports: list[CommunityReport],
entities: list[Entity],
communities: list[Community],
response_type: str,
dynamic_community_selection: bool = False,
) -> GlobalSearch:
"""Create a global search engine based on data + configuration."""
token_encoder = tiktoken.get_encoding(config.encoding_model)
gs_config = config.global_search
dynamic_community_selection_kwargs = {}
if dynamic_community_selection:
gs_config = config.global_search
_config = deepcopy(config)
_config.llm.model = _config.llm.deployment_name = gs_config.dynamic_search_llm
dynamic_community_selection_kwargs.update({
"llm": get_llm(_config),
"token_encoder": tiktoken.encoding_for_model(gs_config.dynamic_search_llm),
"keep_parent": gs_config.dynamic_search_keep_parent,
"num_repeats": gs_config.dynamic_search_num_repeats,
"use_summary": gs_config.dynamic_search_use_summary,
"concurrent_coroutines": gs_config.dynamic_search_concurrent_coroutines,
"threshold": gs_config.dynamic_search_threshold,
"max_level": gs_config.dynamic_search_max_level,
})
return GlobalSearch(
llm=get_llm(config),
context_builder=GlobalCommunityContext(
community_reports=reports, entities=entities, token_encoder=token_encoder
community_reports=reports,
communities=communities,
entities=entities,
token_encoder=token_encoder,
dynamic_community_selection=dynamic_community_selection,
dynamic_community_selection_kwargs=dynamic_community_selection_kwargs,
),
token_encoder=token_encoder,
max_data_tokens=gs_config.data_max_tokens,

View File

@ -3,17 +3,27 @@
"""Indexing-Engine to Query Read Adapters.
The parts of these functions that do type adaptation, renaming, collating, etc. should eventually go away.
Ideally this is just a straight read-thorugh into the object model.
Ideally this is just a straight read-through into the object model.
"""
import logging
from typing import cast
import pandas as pd
from graphrag.config.models.graph_rag_config import GraphRagConfig
from graphrag.model import CommunityReport, Covariate, Entity, Relationship, TextUnit
from graphrag.index.operations.summarize_communities import restore_community_hierarchy
from graphrag.model import (
Community,
CommunityReport,
Covariate,
Entity,
Relationship,
TextUnit,
)
from graphrag.query.factories import get_text_embedder
from graphrag.query.input.loaders.dfs import (
read_communities,
read_community_reports,
read_covariates,
read_entities,
@ -23,6 +33,8 @@ from graphrag.query.input.loaders.dfs import (
from graphrag.query.llm.oai.embedding import OpenAIEmbedding
from graphrag.vector_stores.base import BaseVectorStore
log = logging.getLogger(__name__)
def read_indexer_text_units(final_text_units: pd.DataFrame) -> list[TextUnit]:
"""Read in the Text Units from the raw indexing outputs."""
@ -66,23 +78,32 @@ def read_indexer_relationships(final_relationships: pd.DataFrame) -> list[Relati
def read_indexer_reports(
final_community_reports: pd.DataFrame,
final_nodes: pd.DataFrame,
community_level: int,
community_level: int | None,
dynamic_community_selection: bool = False,
content_embedding_col: str = "full_content_embedding",
config: GraphRagConfig | None = None,
) -> list[CommunityReport]:
"""Read in the Community Reports from the raw indexing outputs."""
"""Read in the Community Reports from the raw indexing outputs.
If not dynamic_community_selection, then select reports with the max community level that an entity belongs to.
"""
report_df = final_community_reports
entity_df = final_nodes
entity_df = _filter_under_community_level(entity_df, community_level)
entity_df.loc[:, "community"] = entity_df["community"].fillna(-1)
entity_df.loc[:, "community"] = entity_df["community"].astype(int)
if community_level is not None:
entity_df = _filter_under_community_level(entity_df, community_level)
report_df = _filter_under_community_level(report_df, community_level)
entity_df = entity_df.groupby(["title"]).agg({"community": "max"}).reset_index()
entity_df["community"] = entity_df["community"].astype(str)
filtered_community_df = entity_df["community"].drop_duplicates()
if not dynamic_community_selection:
# perform community level roll up
entity_df.loc[:, "community"] = entity_df["community"].fillna(-1)
entity_df.loc[:, "community"] = entity_df["community"].astype(int)
entity_df = entity_df.groupby(["title"]).agg({"community": "max"}).reset_index()
entity_df["community"] = entity_df["community"].astype(str)
filtered_community_df = entity_df["community"].drop_duplicates()
report_df = report_df.merge(filtered_community_df, on="community", how="inner")
report_df = _filter_under_community_level(report_df, community_level)
report_df = report_df.merge(filtered_community_df, on="community", how="inner")
if config and (
content_embedding_col not in report_df.columns
or report_df.loc[:, content_embedding_col].isna().any()
@ -113,13 +134,15 @@ def read_indexer_report_embeddings(
def read_indexer_entities(
final_nodes: pd.DataFrame,
final_entities: pd.DataFrame,
community_level: int,
community_level: int | None,
) -> list[Entity]:
"""Read in the Entities from the raw indexing outputs."""
entity_df = final_nodes
entity_embedding_df = final_entities
entity_df = _filter_under_community_level(entity_df, community_level)
if community_level is not None:
entity_df = _filter_under_community_level(entity_df, community_level)
entity_df = cast(pd.DataFrame, entity_df[["title", "degree", "community"]]).rename(
columns={"title": "name", "degree": "rank"}
)
@ -128,11 +151,11 @@ def read_indexer_entities(
entity_df["community"] = entity_df["community"].astype(int)
entity_df["rank"] = entity_df["rank"].astype(int)
# for duplicate entities, keep the one with the highest community level
# group entities by name and rank and remove duplicated community IDs
entity_df = (
entity_df.groupby(["name", "rank"]).agg({"community": "max"}).reset_index()
entity_df.groupby(["name", "rank"]).agg({"community": set}).reset_index()
)
entity_df["community"] = entity_df["community"].apply(lambda x: [str(x)])
entity_df["community"] = entity_df["community"].apply(lambda x: [str(i) for i in x])
entity_df = entity_df.merge(
entity_embedding_df, on="name", how="inner"
).drop_duplicates(subset=["name"])
@ -155,6 +178,60 @@ def read_indexer_entities(
)
def read_indexer_communities(
final_communities: pd.DataFrame,
final_nodes: pd.DataFrame,
final_community_reports: pd.DataFrame,
) -> list[Community]:
"""Read in the Communities from the raw indexing outputs.
Reconstruct the community hierarchy information and add to the sub-community field.
"""
community_df = final_communities
node_df = final_nodes
report_df = final_community_reports
# ensure communities matches community reports
missing_reports = community_df[
~community_df.id.isin(report_df.community.unique())
].id.to_list()
if len(missing_reports):
log.warning("Missing reports for communities: %s", missing_reports)
community_df = community_df.loc[
community_df.id.isin(report_df.community.unique())
]
node_df = node_df.loc[node_df.community.isin(report_df.community.unique())]
# reconstruct the community hierarchy
# note that restore_community_hierarchy only return communities with sub communities
community_hierarchy = restore_community_hierarchy(input=node_df)
community_hierarchy = (
community_hierarchy.groupby(["community"])
.agg({"sub_community": list})
.reset_index()
.rename(columns={"community": "id", "sub_community": "sub_community_ids"})
)
# add sub community IDs to community DataFrame
community_df = community_df.merge(community_hierarchy, on="id", how="left")
# replace NaN sub community IDs with empty list
community_df.sub_community_ids = community_df.sub_community_ids.apply(
lambda x: x if isinstance(x, list) else []
)
return read_communities(
community_df,
id_col="id",
short_id_col="id",
title_col="title",
level_col="level",
entities_col=None,
relationships_col=None,
covariates_col=None,
sub_communities_col="sub_community_ids",
attributes_cols=None,
)
def embed_community_reports(
reports_df: pd.DataFrame,
embedder: OpenAIEmbedding,

View File

@ -215,6 +215,7 @@ def read_communities(
entities_col: str | None = "entity_ids",
relationships_col: str | None = "relationship_ids",
covariates_col: str | None = "covariate_ids",
sub_communities_col: str | None = "sub_community_ids",
attributes_cols: list[str] | None = None,
) -> list[Community]:
"""Read communities from a dataframe."""
@ -230,6 +231,7 @@ def read_communities(
covariate_ids=to_optional_dict(
row, covariates_col, key_type=str, value_type=str
),
sub_community_ids=to_optional_list(row, sub_communities_col, item_type=str),
attributes=(
{col: row.get(col) for col in attributes_cols}
if attributes_cols

View File

@ -0,0 +1,76 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""Initialize LLM and Embedding clients."""
from azure.identity import DefaultAzureCredential, get_bearer_token_provider
from graphrag.config import GraphRagConfig, LLMType
from graphrag.query.llm.oai.chat_openai import ChatOpenAI
from graphrag.query.llm.oai.embedding import OpenAIEmbedding
from graphrag.query.llm.oai.typing import OpenaiApiType
def get_llm(config: GraphRagConfig) -> ChatOpenAI:
"""Get the LLM client."""
is_azure_client = (
config.llm.type == LLMType.AzureOpenAIChat
or config.llm.type == LLMType.AzureOpenAI
)
debug_llm_key = config.llm.api_key or ""
llm_debug_info = {
**config.llm.model_dump(),
"api_key": f"REDACTED,len={len(debug_llm_key)}",
}
audience = (
config.llm.audience
if config.llm.audience
else "https://cognitiveservices.azure.com/.default"
)
print(f"creating llm client with {llm_debug_info}") # noqa T201
return ChatOpenAI(
api_key=config.llm.api_key,
azure_ad_token_provider=(
get_bearer_token_provider(DefaultAzureCredential(), audience)
if is_azure_client and not config.llm.api_key
else None
),
api_base=config.llm.api_base,
organization=config.llm.organization,
model=config.llm.model,
api_type=OpenaiApiType.AzureOpenAI if is_azure_client else OpenaiApiType.OpenAI,
deployment_name=config.llm.deployment_name,
api_version=config.llm.api_version,
max_retries=config.llm.max_retries,
request_timeout=config.llm.request_timeout,
)
def get_text_embedder(config: GraphRagConfig) -> OpenAIEmbedding:
"""Get the LLM client for embeddings."""
is_azure_client = config.embeddings.llm.type == LLMType.AzureOpenAIEmbedding
debug_embedding_api_key = config.embeddings.llm.api_key or ""
llm_debug_info = {
**config.embeddings.llm.model_dump(),
"api_key": f"REDACTED,len={len(debug_embedding_api_key)}",
}
if config.embeddings.llm.audience is None:
audience = "https://cognitiveservices.azure.com/.default"
else:
audience = config.embeddings.llm.audience
print(f"creating embedding llm client with {llm_debug_info}") # noqa T201
return OpenAIEmbedding(
api_key=config.embeddings.llm.api_key,
azure_ad_token_provider=(
get_bearer_token_provider(DefaultAzureCredential(), audience)
if is_azure_client and not config.embeddings.llm.api_key
else None
),
api_base=config.embeddings.llm.api_base,
organization=config.llm.organization,
api_type=OpenaiApiType.AzureOpenAI if is_azure_client else OpenaiApiType.OpenAI,
model=config.embeddings.llm.model,
deployment_name=config.embeddings.llm.deployment_name,
api_version=config.embeddings.llm.api_version,
max_retries=config.embeddings.llm.max_retries,
)

View File

@ -6,12 +6,16 @@
from enum import Enum
from typing import Any, cast
import httpx
import openai
OPENAI_RETRY_ERROR_TYPES = (
# TODO: update these when we update to OpenAI 1+ library
cast(Any, openai).RateLimitError,
cast(Any, openai).APIConnectionError,
cast(Any, openai).APIError,
cast(Any, httpx).RemoteProtocolError,
cast(Any, httpx).ReadTimeout,
# TODO: replace with comparable OpenAI 1+ error
)

View File

@ -31,8 +31,14 @@ class SearchResult:
# actual text strings that are in the context window, built from context_data
context_text: str | list[str] | dict[str, str]
completion_time: float
# total LLM calls and token usage
llm_calls: int
prompt_tokens: int
output_tokens: int
# breakdown of LLM calls and token usage
llm_calls_categories: dict[str, int] | None = None
prompt_tokens_categories: dict[str, int] | None = None
output_tokens_categories: dict[str, int] | None = None
T = TypeVar("T", GlobalContextBuilder, LocalContextBuilder, DRIFTContextBuilder)

View File

@ -7,8 +7,6 @@ import json
import logging
from typing import Any
from graphrag.query.llm.text_utils import num_tokens
log = logging.getLogger(__name__)
@ -39,7 +37,11 @@ class DriftAction:
self.follow_ups: list[DriftAction] = (
follow_ups if follow_ups is not None else []
)
self.metadata: dict[str, Any] = {}
self.metadata: dict[str, Any] = {
"llm_calls": 0,
"prompt_tokens": 0,
"output_tokens": 0,
}
@property
def is_complete(self) -> bool:
@ -85,12 +87,10 @@ class DriftAction:
if self.answer is None:
log.warning("No answer found for query: %s", self.query)
generated_tokens = 0
else:
generated_tokens = num_tokens(self.answer, search_engine.token_encoder)
self.metadata.update({
"token_ct": search_result.prompt_tokens + generated_tokens
})
self.metadata["llm_calls"] += 1
self.metadata["prompt_tokens"] += search_result.prompt_tokens
self.metadata["output_tokens"] += search_result.output_tokens
self.follow_ups = response.pop("follow_up_queries", [])
if not self.follow_ups:

View File

@ -69,7 +69,6 @@ class DRIFTSearchContextBuilder(DRIFTContextBuilder):
self.covariates = covariates
self.embedding_vectorstore_key = embedding_vectorstore_key
self.llm_tokens = 0
self.local_mixed_context = (
local_mixed_context or self.init_local_context_builder()
)
@ -160,7 +159,9 @@ class DRIFTSearchContextBuilder(DRIFTContextBuilder):
and isinstance(query_embedding[0], type(embedding[0]))
)
def build_context(self, query: str, **kwargs) -> pd.DataFrame:
def build_context(
self, query: str, **kwargs
) -> tuple[pd.DataFrame, dict[str, int]]:
"""
Build DRIFT search context.
@ -172,6 +173,7 @@ class DRIFTSearchContextBuilder(DRIFTContextBuilder):
Returns
-------
pd.DataFrame: Top-k most similar documents.
dict[str, int]: Number of LLM calls, and prompts and output tokens.
Raises
------
@ -192,7 +194,6 @@ class DRIFTSearchContextBuilder(DRIFTContextBuilder):
)
query_embedding, token_ct = query_processor(query)
self.llm_tokens += token_ct
report_df = self.convert_reports_to_df(self.reports)
@ -219,4 +220,4 @@ class DRIFTSearchContextBuilder(DRIFTContextBuilder):
# Sort by similarity and select top-k
top_k = report_df.nlargest(self.config.drift_k_followups, "similarity")
return top_k.loc[:, ["short_id", "community_id", "full_content"]]
return top_k.loc[:, ["short_id", "community_id", "full_content"]], token_ct

View File

@ -50,7 +50,7 @@ class PrimerQueryProcessor:
self.token_encoder = token_encoder
self.reports = reports
def expand_query(self, query: str) -> tuple[str, int]:
def expand_query(self, query: str) -> tuple[str, dict[str, int]]:
"""
Expand the query using a random community report template.
@ -59,9 +59,8 @@ class PrimerQueryProcessor:
Returns
-------
tuple[str, int]: Expanded query text and the number of tokens used.
tuple[str, dict[str, int]]: Expanded query text and the number of tokens used.
"""
token_ct = 0
template = secrets.choice(self.reports).full_content # nosec S311
prompt = f"""Create a hypothetical answer to the following query: {query}\n\n
@ -72,13 +71,19 @@ class PrimerQueryProcessor:
messages = [{"role": "user", "content": prompt}]
text = self.chat_llm.generate(messages)
token_ct = num_tokens(text + query)
prompt_tokens = num_tokens(prompt, self.token_encoder)
output_tokens = num_tokens(text, self.token_encoder)
token_ct = {
"llm_calls": 1,
"prompt_tokens": prompt_tokens,
"output_tokens": output_tokens,
}
if text == "":
log.warning("Failed to generate expansion for query: %s", query)
return query, token_ct
return text, token_ct
def __call__(self, query: str) -> tuple[list[float], int]:
def __call__(self, query: str) -> tuple[list[float], dict[str, int]]:
"""
Call method to process the query, expand it, and embed the result.
@ -117,7 +122,7 @@ class DRIFTPrimer:
async def decompose_query(
self, query: str, reports: pd.DataFrame
) -> tuple[dict, int]:
) -> tuple[dict, dict[str, int]]:
"""
Decompose the query into subqueries based on the fetched global structures.
@ -127,7 +132,7 @@ class DRIFTPrimer:
Returns
-------
tuple[dict, int]: Parsed response and the number of tokens used.
tuple[dict, int, int]: Parsed response and the number of prompt and output tokens used.
"""
community_reports = "\n\n".join(reports["full_content"].tolist())
prompt = DRIFT_PRIMER_PROMPT.format(
@ -140,7 +145,12 @@ class DRIFTPrimer:
)
parsed_response = json.loads(response)
token_ct = num_tokens(prompt + response, self.token_encoder)
token_ct = {
"llm_calls": 1,
"prompt_tokens": num_tokens(prompt, self.token_encoder),
"output_tokens": num_tokens(response, self.token_encoder),
}
return parsed_response, token_ct
@ -163,7 +173,7 @@ class DRIFTPrimer:
start_time = time.perf_counter()
report_folds = self.split_reports(top_k_reports)
tasks = [self.decompose_query(query, fold) for fold in report_folds]
results_with_tokens = await tqdm_asyncio.gather(*tasks)
results_with_tokens = await tqdm_asyncio.gather(*tasks, leave=False)
completion_time = time.perf_counter() - start_time
@ -173,7 +183,8 @@ class DRIFTPrimer:
context_text=top_k_reports.to_json() or "",
completion_time=completion_time,
llm_calls=len(results_with_tokens),
prompt_tokens=sum(tokens for _, tokens in results_with_tokens),
prompt_tokens=sum(ct["prompt_tokens"] for _, ct in results_with_tokens),
output_tokens=sum(ct["output_tokens"] for _, ct in results_with_tokens),
)
def split_reports(self, reports: pd.DataFrame) -> list[pd.DataFrame]:

View File

@ -163,7 +163,7 @@ class DRIFTSearch(BaseSearch[DRIFTSearchContextBuilder]):
action.asearch(search_engine=search_engine, global_query=global_query)
for action in actions
]
return await tqdm_asyncio.gather(*tasks)
return await tqdm_asyncio.gather(*tasks, leave=False)
async def asearch(
self,
@ -190,20 +190,25 @@ class DRIFTSearch(BaseSearch[DRIFTSearchContextBuilder]):
error_msg = "DRIFT Search query cannot be empty."
raise ValueError(error_msg)
llm_calls, prompt_tokens, output_tokens = {}, {}, {}
start_time = time.perf_counter()
primer_token_ct = 0
context_token_ct = 0
# Check if query state is empty
if not self.query_state.graph:
# Prime the search with the primer
primer_context = self.context_builder.build_context(query)
context_token_ct = self.context_builder.llm_tokens
primer_context, token_ct = self.context_builder.build_context(query)
llm_calls["build_context"] = token_ct["llm_calls"]
prompt_tokens["build_context"] = token_ct["prompt_tokens"]
output_tokens["build_context"] = token_ct["prompt_tokens"]
primer_response = await self.primer.asearch(
query=query, top_k_reports=primer_context
)
primer_token_ct = primer_response.prompt_tokens
llm_calls["primer"] = primer_response.llm_calls
prompt_tokens["primer"] = primer_response.prompt_tokens
output_tokens["primer"] = primer_response.output_tokens
# Package response into DriftAction
init_action = self._process_primer_results(query, primer_response)
self.query_state.add_action(init_action)
@ -233,9 +238,10 @@ class DRIFTSearch(BaseSearch[DRIFTSearchContextBuilder]):
t_elapsed = time.perf_counter() - start_time
# Calculate token usage
total_tokens = (
primer_token_ct + context_token_ct + self.query_state.action_token_ct()
)
token_ct = self.query_state.action_token_ct()
llm_calls["action"] = token_ct["llm_calls"]
prompt_tokens["action"] = token_ct["prompt_tokens"]
output_tokens["action"] = token_ct["output_tokens"]
# Package up context data
response_state, context_data, context_text = self.query_state.serialize(
@ -247,10 +253,12 @@ class DRIFTSearch(BaseSearch[DRIFTSearchContextBuilder]):
context_data=context_data,
context_text=context_text,
completion_time=t_elapsed,
llm_calls=1
+ self.config.primer_folds
+ (self.config.drift_k_followups - llm_call_offset) * self.config.n_depth,
prompt_tokens=total_tokens,
llm_calls=sum(llm_calls.values()),
prompt_tokens=sum(prompt_tokens.values()),
output_tokens=sum(output_tokens.values()),
llm_calls_categories=llm_calls,
prompt_tokens_categories=prompt_tokens,
output_tokens_categories=output_tokens,
)
def search(

View File

@ -136,6 +136,15 @@ class QueryState:
if source_action and target_action:
self.relate_actions(source_action, target_action, weight)
def action_token_ct(self) -> int:
def action_token_ct(self) -> dict[str, int]:
"""Return the token count of the action."""
return sum(action.metadata.get("token_ct", 0) for action in self.graph.nodes)
llm_calls, prompt_tokens, output_tokens = 0, 0, 0
for action in self.graph.nodes:
llm_calls += action.metadata["llm_calls"]
prompt_tokens += action.metadata["prompt_tokens"]
output_tokens += action.metadata["output_tokens"]
return {
"llm_calls": llm_calls,
"prompt_tokens": prompt_tokens,
"output_tokens": output_tokens,
}

View File

@ -5,16 +5,19 @@
from typing import Any
import pandas as pd
import tiktoken
from graphrag.model import CommunityReport, Entity
from graphrag.model import Community, CommunityReport, Entity
from graphrag.query.context_builder.builders import ContextBuilderResult
from graphrag.query.context_builder.community_context import (
build_community_context,
)
from graphrag.query.context_builder.conversation_history import (
ConversationHistory,
)
from graphrag.query.context_builder.dynamic_community_selection import (
DynamicCommunitySelection,
)
from graphrag.query.structured_search.base import GlobalContextBuilder
@ -24,17 +27,32 @@ class GlobalCommunityContext(GlobalContextBuilder):
def __init__(
self,
community_reports: list[CommunityReport],
communities: list[Community],
entities: list[Entity] | None = None,
token_encoder: tiktoken.Encoding | None = None,
dynamic_community_selection: bool = False,
dynamic_community_selection_kwargs: dict[str, Any] | None = None,
random_state: int = 86,
):
self.community_reports = community_reports
self.entities = entities
self.token_encoder = token_encoder
self.dynamic_community_selection = None
if dynamic_community_selection and isinstance(
dynamic_community_selection_kwargs, dict
):
self.dynamic_community_selection = DynamicCommunitySelection(
community_reports=community_reports,
communities=communities,
llm=dynamic_community_selection_kwargs.pop("llm"),
token_encoder=dynamic_community_selection_kwargs.pop("token_encoder"),
**dynamic_community_selection_kwargs,
)
self.random_state = random_state
def build_context(
async def build_context(
self,
query: str,
conversation_history: ConversationHistory | None = None,
use_community_summary: bool = True,
column_delimiter: str = "|",
@ -50,10 +68,11 @@ class GlobalCommunityContext(GlobalContextBuilder):
conversation_history_user_turns_only: bool = True,
conversation_history_max_turns: int | None = 5,
**kwargs: Any,
) -> tuple[str | list[str], dict[str, pd.DataFrame]]:
) -> ContextBuilderResult:
"""Prepare batches of community report data table as context data for global search."""
conversation_history_context = ""
final_context_data = {}
llm_calls, prompt_tokens, output_tokens = 0, 0, 0
if conversation_history:
# build conversation history context
(
@ -69,8 +88,18 @@ class GlobalCommunityContext(GlobalContextBuilder):
if conversation_history_context != "":
final_context_data = conversation_history_context_data
community_reports = self.community_reports
if self.dynamic_community_selection is not None:
(
community_reports,
dynamic_info,
) = await self.dynamic_community_selection.select(query)
llm_calls += dynamic_info["llm_calls"]
prompt_tokens += dynamic_info["prompt_tokens"]
output_tokens += dynamic_info["output_tokens"]
community_context, community_context_data = build_community_context(
community_reports=self.community_reports,
community_reports=community_reports,
entities=self.entities,
token_encoder=self.token_encoder,
use_community_summary=use_community_summary,
@ -104,4 +133,10 @@ class GlobalCommunityContext(GlobalContextBuilder):
# Update the final context data with the provided community_context_data
final_context_data.update(community_context_data)
return final_context, final_context_data
return ContextBuilderResult(
context_chunks=final_context,
context_records=final_context_data,
llm_calls=llm_calls,
prompt_tokens=prompt_tokens,
output_tokens=output_tokens,
)

View File

@ -45,7 +45,7 @@ DEFAULT_REDUCE_LLM_PARAMS = {
log = logging.getLogger(__name__)
@dataclass
@dataclass(kw_only=True)
class GlobalSearchResult(SearchResult):
"""A GlobalSearch result."""
@ -105,24 +105,25 @@ class GlobalSearch(BaseSearch[GlobalContextBuilder]):
conversation_history: ConversationHistory | None = None,
) -> AsyncGenerator:
"""Stream the global search response."""
context_chunks, context_records = self.context_builder.build_context(
conversation_history=conversation_history, **self.context_builder_params
context_result = await self.context_builder.build_context(
conversation_history=conversation_history,
**self.context_builder_params,
)
if self.callbacks:
for callback in self.callbacks:
callback.on_map_response_start(context_chunks) # type: ignore
callback.on_map_response_start(context_result.context_chunks) # type: ignore
map_responses = await asyncio.gather(*[
self._map_response_single_batch(
context_data=data, query=query, **self.map_llm_params
)
for data in context_chunks
for data in context_result.context_chunks
])
if self.callbacks:
for callback in self.callbacks:
callback.on_map_response_end(map_responses) # type: ignore
# send context records first before sending the reduce response
yield context_records
yield context_result.context_records
async for response in self._stream_reduce_response(
map_responses=map_responses, # type: ignore
query=query,
@ -145,26 +146,33 @@ class GlobalSearch(BaseSearch[GlobalContextBuilder]):
- Step 2: Combine the answers from step 2 to generate the final answer
"""
# Step 1: Generate answers for each batch of community short summaries
llm_calls, prompt_tokens, output_tokens = {}, {}, {}
start_time = time.time()
context_chunks, context_records = self.context_builder.build_context(
conversation_history=conversation_history, **self.context_builder_params
context_result = await self.context_builder.build_context(
query=query,
conversation_history=conversation_history,
**self.context_builder_params,
)
llm_calls["build_context"] = context_result.llm_calls
prompt_tokens["build_context"] = context_result.prompt_tokens
output_tokens["build_context"] = context_result.output_tokens
if self.callbacks:
for callback in self.callbacks:
callback.on_map_response_start(context_chunks) # type: ignore
callback.on_map_response_start(context_result.context_chunks) # type: ignore
map_responses = await asyncio.gather(*[
self._map_response_single_batch(
context_data=data, query=query, **self.map_llm_params
)
for data in context_chunks
for data in context_result.context_chunks
])
if self.callbacks:
for callback in self.callbacks:
callback.on_map_response_end(map_responses)
map_llm_calls = sum(response.llm_calls for response in map_responses)
map_prompt_tokens = sum(response.prompt_tokens for response in map_responses)
llm_calls["map"] = sum(response.llm_calls for response in map_responses)
prompt_tokens["map"] = sum(response.prompt_tokens for response in map_responses)
output_tokens["map"] = sum(response.output_tokens for response in map_responses)
# Step 2: Combine the intermediate answers from step 2 to generate the final answer
reduce_response = await self._reduce_response(
@ -172,17 +180,24 @@ class GlobalSearch(BaseSearch[GlobalContextBuilder]):
query=query,
**self.reduce_llm_params,
)
llm_calls["reduce"] = reduce_response.llm_calls
prompt_tokens["reduce"] = reduce_response.prompt_tokens
output_tokens["reduce"] = reduce_response.output_tokens
return GlobalSearchResult(
response=reduce_response.response,
context_data=context_records,
context_text=context_chunks,
context_data=context_result.context_records,
context_text=context_result.context_chunks,
map_responses=map_responses,
reduce_context_data=reduce_response.context_data,
reduce_context_text=reduce_response.context_text,
completion_time=time.time() - start_time,
llm_calls=map_llm_calls + reduce_response.llm_calls,
prompt_tokens=map_prompt_tokens + reduce_response.prompt_tokens,
llm_calls=sum(llm_calls.values()),
prompt_tokens=sum(prompt_tokens.values()),
output_tokens=sum(output_tokens.values()),
llm_calls_categories=llm_calls,
prompt_tokens_categories=prompt_tokens,
output_tokens_categories=output_tokens,
)
def search(
@ -230,6 +245,7 @@ class GlobalSearch(BaseSearch[GlobalContextBuilder]):
completion_time=time.time() - start_time,
llm_calls=1,
prompt_tokens=num_tokens(search_prompt, self.token_encoder),
output_tokens=num_tokens(search_response, self.token_encoder),
)
except Exception:
@ -241,6 +257,7 @@ class GlobalSearch(BaseSearch[GlobalContextBuilder]):
completion_time=time.time() - start_time,
llm_calls=1,
prompt_tokens=num_tokens(search_prompt, self.token_encoder),
output_tokens=0,
)
def parse_search_response(self, search_response: str) -> list[dict[str, Any]]:
@ -319,6 +336,7 @@ class GlobalSearch(BaseSearch[GlobalContextBuilder]):
completion_time=time.time() - start_time,
llm_calls=0,
prompt_tokens=0,
output_tokens=0,
)
filtered_key_points = sorted(
@ -372,6 +390,7 @@ class GlobalSearch(BaseSearch[GlobalContextBuilder]):
completion_time=time.time() - start_time,
llm_calls=1,
prompt_tokens=num_tokens(search_prompt, self.token_encoder),
output_tokens=num_tokens(search_response, self.token_encoder),
)
except Exception:
log.exception("Exception in reduce_response")
@ -382,6 +401,7 @@ class GlobalSearch(BaseSearch[GlobalContextBuilder]):
completion_time=time.time() - start_time,
llm_calls=1,
prompt_tokens=num_tokens(search_prompt, self.token_encoder),
output_tokens=0,
)
async def _stream_reduce_response(

View File

@ -16,6 +16,7 @@ from graphrag.model import (
Relationship,
TextUnit,
)
from graphrag.query.context_builder.builders import ContextBuilderResult
from graphrag.query.context_builder.community_context import (
build_community_context,
)
@ -113,7 +114,7 @@ class LocalSearchMixedContext(LocalContextBuilder):
community_context_name: str = "Reports",
column_delimiter: str = "|",
**kwargs: dict[str, Any],
) -> tuple[str | list[str], dict[str, pd.DataFrame]]:
) -> ContextBuilderResult:
"""
Build data context for local search prompt.
@ -217,7 +218,10 @@ class LocalSearchMixedContext(LocalContextBuilder):
final_context.append(text_unit_context)
final_context_data = {**final_context_data, **text_unit_context_data}
return ("\n\n".join(final_context), final_context_data)
return ContextBuilderResult(
context_chunks="\n\n".join(final_context),
context_records=final_context_data,
)
def _build_community_context(
self,

View File

@ -63,25 +63,30 @@ class LocalSearch(BaseSearch[LocalContextBuilder]):
"""Build local search context that fits a single context window and generate answer for the user query."""
start_time = time.time()
search_prompt = ""
context_text, context_records = self.context_builder.build_context(
llm_calls, prompt_tokens, output_tokens = {}, {}, {}
context_result = self.context_builder.build_context(
query=query,
conversation_history=conversation_history,
**kwargs,
**self.context_builder_params,
)
llm_calls["build_context"] = context_result.llm_calls
prompt_tokens["build_context"] = context_result.prompt_tokens
output_tokens["build_context"] = context_result.output_tokens
log.info("GENERATE ANSWER: %s. QUERY: %s", start_time, query)
try:
if "drift_query" in kwargs:
drift_query = kwargs["drift_query"]
search_prompt = self.system_prompt.format(
context_data=context_text,
context_data=context_result.context_chunks,
response_type=self.response_type,
global_query=drift_query,
)
else:
search_prompt = self.system_prompt.format(
context_data=context_text, response_type=self.response_type
context_data=context_result.context_chunks,
response_type=self.response_type,
)
search_messages = [
{"role": "system", "content": search_prompt},
@ -94,25 +99,33 @@ class LocalSearch(BaseSearch[LocalContextBuilder]):
callbacks=self.callbacks,
**self.llm_params,
)
llm_calls["response"] = 1
prompt_tokens["response"] = num_tokens(search_prompt, self.token_encoder)
output_tokens["response"] = num_tokens(response, self.token_encoder)
return SearchResult(
response=response,
context_data=context_records,
context_text=context_text,
context_data=context_result.context_records,
context_text=context_result.context_chunks,
completion_time=time.time() - start_time,
llm_calls=1,
prompt_tokens=num_tokens(search_prompt, self.token_encoder),
llm_calls=sum(llm_calls.values()),
prompt_tokens=sum(prompt_tokens.values()),
output_tokens=sum(output_tokens.values()),
llm_calls_categories=llm_calls,
prompt_tokens_categories=prompt_tokens,
output_tokens_categories=output_tokens,
)
except Exception:
log.exception("Exception in _asearch")
return SearchResult(
response="",
context_data=context_records,
context_text=context_text,
context_data=context_result.context_records,
context_text=context_result.context_chunks,
completion_time=time.time() - start_time,
llm_calls=1,
prompt_tokens=num_tokens(search_prompt, self.token_encoder),
output_tokens=0,
)
async def astream_search(
@ -123,14 +136,14 @@ class LocalSearch(BaseSearch[LocalContextBuilder]):
"""Build local search context that fits a single context window and generate answer for the user query."""
start_time = time.time()
context_text, context_records = self.context_builder.build_context(
context_result = self.context_builder.build_context(
query=query,
conversation_history=conversation_history,
**self.context_builder_params,
)
log.info("GENERATE ANSWER: %s. QUERY: %s", start_time, query)
search_prompt = self.system_prompt.format(
context_data=context_text, response_type=self.response_type
context_data=context_result.context_chunks, response_type=self.response_type
)
search_messages = [
{"role": "system", "content": search_prompt},
@ -138,7 +151,7 @@ class LocalSearch(BaseSearch[LocalContextBuilder]):
]
# send context records first before sending the reduce response
yield context_records
yield context_result.context_records
async for response in self.llm.astream_generate( # type: ignore
messages=search_messages,
callbacks=self.callbacks,
@ -155,16 +168,22 @@ class LocalSearch(BaseSearch[LocalContextBuilder]):
"""Build local search context that fits a single context window and generate answer for the user question."""
start_time = time.time()
search_prompt = ""
context_text, context_records = self.context_builder.build_context(
llm_calls, prompt_tokens, output_tokens = {}, {}, {}
context_result = self.context_builder.build_context(
query=query,
conversation_history=conversation_history,
**kwargs,
**self.context_builder_params,
)
llm_calls["build_context"] = context_result.llm_calls
prompt_tokens["build_context"] = context_result.prompt_tokens
output_tokens["build_context"] = context_result.output_tokens
log.info("GENERATE ANSWER: %d. QUERY: %s", start_time, query)
try:
search_prompt = self.system_prompt.format(
context_data=context_text, response_type=self.response_type
context_data=context_result.context_chunks,
response_type=self.response_type,
)
search_messages = [
{"role": "system", "content": search_prompt},
@ -177,23 +196,31 @@ class LocalSearch(BaseSearch[LocalContextBuilder]):
callbacks=self.callbacks,
**self.llm_params,
)
llm_calls["response"] = 1
prompt_tokens["response"] = num_tokens(search_prompt, self.token_encoder)
output_tokens["response"] = num_tokens(response, self.token_encoder)
return SearchResult(
response=response,
context_data=context_records,
context_text=context_text,
context_data=context_result.context_records,
context_text=context_result.context_chunks,
completion_time=time.time() - start_time,
llm_calls=1,
prompt_tokens=num_tokens(search_prompt, self.token_encoder),
llm_calls=sum(llm_calls.values()),
prompt_tokens=sum(prompt_tokens.values()),
output_tokens=sum(output_tokens.values()),
llm_calls_categories=llm_calls,
prompt_tokens_categories=prompt_tokens,
output_tokens_categories=output_tokens,
)
except Exception:
log.exception("Exception in _map_response_single_batch")
return SearchResult(
response="",
context_data=context_records,
context_text=context_text,
context_data=context_result.context_records,
context_text=context_result.context_chunks,
completion_time=time.time() - start_time,
llm_calls=1,
prompt_tokens=num_tokens(search_prompt, self.token_encoder),
output_tokens=0,
)