In [None]:
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License.

# Multi Index Search
This notebook demonstrates multi-index search using the GraphRAG API.

Indexes created from Wikipedia state articles for Alaska, California, DC, Maryland, NY and Washington are used.

In [None]:
import asyncio

import pandas as pd

from graphrag.api.query import (
    multi_index_basic_search,
    multi_index_drift_search,
    multi_index_global_search,
    multi_index_local_search,
)
from graphrag.config.create_graphrag_config import create_graphrag_config

indexes = ["alaska", "california", "dc", "maryland", "ny", "washington"]
indexes = sorted(indexes)

print(indexes)

vector_store_configs = {
    index: {
        "type": "lancedb",
        "db_uri": f"inputs/{index}/lancedb",
        "container_name": "default",
        "overwrite": True,
        "index_name": f"{index}",
    }
    for index in indexes
}

In [None]:
config_data = {
    "models": {
        "default_chat_model": {
            "model_supports_json": True,
            "parallelization_num_threads": 50,
            "parallelization_stagger": 0.3,
            "async_mode": "threaded",
            "type": "azure_openai_chat",
            "model": "gpt-4o",
            "auth_type": "azure_managed_identity",
            "api_base": "<API_BASE_URL>",
            "api_version": "2024-02-15-preview",
            "deployment_name": "gpt-4o",
        },
        "default_embedding_model": {
            "parallelization_num_threads": 50,
            "parallelization_stagger": 0.3,
            "async_mode": "threaded",
            "type": "azure_openai_embedding",
            "model": "text-embedding-3-large",
            "auth_type": "azure_managed_identity",
            "api_base": "<API_BASE_URL>",
            "api_version": "2024-02-15-preview",
            "deployment_name": "text-embedding-3-large",
        },
    },
    "vector_store": vector_store_configs,
    "local_search": {
        "prompt": "prompts/local_search_system_prompt.txt",
        "llm_max_tokens": 12000,
    },
    "global_search": {
        "map_prompt": "prompts/global_search_map_system_prompt.txt",
        "reduce_prompt": "prompts/global_search_reduce_system_prompt.txt",
        "knowledge_prompt": "prompts/global_search_knowledge_system_prompt.txt",
    },
    "drift_search": {
        "prompt": "prompts/drift_search_system_prompt.txt",
        "reduce_prompt": "prompts/drift_search_reduce_prompt.txt",
    },
    "basic_search": {"prompt": "prompts/basic_search_system_prompt.txt"},
}
parameters = create_graphrag_config(config_data, ".")
loop = asyncio.get_event_loop()

### Multi-index Global Search

In [None]:
entities = [pd.read_parquet(f"inputs/{index}/entities.parquet") for index in indexes]
communities = [
    pd.read_parquet(f"inputs/{index}/communities.parquet") for index in indexes
]
community_reports = [
    pd.read_parquet(f"inputs/{index}/community_reports.parquet") for index in indexes
]

task = loop.create_task(
    multi_index_global_search(
        parameters,
        entities,
        communities,
        community_reports,
        indexes,
        1,
        False,
        "Multiple Paragraphs",
        False,
        "Describe this dataset.",
    )
)
results = await task

#### Print report

In [None]:
print(results[0])

#### Show context links back to original index

In [None]:
for report_id in [120, 129, 40, 16, 204, 143, 85, 122, 83]:
    index_name = [i for i in results[1]["reports"] if i["id"] == str(report_id)][0][  # noqa: RUF015
        "index_name"
    ]
    index_id = [i for i in results[1]["reports"] if i["id"] == str(report_id)][0][  # noqa: RUF015
        "index_id"
    ]
    print(report_id, index_name, index_id)
    index_reports = pd.read_parquet(
        f"inputs/{index_name}/create_final_community_reports.parquet"
    )
    print([i for i in results[1]["reports"] if i["id"] == str(report_id)][0]["title"])  # noqa: RUF015
    print(
        index_reports[index_reports["community"] == int(index_id)]["title"].to_numpy()[
            0
        ]
    )

#### Multi-index Local Search

In [None]:
entities = [pd.read_parquet(f"inputs/{index}/entities.parquet") for index in indexes]
communities = [
    pd.read_parquet(f"inputs/{index}/communities.parquet") for index in indexes
]
community_reports = [
    pd.read_parquet(f"inputs/{index}/community_reports.parquet") for index in indexes
]
covariates = [
    pd.read_parquet(f"inputs/{index}/covariates.parquet") for index in indexes
]
text_units = [
    pd.read_parquet(f"inputs/{index}/text_units.parquet") for index in indexes
]
relationships = [
    pd.read_parquet(f"inputs/{index}/relationships.parquet") for index in indexes
]

task = loop.create_task(
    multi_index_local_search(
        parameters,
        entities,
        communities,
        community_reports,
        text_units,
        relationships,
        covariates,
        indexes,
        1,
        "Multiple Paragraphs",
        False,
        "weather",
    )
)
results = await task

#### Print report

In [None]:
print(results[0])

#### Show context links back to original index

In [None]:
for report_id in [47, 213]:
    index_name = [i for i in results[1]["reports"] if i["id"] == str(report_id)][0][  # noqa: RUF015
        "index_name"
    ]
    index_id = [i for i in results[1]["reports"] if i["id"] == str(report_id)][0][  # noqa: RUF015
        "index_id"
    ]
    print(report_id, index_name, index_id)
    index_reports = pd.read_parquet(
        f"inputs/{index_name}/create_final_community_reports.parquet"
    )
    print([i for i in results[1]["reports"] if i["id"] == str(report_id)][0]["title"])  # noqa: RUF015
    print(
        index_reports[index_reports["community"] == int(index_id)]["title"].to_numpy()[
            0
        ]
    )
