Move prompts (#1404)

* Move indexing prompts to root

* Move query prompts to root

* Export query prompts during init

* Extract general knowledge prompt

* Load query prompts from disk

* Semver

* Fix unit tests
This commit is contained in:
Nathan Evans 2024-11-14 10:45:37 -08:00 committed by GitHub
parent c8c354e357
commit 51912b2e03
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
43 changed files with 269 additions and 93 deletions

View File

@ -0,0 +1,4 @@
{
"type": "patch",
"description": "Centralized prompts and export all for easier injection."
}

View File

@ -8,7 +8,7 @@ Each of these prompts may be overridden by writing a custom prompt file in plain
## Entity/Relationship Extraction
[Prompt Source](http://github.com/microsoft/graphrag/blob/main/graphrag/index/graph/extractors/graph/prompts.py)
[Prompt Source](http://github.com/microsoft/graphrag/blob/main/graphrag/prompts/entity_extraction.py)
### Tokens (values provided by extractor)
@ -20,7 +20,7 @@ Each of these prompts may be overridden by writing a custom prompt file in plain
## Summarize Entity/Relationship Descriptions
[Prompt Source](http://github.com/microsoft/graphrag/blob/main/graphrag/index/graph/extractors/summarize/prompts.py)
[Prompt Source](http://github.com/microsoft/graphrag/blob/main/graphrag/prompts/summarize_descriptions.py)
### Tokens (values provided by extractor)
@ -29,7 +29,7 @@ Each of these prompts may be overridden by writing a custom prompt file in plain
## Claim Extraction
[Prompt Source](http://github.com/microsoft/graphrag/blob/main/graphrag/index/graph/extractors/claims/prompts.py)
[Prompt Source](http://github.com/microsoft/graphrag/blob/main/graphrag/prompts/claim_extraction.py)
### Tokens (values provided by extractor)
@ -47,7 +47,7 @@ See the [configuration documentation](../config/overview.md) for details on how
## Generate Community Reports
[Prompt Source](http://github.com/microsoft/graphrag/blob/main/graphrag/index/graph/extractors/community_reports/prompts.py)
[Prompt Source](http://github.com/microsoft/graphrag/blob/main/graphrag/prompts/community_report.py)
### Tokens (values provided by extractor)

View File

@ -56,11 +56,11 @@ Below are the key parameters of the [GlobalSearch class](https://github.com/micr
* `llm`: OpenAI model object to be used for response generation
* `context_builder`: [context builder](https://github.com/microsoft/graphrag/blob/main//graphrag/query/structured_search/global_search/community_context.py) object to be used for preparing context data from community reports
* `map_system_prompt`: prompt template used in the `map` stage. Default template can be found at [map_system_prompt](https://github.com/microsoft/graphrag/blob/main//graphrag/query/structured_search/global_search/map_system_prompt.py)
* `reduce_system_prompt`: prompt template used in the `reduce` stage, default template can be found at [reduce_system_prompt](https://github.com/microsoft/graphrag/blob/main//graphrag/query/structured_search/global_search/reduce_system_prompt.py)
* `map_system_prompt`: prompt template used in the `map` stage. Default template can be found at [map_system_prompt](https://github.com/microsoft/graphrag/blob/main//graphrag/prompts/query/global_search_map_system_prompt.py)
* `reduce_system_prompt`: prompt template used in the `reduce` stage, default template can be found at [reduce_system_prompt](https://github.com/microsoft/graphrag/blob/main//graphrag/prompts/query/global_search_reduce_system_prompt.py)
* `response_type`: free-form text describing the desired response type and format (e.g., `Multiple Paragraphs`, `Multi-Page Report`)
* `allow_general_knowledge`: setting this to True will include additional instructions to the `reduce_system_prompt` to prompt the LLM to incorporate relevant real-world knowledge outside of the dataset. Note that this may increase hallucinations, but can be useful for certain scenarios. Default is False
*`general_knowledge_inclusion_prompt`: instruction to add to the `reduce_system_prompt` if `allow_general_knowledge` is enabled. Default instruction can be found at [general_knowledge_instruction](https://github.com/microsoft/graphrag/blob/main//graphrag/query/structured_search/global_search/reduce_system_prompt.py)
*`general_knowledge_inclusion_prompt`: instruction to add to the `reduce_system_prompt` if `allow_general_knowledge` is enabled. Default instruction can be found at [general_knowledge_instruction](https://github.com/microsoft/graphrag/blob/main//graphrag/prompts/query/global_search_knowledge_system_prompt.py)
* `max_data_tokens`: token budget for the context data
* `map_llm_params`: a dictionary of additional parameters (e.g., temperature, max_tokens) to be passed to the LLM call at the `map` stage
* `reduce_llm_params`: a dictionary of additional parameters (e.g., temperature, max_tokens) to passed to the LLM call at the `reduce` stage

View File

@ -50,7 +50,7 @@ Below are the key parameters of the [LocalSearch class](https://github.com/micro
* `llm`: OpenAI model object to be used for response generation
* `context_builder`: [context builder](https://github.com/microsoft/graphrag/blob/main//graphrag/query/structured_search/local_search/mixed_context.py) object to be used for preparing context data from collections of knowledge model objects
* `system_prompt`: prompt template used to generate the search response. Default template can be found at [system_prompt](https://github.com/microsoft/graphrag/blob/main//graphrag/query/structured_search/local_search/system_prompt.py)
* `system_prompt`: prompt template used to generate the search response. Default template can be found at [system_prompt](https://github.com/microsoft/graphrag/blob/main//graphrag/prompts/query/local_search_system_prompt.py)
* `response_type`: free-form text describing the desired response type and format (e.g., `Multiple Paragraphs`, `Multi-Page Report`)
* `llm_params`: a dictionary of additional parameters (e.g., temperature, max_tokens) to be passed to the LLM call
* `context_builder_params`: a dictionary of additional parameters to be passed to the [`context_builder`](https://github.com/microsoft/graphrag/blob/main//graphrag/query/structured_search/local_search/mixed_context.py) object when building context for the search prompt

View File

@ -13,7 +13,7 @@ Below are the key parameters of the [Question Generation class](https://github.c
* `llm`: OpenAI model object to be used for response generation
* `context_builder`: [context builder](https://github.com/microsoft/graphrag/blob/main//graphrag/query/structured_search/local_search/mixed_context.py) object to be used for preparing context data from collections of knowledge model objects, using the same context builder class as in local search
* `system_prompt`: prompt template used to generate candidate questions. Default template can be found at [system_prompt](https://github.com/microsoft/graphrag/blob/main//graphrag/query/question_gen/system_prompt.py)
* `system_prompt`: prompt template used to generate candidate questions. Default template can be found at [system_prompt](https://github.com/microsoft/graphrag/blob/main//graphrag/prompts/query/question_gen_system_prompt.py)
* `llm_params`: a dictionary of additional parameters (e.g., temperature, max_tokens) to be passed to the LLM call
* `context_builder_params`: a dictionary of additional parameters to be passed to the [`context_builder`](https://github.com/microsoft/graphrag/blob/main//graphrag/query/structured_search/local_search/mixed_context.py) object when building context for the question generation prompt
* `callbacks`: optional callback functions, can be used to provide custom event handlers for LLM's completion streaming events

View File

@ -98,6 +98,14 @@ async def global_search(
dynamic_community_selection=dynamic_community_selection,
)
_entities = read_indexer_entities(nodes, entities, community_level=community_level)
map_prompt = _load_search_prompt(config.root_dir, config.global_search.map_prompt)
reduce_prompt = _load_search_prompt(
config.root_dir, config.global_search.reduce_prompt
)
knowledge_prompt = _load_search_prompt(
config.root_dir, config.global_search.knowledge_prompt
)
search_engine = get_global_search_engine(
config,
reports=reports,
@ -105,6 +113,9 @@ async def global_search(
communities=_communities,
response_type=response_type,
dynamic_community_selection=dynamic_community_selection,
map_system_prompt=map_prompt,
reduce_system_prompt=reduce_prompt,
general_knowledge_inclusion_prompt=knowledge_prompt,
)
result: SearchResult = await search_engine.asearch(query=query)
response = result.response
@ -156,6 +167,14 @@ async def global_search_streaming(
dynamic_community_selection=dynamic_community_selection,
)
_entities = read_indexer_entities(nodes, entities, community_level=community_level)
map_prompt = _load_search_prompt(config.root_dir, config.global_search.map_prompt)
reduce_prompt = _load_search_prompt(
config.root_dir, config.global_search.reduce_prompt
)
knowledge_prompt = _load_search_prompt(
config.root_dir, config.global_search.knowledge_prompt
)
search_engine = get_global_search_engine(
config,
reports=reports,
@ -163,6 +182,9 @@ async def global_search_streaming(
communities=_communities,
response_type=response_type,
dynamic_community_selection=dynamic_community_selection,
map_system_prompt=map_prompt,
reduce_system_prompt=reduce_prompt,
general_knowledge_inclusion_prompt=knowledge_prompt,
)
search_result = search_engine.astream_search(query=query)
@ -238,6 +260,7 @@ async def local_search(
_entities = read_indexer_entities(nodes, entities, community_level)
_covariates = read_indexer_covariates(covariates) if covariates is not None else []
prompt = _load_search_prompt(config.root_dir, config.local_search.prompt)
search_engine = get_local_search_engine(
config=config,
@ -248,6 +271,7 @@ async def local_search(
covariates={"claims": _covariates},
description_embedding_store=description_embedding_store, # type: ignore
response_type=response_type,
system_prompt=prompt,
)
result: SearchResult = await search_engine.asearch(query=query)
@ -312,6 +336,7 @@ async def local_search_streaming(
_entities = read_indexer_entities(nodes, entities, community_level)
_covariates = read_indexer_covariates(covariates) if covariates is not None else []
prompt = _load_search_prompt(config.root_dir, config.local_search.prompt)
search_engine = get_local_search_engine(
config=config,
@ -322,6 +347,7 @@ async def local_search_streaming(
covariates={"claims": _covariates},
description_embedding_store=description_embedding_store, # type: ignore
response_type=response_type,
system_prompt=prompt,
)
search_result = search_engine.astream_search(query=query)
@ -401,7 +427,7 @@ async def drift_search(
_entities = read_indexer_entities(nodes, entities, community_level)
_reports = read_indexer_reports(community_reports, nodes, community_level)
read_indexer_report_embeddings(_reports, full_content_embedding_store)
prompt = _load_search_prompt(config.root_dir, config.drift_search.prompt)
search_engine = get_drift_search_engine(
config=config,
reports=_reports,
@ -409,6 +435,7 @@ async def drift_search(
entities=_entities,
relationships=read_indexer_relationships(relationships),
description_embedding_store=description_embedding_store, # type: ignore
local_system_prompt=prompt,
)
result: SearchResult = await search_engine.asearch(query=query)
@ -551,3 +578,17 @@ def _reformat_context_data(context_data: dict) -> dict:
continue
final_format[key] = records
return final_format
def _load_search_prompt(root_dir: str, prompt_config: str | None) -> str | None:
"""
Load the search prompt from disk if configured.
If not, leave it empty - the search functions will load their defaults.
"""
if prompt_config:
prompt_file = Path(root_dir) / prompt_config
if prompt_file.exists():
return prompt_file.read_bytes().decode(encoding="utf-8")
return None

View File

@ -5,14 +5,24 @@
from pathlib import Path
from graphrag.index.graph.extractors.claims.prompts import CLAIM_EXTRACTION_PROMPT
from graphrag.index.graph.extractors.community_reports.prompts import (
COMMUNITY_REPORT_PROMPT,
)
from graphrag.index.graph.extractors.graph.prompts import GRAPH_EXTRACTION_PROMPT
from graphrag.index.graph.extractors.summarize.prompts import SUMMARIZE_PROMPT
from graphrag.index.init_content import INIT_DOTENV, INIT_YAML
from graphrag.logging import ReporterType, create_progress_reporter
from graphrag.prompts.index.claim_extraction import CLAIM_EXTRACTION_PROMPT
from graphrag.prompts.index.community_report import (
COMMUNITY_REPORT_PROMPT,
)
from graphrag.prompts.index.entity_extraction import GRAPH_EXTRACTION_PROMPT
from graphrag.prompts.index.summarize_descriptions import SUMMARIZE_PROMPT
from graphrag.prompts.query.drift_search_system_prompt import DRIFT_LOCAL_SYSTEM_PROMPT
from graphrag.prompts.query.global_search_knowledge_system_prompt import (
GENERAL_KNOWLEDGE_INSTRUCTION,
)
from graphrag.prompts.query.global_search_map_system_prompt import MAP_SYSTEM_PROMPT
from graphrag.prompts.query.global_search_reduce_system_prompt import (
REDUCE_SYSTEM_PROMPT,
)
from graphrag.prompts.query.local_search_system_prompt import LOCAL_SEARCH_SYSTEM_PROMPT
from graphrag.prompts.query.question_gen_system_prompt import QUESTION_SYSTEM_PROMPT
def initialize_project_at(path: Path) -> None:
@ -40,28 +50,21 @@ def initialize_project_at(path: Path) -> None:
if not prompts_dir.exists():
prompts_dir.mkdir(parents=True, exist_ok=True)
entity_extraction = prompts_dir / "entity_extraction.txt"
if not entity_extraction.exists():
with entity_extraction.open("wb") as file:
file.write(
GRAPH_EXTRACTION_PROMPT.encode(encoding="utf-8", errors="strict")
)
prompts = {
"entity_extraction": GRAPH_EXTRACTION_PROMPT,
"summarize_descriptions": SUMMARIZE_PROMPT,
"claim_extraction": CLAIM_EXTRACTION_PROMPT,
"community_report": COMMUNITY_REPORT_PROMPT,
"drift_search_system_prompt": DRIFT_LOCAL_SYSTEM_PROMPT,
"global_search_map_system_prompt": MAP_SYSTEM_PROMPT,
"global_search_reduce_system_prompt": REDUCE_SYSTEM_PROMPT,
"global_search_knowledge_system_prompt": GENERAL_KNOWLEDGE_INSTRUCTION,
"local_search_system_prompt": LOCAL_SEARCH_SYSTEM_PROMPT,
"question_gen_system_prompt": QUESTION_SYSTEM_PROMPT,
}
summarize_descriptions = prompts_dir / "summarize_descriptions.txt"
if not summarize_descriptions.exists():
with summarize_descriptions.open("wb") as file:
file.write(SUMMARIZE_PROMPT.encode(encoding="utf-8", errors="strict"))
claim_extraction = prompts_dir / "claim_extraction.txt"
if not claim_extraction.exists():
with claim_extraction.open("wb") as file:
file.write(
CLAIM_EXTRACTION_PROMPT.encode(encoding="utf-8", errors="strict")
)
community_report = prompts_dir / "community_report.txt"
if not community_report.exists():
with community_report.open("wb") as file:
file.write(
COMMUNITY_REPORT_PROMPT.encode(encoding="utf-8", errors="strict")
)
for name, content in prompts.items():
prompt_file = prompts_dir / f"{name}.txt"
if not prompt_file.exists():
with prompt_file.open("wb") as file:
file.write(content.encode(encoding="utf-8", errors="strict"))

View File

@ -51,6 +51,7 @@ from .models import (
ClaimExtractionConfig,
ClusterGraphConfig,
CommunityReportsConfig,
DRIFTSearchConfig,
EmbedGraphConfig,
EntityExtractionConfig,
GlobalSearchConfig,
@ -85,6 +86,7 @@ __all__ = [
"ClusterGraphConfigInput",
"CommunityReportsConfig",
"CommunityReportsConfigInput",
"DRIFTSearchConfig",
"EmbedGraphConfig",
"EmbedGraphConfigInput",
"EntityExtractionConfig",

View File

@ -39,6 +39,7 @@ from .models import (
ClaimExtractionConfig,
ClusterGraphConfig,
CommunityReportsConfig,
DRIFTSearchConfig,
EmbedGraphConfig,
EntityExtractionConfig,
GlobalSearchConfig,
@ -514,6 +515,7 @@ def create_graphrag_config(
reader.envvar_prefix(Section.local_search),
):
local_search_model = LocalSearchConfig(
prompt=reader.str("prompt") or None,
text_unit_prop=reader.float("text_unit_prop")
or defs.LOCAL_SEARCH_TEXT_UNIT_PROP,
community_prop=reader.float("community_prop")
@ -541,6 +543,9 @@ def create_graphrag_config(
reader.envvar_prefix(Section.global_search),
):
global_search_model = GlobalSearchConfig(
map_prompt=reader.str("map_prompt") or None,
reduce_prompt=reader.str("reduce_prompt") or None,
knowledge_prompt=reader.str("knowledge_prompt") or None,
temperature=reader.float("llm_temperature")
or defs.GLOBAL_SEARCH_LLM_TEMPERATURE,
top_p=reader.float("llm_top_p") or defs.GLOBAL_SEARCH_LLM_TOP_P,
@ -556,6 +561,54 @@ def create_graphrag_config(
concurrency=reader.int("concurrency") or defs.GLOBAL_SEARCH_CONCURRENCY,
)
with (
reader.use(values.get("drift_search")),
reader.envvar_prefix(Section.drift_search),
):
drift_search_model = DRIFTSearchConfig(
prompt=reader.str("prompt") or None,
temperature=reader.float("llm_temperature")
or defs.DRIFT_SEARCH_LLM_TEMPERATURE,
top_p=reader.float("llm_top_p") or defs.DRIFT_SEARCH_LLM_TOP_P,
n=reader.int("llm_n") or defs.DRIFT_SEARCH_LLM_N,
max_tokens=reader.int(Fragment.max_tokens)
or defs.DRIFT_SEARCH_MAX_TOKENS,
data_max_tokens=reader.int("data_max_tokens")
or defs.DRIFT_SEARCH_DATA_MAX_TOKENS,
concurrency=reader.int("concurrency") or defs.DRIFT_SEARCH_CONCURRENCY,
drift_k_followups=reader.int("drift_k_followups")
or defs.DRIFT_SEARCH_K_FOLLOW_UPS,
primer_folds=reader.int("primer_folds")
or defs.DRIFT_SEARCH_PRIMER_FOLDS,
primer_llm_max_tokens=reader.int("primer_llm_max_tokens")
or defs.DRIFT_SEARCH_PRIMER_MAX_TOKENS,
n_depth=reader.int("n_depth") or defs.DRIFT_N_DEPTH,
local_search_text_unit_prop=reader.float("local_search_text_unit_prop")
or defs.DRIFT_LOCAL_SEARCH_TEXT_UNIT_PROP,
local_search_community_prop=reader.float("local_search_community_prop")
or defs.DRIFT_LOCAL_SEARCH_COMMUNITY_PROP,
local_search_top_k_mapped_entities=reader.int(
"local_search_top_k_mapped_entities"
)
or defs.DRIFT_LOCAL_SEARCH_TOP_K_MAPPED_ENTITIES,
local_search_top_k_relationships=reader.int(
"local_search_top_k_relationships"
)
or defs.DRIFT_LOCAL_SEARCH_TOP_K_RELATIONSHIPS,
local_search_max_data_tokens=reader.int("local_search_max_data_tokens")
or defs.DRIFT_LOCAL_SEARCH_MAX_TOKENS,
local_search_temperature=reader.float("local_search_temperature")
or defs.DRIFT_LOCAL_SEARCH_LLM_TEMPERATURE,
local_search_top_p=reader.float("local_search_top_p")
or defs.DRIFT_LOCAL_SEARCH_LLM_TOP_P,
local_search_n=reader.int("local_search_n")
or defs.DRIFT_LOCAL_SEARCH_LLM_N,
local_search_llm_max_gen_tokens=reader.int(
"local_search_llm_max_gen_tokens"
)
or defs.DRIFT_LOCAL_SEARCH_LLM_MAX_TOKENS,
)
encoding_model = reader.str(Fragment.encoding_model) or defs.ENCODING_MODEL
skip_workflows = reader.list("skip_workflows") or []
@ -583,6 +636,7 @@ def create_graphrag_config(
skip_workflows=skip_workflows,
local_search=local_search_model,
global_search=global_search_model,
drift_search=drift_search_model,
)
@ -649,6 +703,7 @@ class Section(str, Enum):
update_index_storage = "UPDATE_INDEX_STORAGE"
local_search = "LOCAL_SEARCH"
global_search = "GLOBAL_SEARCH"
drift_search = "DRIFT_SEARCH"
def _is_azure(llm_type: LLMType | None) -> bool:

View File

@ -11,6 +11,9 @@ import graphrag.config.defaults as defs
class DRIFTSearchConfig(BaseModel):
"""The default configuration section for Cache."""
prompt: str | None = Field(
description="The drift search prompt to use.", default=None
)
temperature: float = Field(
description="The temperature to use for token generation.",
default=defs.DRIFT_SEARCH_LLM_TEMPERATURE,

View File

@ -11,6 +11,15 @@ import graphrag.config.defaults as defs
class GlobalSearchConfig(BaseModel):
"""The default configuration section for Cache."""
map_prompt: str | None = Field(
description="The global search mapper prompt to use.", default=None
)
reduce_prompt: str | None = Field(
description="The global search reducer to use.", default=None
)
knowledge_prompt: str | None = Field(
description="The global search general prompt to use.", default=None
)
temperature: float | None = Field(
description="The temperature to use for token generation.",
default=defs.GLOBAL_SEARCH_LLM_TEMPERATURE,

View File

@ -13,6 +13,7 @@ from .chunking_config import ChunkingConfig
from .claim_extraction_config import ClaimExtractionConfig
from .cluster_graph_config import ClusterGraphConfig
from .community_reports_config import CommunityReportsConfig
from .drift_config import DRIFTSearchConfig
from .embed_graph_config import EmbedGraphConfig
from .entity_extraction_config import EntityExtractionConfig
from .global_search_config import GlobalSearchConfig
@ -141,6 +142,11 @@ class GraphRagConfig(LLMConfig):
)
"""The global search configuration."""
drift_search: DRIFTSearchConfig = Field(
description="The drift search configuration.", default=DRIFTSearchConfig()
)
"""The drift search configuration."""
encoding_model: str = Field(
description="The encoding model to use.", default=defs.ENCODING_MODEL
)

View File

@ -11,6 +11,9 @@ import graphrag.config.defaults as defs
class LocalSearchConfig(BaseModel):
"""The default configuration section for Cache."""
prompt: str | None = Field(
description="The local search prompt to use.", default=None
)
text_unit_prop: float = Field(
description="The text unit proportion.",
default=defs.LOCAL_SEARCH_TEXT_UNIT_PROP,

View File

@ -3,16 +3,13 @@
"""The Indexing Engine graph extractors package root."""
from .claims import CLAIM_EXTRACTION_PROMPT, ClaimExtractor
from .claims import ClaimExtractor
from .community_reports import (
COMMUNITY_REPORT_PROMPT,
CommunityReportsExtractor,
)
from .graph import GraphExtractionResult, GraphExtractor
__all__ = [
"CLAIM_EXTRACTION_PROMPT",
"COMMUNITY_REPORT_PROMPT",
"ClaimExtractor",
"CommunityReportsExtractor",
"GraphExtractionResult",

View File

@ -4,6 +4,5 @@
"""The Indexing Engine graph extractors claims package root."""
from .claim_extractor import ClaimExtractor
from .prompts import CLAIM_EXTRACTION_PROMPT
__all__ = ["CLAIM_EXTRACTION_PROMPT", "ClaimExtractor"]
__all__ = ["ClaimExtractor"]

View File

@ -13,8 +13,7 @@ import tiktoken
import graphrag.config.defaults as defs
from graphrag.index.typing import ErrorHandlerFn
from graphrag.llm import CompletionLLM
from .prompts import (
from graphrag.prompts.index.claim_extraction import (
CLAIM_EXTRACTION_PROMPT,
CONTINUE_PROMPT,
LOOP_PROMPT,

View File

@ -8,7 +8,6 @@ import graphrag.index.graph.extractors.community_reports.schemas as schemas
from .build_mixed_context import build_mixed_context
from .community_reports_extractor import CommunityReportsExtractor
from .prep_community_report_context import prep_community_report_context
from .prompts import COMMUNITY_REPORT_PROMPT
from .sort_context import sort_context
from .utils import (
filter_claims_to_nodes,
@ -20,7 +19,6 @@ from .utils import (
)
__all__ = [
"COMMUNITY_REPORT_PROMPT",
"CommunityReportsExtractor",
"build_mixed_context",
"filter_claims_to_nodes",

View File

@ -11,8 +11,7 @@ from typing import Any
from graphrag.index.typing import ErrorHandlerFn
from graphrag.index.utils import dict_has_keys_with_types
from graphrag.llm import CompletionLLM
from .prompts import COMMUNITY_REPORT_PROMPT
from graphrag.prompts.index.community_report import COMMUNITY_REPORT_PROMPT
log = logging.getLogger(__name__)

View File

@ -8,11 +8,9 @@ from .graph_extractor import (
GraphExtractionResult,
GraphExtractor,
)
from .prompts import GRAPH_EXTRACTION_PROMPT
__all__ = [
"DEFAULT_ENTITY_TYPES",
"GRAPH_EXTRACTION_PROMPT",
"GraphExtractionResult",
"GraphExtractor",
]

View File

@ -17,8 +17,11 @@ import graphrag.config.defaults as defs
from graphrag.index.typing import ErrorHandlerFn
from graphrag.index.utils import clean_str
from graphrag.llm import CompletionLLM
from .prompts import CONTINUE_PROMPT, GRAPH_EXTRACTION_PROMPT, LOOP_PROMPT
from graphrag.prompts.index.entity_extraction import (
CONTINUE_PROMPT,
GRAPH_EXTRACTION_PROMPT,
LOOP_PROMPT,
)
DEFAULT_TUPLE_DELIMITER = "<|>"
DEFAULT_RECORD_DELIMITER = "##"

View File

@ -7,6 +7,5 @@ from .description_summary_extractor import (
SummarizationResult,
SummarizeExtractor,
)
from .prompts import SUMMARIZE_PROMPT
__all__ = ["SUMMARIZE_PROMPT", "SummarizationResult", "SummarizeExtractor"]
__all__ = ["SummarizationResult", "SummarizeExtractor"]

View File

@ -9,8 +9,7 @@ from dataclasses import dataclass
from graphrag.index.typing import ErrorHandlerFn
from graphrag.index.utils.tokens import num_tokens_from_string
from graphrag.llm import CompletionLLM
from .prompts import SUMMARIZE_PROMPT
from graphrag.prompts.index.summarize_descriptions import SUMMARIZE_PROMPT
# Max token size for input prompts
DEFAULT_MAX_INPUT_TOKENS = 4_000

View File

@ -158,6 +158,7 @@ snapshots:
transient: false
local_search:
prompt: "prompts/local_search_system_prompt.txt"
# text_unit_prop: {defs.LOCAL_SEARCH_TEXT_UNIT_PROP}
# community_prop: {defs.LOCAL_SEARCH_COMMUNITY_PROP}
# conversation_history_max_turns: {defs.LOCAL_SEARCH_CONVERSATION_HISTORY_MAX_TURNS}
@ -169,6 +170,9 @@ local_search:
# max_tokens: {defs.LOCAL_SEARCH_MAX_TOKENS}
global_search:
map_prompt: "prompts/global_search_map_system_prompt.txt"
reduce_prompt: "prompts/global_search_reduce_system_prompt.txt"
knowledge_prompt: "prompts/global_search_knowledge_system_prompt.txt"
# llm_temperature: {defs.GLOBAL_SEARCH_LLM_TEMPERATURE} # temperature for sampling
# llm_top_p: {defs.GLOBAL_SEARCH_LLM_TOP_P} # top-p sampling
# llm_n: {defs.GLOBAL_SEARCH_LLM_N} # Number of completions to generate
@ -177,6 +181,28 @@ global_search:
# map_max_tokens: {defs.GLOBAL_SEARCH_MAP_MAX_TOKENS}
# reduce_max_tokens: {defs.GLOBAL_SEARCH_REDUCE_MAX_TOKENS}
# concurrency: {defs.GLOBAL_SEARCH_CONCURRENCY}
drift_search:
prompt: "prompts/drift_search_system_prompt.txt"
# temperature: {defs.DRIFT_SEARCH_LLM_TEMPERATURE}
# top_p: {defs.DRIFT_SEARCH_LLM_TOP_P}
# n: {defs.DRIFT_SEARCH_LLM_N}
# max_tokens: {defs.DRIFT_SEARCH_MAX_TOKENS}
# data_max_tokens: {defs.DRIFT_SEARCH_DATA_MAX_TOKENS}
# concurrency: {defs.DRIFT_SEARCH_CONCURRENCY}
# drift_k_followups: {defs.DRIFT_SEARCH_K_FOLLOW_UPS}
# primer_folds: {defs.DRIFT_SEARCH_PRIMER_FOLDS}
# primer_llm_max_tokens: {defs.DRIFT_SEARCH_PRIMER_MAX_TOKENS}
# n_depth: {defs.DRIFT_N_DEPTH}
# local_search_text_unit_prop: {defs.DRIFT_LOCAL_SEARCH_TEXT_UNIT_PROP}
# local_search_community_prop: {defs.DRIFT_LOCAL_SEARCH_COMMUNITY_PROP}
# local_search_top_k_mapped_entities: {defs.DRIFT_LOCAL_SEARCH_TOP_K_MAPPED_ENTITIES}
# local_search_top_k_relationships: {defs.DRIFT_LOCAL_SEARCH_TOP_K_RELATIONSHIPS}
# local_search_max_data_tokens: {defs.DRIFT_LOCAL_SEARCH_MAX_TOKENS}
# local_search_temperature: {defs.DRIFT_LOCAL_SEARCH_LLM_TEMPERATURE}
# local_search_top_p: {defs.DRIFT_LOCAL_SEARCH_LLM_TOP_P}
# local_search_n: {defs.DRIFT_LOCAL_SEARCH_LLM_N}
# local_search_llm_max_gen_tokens: {defs.DRIFT_LOCAL_SEARCH_LLM_MAX_TOKENS}
"""
INIT_DOTENV = """\

View File

@ -0,0 +1,4 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""All prompts for the system."""

View File

@ -0,0 +1,4 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""All prompts for indexing."""

View File

@ -0,0 +1,4 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""All prompts for query."""

View File

@ -0,0 +1,9 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""Global Search system prompts."""
GENERAL_KNOWLEDGE_INSTRUCTION = """
The response may also include relevant real-world knowledge outside the dataset, but it must be explicitly annotated with a verification tag [LLM: verify]. For example:
"This is an example sentence supported by real-world knowledge [LLM: verify]."
"""

View File

@ -81,8 +81,3 @@ Add sections and commentary to the response as appropriate for the length and fo
NO_DATA_ANSWER = (
"I am sorry but I am unable to answer this question given the provided data."
)
GENERAL_KNOWLEDGE_INSTRUCTION = """
The response may also include relevant real-world knowledge outside the dataset, but it must be explicitly annotated with a verification tag [LLM: verify]. For example:
"This is an example sentence supported by real-world knowledge [LLM: verify]."
"""

View File

@ -42,6 +42,7 @@ def get_local_search_engine(
covariates: dict[str, list[Covariate]],
response_type: str,
description_embedding_store: BaseVectorStore,
system_prompt: str | None = None,
) -> LocalSearch:
"""Create a local search engine based on data + configuration."""
llm = get_llm(config)
@ -52,6 +53,7 @@ def get_local_search_engine(
return LocalSearch(
llm=llm,
system_prompt=system_prompt,
context_builder=LocalSearchMixedContext(
community_reports=reports,
text_units=text_units,
@ -95,6 +97,9 @@ def get_global_search_engine(
communities: list[Community],
response_type: str,
dynamic_community_selection: bool = False,
map_system_prompt: str | None = None,
reduce_system_prompt: str | None = None,
general_knowledge_inclusion_prompt: str | None = None,
) -> GlobalSearch:
"""Create a global search engine based on data + configuration."""
token_encoder = tiktoken.get_encoding(config.encoding_model)
@ -118,6 +123,9 @@ def get_global_search_engine(
return GlobalSearch(
llm=get_llm(config),
map_system_prompt=map_system_prompt,
reduce_system_prompt=reduce_system_prompt,
general_knowledge_inclusion_prompt=general_knowledge_inclusion_prompt,
context_builder=GlobalCommunityContext(
community_reports=reports,
communities=communities,
@ -166,6 +174,7 @@ def get_drift_search_engine(
entities: list[Entity],
relationships: list[Relationship],
description_embedding_store: BaseVectorStore,
local_system_prompt: str | None = None,
) -> DRIFTSearch:
"""Create a local search engine based on data + configuration."""
llm = get_llm(config)
@ -182,6 +191,8 @@ def get_drift_search_engine(
reports=reports,
entity_text_embeddings=description_embedding_store,
text_units=text_units,
local_system_prompt=local_system_prompt,
config=config.drift_search,
),
token_encoder=token_encoder,
)

View File

@ -9,6 +9,7 @@ from typing import Any
import tiktoken
from graphrag.prompts.query.question_gen_system_prompt import QUESTION_SYSTEM_PROMPT
from graphrag.query.context_builder.builders import LocalContextBuilder
from graphrag.query.context_builder.conversation_history import (
ConversationHistory,
@ -16,7 +17,6 @@ from graphrag.query.context_builder.conversation_history import (
from graphrag.query.llm.base import BaseLLM, BaseLLMCallback
from graphrag.query.llm.text_utils import num_tokens
from graphrag.query.question_gen.base import BaseQuestionGen, QuestionResult
from graphrag.query.question_gen.system_prompt import QUESTION_SYSTEM_PROMPT
log = logging.getLogger(__name__)

View File

@ -19,14 +19,14 @@ from graphrag.model import (
Relationship,
TextUnit,
)
from graphrag.prompts.query.drift_search_system_prompt import (
DRIFT_LOCAL_SYSTEM_PROMPT,
)
from graphrag.query.context_builder.entity_extraction import EntityVectorStoreKey
from graphrag.query.llm.base import BaseTextEmbedding
from graphrag.query.llm.oai.chat_openai import ChatOpenAI
from graphrag.query.structured_search.base import DRIFTContextBuilder
from graphrag.query.structured_search.drift_search.primer import PrimerQueryProcessor
from graphrag.query.structured_search.drift_search.system_prompt import (
DRIFT_LOCAL_SYSTEM_PROMPT,
)
from graphrag.query.structured_search.local_search.mixed_context import (
LocalSearchMixedContext,
)
@ -51,7 +51,7 @@ class DRIFTSearchContextBuilder(DRIFTContextBuilder):
token_encoder: tiktoken.Encoding | None = None,
embedding_vectorstore_key: str = EntityVectorStoreKey.ID,
config: DRIFTSearchConfig | None = None,
local_system_prompt: str = DRIFT_LOCAL_SYSTEM_PROMPT,
local_system_prompt: str | None = None,
local_mixed_context: LocalSearchMixedContext | None = None,
):
"""Initialize the DRIFT search context builder with necessary components."""
@ -59,7 +59,7 @@ class DRIFTSearchContextBuilder(DRIFTContextBuilder):
self.chat_llm = chat_llm
self.text_embedder = text_embedder
self.token_encoder = token_encoder
self.local_system_prompt = local_system_prompt
self.local_system_prompt = local_system_prompt or DRIFT_LOCAL_SYSTEM_PROMPT
self.entities = entities
self.entity_text_embeddings = entity_text_embeddings

View File

@ -15,13 +15,13 @@ from tqdm.asyncio import tqdm_asyncio
from graphrag.config.models.drift_config import DRIFTSearchConfig
from graphrag.model import CommunityReport
from graphrag.prompts.query.drift_search_system_prompt import (
DRIFT_PRIMER_PROMPT,
)
from graphrag.query.llm.base import BaseTextEmbedding
from graphrag.query.llm.oai.chat_openai import ChatOpenAI
from graphrag.query.llm.text_utils import num_tokens
from graphrag.query.structured_search.base import SearchResult
from graphrag.query.structured_search.drift_search.system_prompt import (
DRIFT_PRIMER_PROMPT,
)
log = logging.getLogger(__name__)

View File

@ -16,6 +16,16 @@ import tiktoken
from graphrag.callbacks.global_search_callbacks import GlobalSearchLLMCallback
from graphrag.llm.openai.utils import try_parse_json_object
from graphrag.prompts.query.global_search_knowledge_system_prompt import (
GENERAL_KNOWLEDGE_INSTRUCTION,
)
from graphrag.prompts.query.global_search_map_system_prompt import (
MAP_SYSTEM_PROMPT,
)
from graphrag.prompts.query.global_search_reduce_system_prompt import (
NO_DATA_ANSWER,
REDUCE_SYSTEM_PROMPT,
)
from graphrag.query.context_builder.builders import GlobalContextBuilder
from graphrag.query.context_builder.conversation_history import (
ConversationHistory,
@ -23,14 +33,6 @@ from graphrag.query.context_builder.conversation_history import (
from graphrag.query.llm.base import BaseLLM
from graphrag.query.llm.text_utils import num_tokens
from graphrag.query.structured_search.base import BaseSearch, SearchResult
from graphrag.query.structured_search.global_search.map_system_prompt import (
MAP_SYSTEM_PROMPT,
)
from graphrag.query.structured_search.global_search.reduce_system_prompt import (
GENERAL_KNOWLEDGE_INSTRUCTION,
NO_DATA_ANSWER,
REDUCE_SYSTEM_PROMPT,
)
DEFAULT_MAP_LLM_PARAMS = {
"max_tokens": 1000,
@ -62,11 +64,11 @@ class GlobalSearch(BaseSearch[GlobalContextBuilder]):
llm: BaseLLM,
context_builder: GlobalContextBuilder,
token_encoder: tiktoken.Encoding | None = None,
map_system_prompt: str = MAP_SYSTEM_PROMPT,
reduce_system_prompt: str = REDUCE_SYSTEM_PROMPT,
map_system_prompt: str | None = None,
reduce_system_prompt: str | None = None,
response_type: str = "multiple paragraphs",
allow_general_knowledge: bool = False,
general_knowledge_inclusion_prompt: str = GENERAL_KNOWLEDGE_INSTRUCTION,
general_knowledge_inclusion_prompt: str | None = None,
json_mode: bool = True,
callbacks: list[GlobalSearchLLMCallback] | None = None,
max_data_tokens: int = 8000,
@ -81,11 +83,13 @@ class GlobalSearch(BaseSearch[GlobalContextBuilder]):
token_encoder=token_encoder,
context_builder_params=context_builder_params,
)
self.map_system_prompt = map_system_prompt
self.reduce_system_prompt = reduce_system_prompt
self.map_system_prompt = map_system_prompt or MAP_SYSTEM_PROMPT
self.reduce_system_prompt = reduce_system_prompt or REDUCE_SYSTEM_PROMPT
self.response_type = response_type
self.allow_general_knowledge = allow_general_knowledge
self.general_knowledge_inclusion_prompt = general_knowledge_inclusion_prompt
self.general_knowledge_inclusion_prompt = (
general_knowledge_inclusion_prompt or GENERAL_KNOWLEDGE_INSTRUCTION
)
self.callbacks = callbacks
self.max_data_tokens = max_data_tokens

View File

@ -10,6 +10,9 @@ from typing import Any
import tiktoken
from graphrag.prompts.query.local_search_system_prompt import (
LOCAL_SEARCH_SYSTEM_PROMPT,
)
from graphrag.query.context_builder.builders import LocalContextBuilder
from graphrag.query.context_builder.conversation_history import (
ConversationHistory,
@ -17,9 +20,6 @@ from graphrag.query.context_builder.conversation_history import (
from graphrag.query.llm.base import BaseLLM, BaseLLMCallback
from graphrag.query.llm.text_utils import num_tokens
from graphrag.query.structured_search.base import BaseSearch, SearchResult
from graphrag.query.structured_search.local_search.system_prompt import (
LOCAL_SEARCH_SYSTEM_PROMPT,
)
DEFAULT_LLM_PARAMS = {
"max_tokens": 1500,
@ -37,7 +37,7 @@ class LocalSearch(BaseSearch[LocalContextBuilder]):
llm: BaseLLM,
context_builder: LocalContextBuilder,
token_encoder: tiktoken.Encoding | None = None,
system_prompt: str = LOCAL_SEARCH_SYSTEM_PROMPT,
system_prompt: str | None = None,
response_type: str = "multiple paragraphs",
callbacks: list[BaseLLMCallback] | None = None,
llm_params: dict[str, Any] = DEFAULT_LLM_PARAMS,
@ -50,7 +50,7 @@ class LocalSearch(BaseSearch[LocalContextBuilder]):
llm_params=llm_params,
context_builder_params=context_builder_params or {},
)
self.system_prompt = system_prompt
self.system_prompt = system_prompt or LOCAL_SEARCH_SYSTEM_PROMPT
self.callbacks = callbacks
self.response_type = response_type

View File

@ -28,6 +28,7 @@ from graphrag.config import (
ClusterGraphConfigInput,
CommunityReportsConfig,
CommunityReportsConfigInput,
DRIFTSearchConfig,
EmbedGraphConfig,
EmbedGraphConfigInput,
EntityExtractionConfig,
@ -202,6 +203,7 @@ class TestDefaultConfig(unittest.TestCase):
assert ClaimExtractionConfig is not None
assert ClusterGraphConfig is not None
assert CommunityReportsConfig is not None
assert DRIFTSearchConfig is not None
assert EmbedGraphConfig is not None
assert EntityExtractionConfig is not None
assert GlobalSearchConfig is not None