for entity_id in [500, 502, 506, 1960, 1961, 1962]:
    index_name = [i for i in results[1]["entities"] if i["id"] == str(entity_id)][0][  # noqa: RUF015
        "index_name"
    ]
    index_id = [i for i in results[1]["entities"] if i["id"] == str(entity_id)][0][  # noqa: RUF015
        "index_id"
    ]
    print(entity_id, index_name, index_id)
    index_entities = pd.read_parquet(
        f"inputs/{index_name}/create_final_entities.parquet"
    )
    print(
        [i for i in results[1]["entities"] if i["id"] == str(entity_id)][0][  # noqa: RUF015
            "description"
        ][:100]
    )
    print(
        index_entities[index_entities["human_readable_id"] == int(index_id)][
            "description"
        ].to_numpy()[0][:100]
    )
for relationship_id in [1805, 1806]:
    index_name = [  # noqa: RUF015
        i for i in results[1]["relationships"] if i["id"] == str(relationship_id)
    ][0]["index_name"]
    index_id = [  # noqa: RUF015
        i for i in results[1]["relationships"] if i["id"] == str(relationship_id)
    ][0]["index_id"]
    print(relationship_id, index_name, index_id)
    index_relationships = pd.read_parquet(
        f"inputs/{index_name}/create_final_relationships.parquet"
    )
    print(
        [i for i in results[1]["relationships"] if i["id"] == str(relationship_id)][0][  # noqa: RUF015
            "description"
        ]
    )
    print(
        index_relationships[index_relationships["human_readable_id"] == int(index_id)][
            "description"
        ].to_numpy()[0]
    )
for claim_id in [100]:
    index_name = [i for i in results[1]["claims"] if i["id"] == str(claim_id)][0][  # noqa: RUF015
        "index_name"
    ]
    index_id = [i for i in results[1]["claims"] if i["id"] == str(claim_id)][0][  # noqa: RUF015
        "index_id"
    ]
    print(relationship_id, index_name, index_id)
    index_claims = pd.read_parquet(
        f"inputs/{index_name}/create_final_covariates.parquet"
    )
    print(
        [i for i in results[1]["claims"] if i["id"] == str(claim_id)][0]["description"]  # noqa: RUF015
    )
    print(
        index_claims[index_claims["human_readable_id"] == int(index_id)][
            "description"
        ].to_numpy()[0]
    )

### Multi-index Drift Search

In [None]:
entities = [pd.read_parquet(f"inputs/{index}/entities.parquet") for index in indexes]
communities = [
    pd.read_parquet(f"inputs/{index}/communities.parquet") for index in indexes
]
community_reports = [
    pd.read_parquet(f"inputs/{index}/community_reports.parquet") for index in indexes
]
text_units = [
    pd.read_parquet(f"inputs/{index}/text_units.parquet") for index in indexes
]
relationships = [
    pd.read_parquet(f"inputs/{index}/relationships.parquet") for index in indexes
]

task = loop.create_task(
    multi_index_drift_search(
        parameters,
        entities,
        communities,
        community_reports,
        text_units,
        relationships,
        indexes,
        1,
        "Multiple Paragraphs",
        False,
        "agriculture",
    )
)
results = await task

#### Print report

In [None]:
print(results[0])

#### Show context links back to original index

In [None]:
for report_id in [47, 236]:
    for question in results[1]:
        resq = results[1][question]
        if len(resq["reports"]) == 0:
            continue
        if len([i for i in resq["reports"] if i["id"] == str(report_id)]) == 0:
            continue
        index_name = [i for i in resq["reports"] if i["id"] == str(report_id)][0][  # noqa: RUF015
            "index_name"
        ]
        index_id = [i for i in resq["reports"] if i["id"] == str(report_id)][0][  # noqa: RUF015
            "index_id"
        ]
        print(question, report_id, index_name, index_id)
        index_reports = pd.read_parquet(
            f"inputs/{index_name}/create_final_community_reports.parquet"
        )
        print([i for i in resq["reports"] if i["id"] == str(report_id)][0]["title"])  # noqa: RUF015
        print(
            index_reports[index_reports["community"] == int(index_id)][
                "title"
            ].to_numpy()[0]
        )
        break
for source_id in [10, 16, 19, 20, 21, 22, 24, 29, 93, 95]:
    for question in results[1]:
        resq = results[1][question]
        if len(resq["sources"]) == 0:
            continue
        if len([i for i in resq["sources"] if i["id"] == str(source_id)]) == 0:
            continue
        index_name = [i for i in resq["sources"] if i["id"] == str(source_id)][0][  # noqa: RUF015
            "index_name"
        ]
        index_id = [i for i in resq["sources"] if i["id"] == str(source_id)][0][  # noqa: RUF015
            "index_id"
        ]
        print(question, source_id, index_name, index_id)
        index_sources = pd.read_parquet(
            f"inputs/{index_name}/create_final_text_units.parquet"
        )
        print(
            [i for i in resq["sources"] if i["id"] == str(source_id)][0]["text"][:250]  # noqa: RUF015
        )
        print(index_sources.loc[int(index_id)]["text"][:250])
        break

### Multi-index Basic Search

In [None]:
text_units = [
    pd.read_parquet(f"inputs/{index}/text_units.parquet") for index in indexes
]

task = loop.create_task(
    multi_index_basic_search(
        parameters, text_units, indexes, False, "industry in maryland"
    )
)
results = await task

#### Print report

In [None]:
print(results[0])

#### Show context links back to original text

Note that original index name is not saved in context data for basic search

In [None]:
for source_id in [0, 1]:
    print(results[1]["sources"][source_id]["text"][:250])