mirror of
https://github.com/microsoft/graphrag.git
synced 2026-01-14 00:57:23 +08:00
Move prompts (#1404)
* Move indexing prompts to root * Move query prompts to root * Export query prompts during init * Extract general knowledge prompt * Load query prompts from disk * Semver * Fix unit tests
This commit is contained in:
parent
c8c354e357
commit
51912b2e03
@ -0,0 +1,4 @@
|
||||
{
|
||||
"type": "patch",
|
||||
"description": "Centralized prompts and export all for easier injection."
|
||||
}
|
||||
@ -8,7 +8,7 @@ Each of these prompts may be overridden by writing a custom prompt file in plain
|
||||
|
||||
## Entity/Relationship Extraction
|
||||
|
||||
[Prompt Source](http://github.com/microsoft/graphrag/blob/main/graphrag/index/graph/extractors/graph/prompts.py)
|
||||
[Prompt Source](http://github.com/microsoft/graphrag/blob/main/graphrag/prompts/entity_extraction.py)
|
||||
|
||||
### Tokens (values provided by extractor)
|
||||
|
||||
@ -20,7 +20,7 @@ Each of these prompts may be overridden by writing a custom prompt file in plain
|
||||
|
||||
## Summarize Entity/Relationship Descriptions
|
||||
|
||||
[Prompt Source](http://github.com/microsoft/graphrag/blob/main/graphrag/index/graph/extractors/summarize/prompts.py)
|
||||
[Prompt Source](http://github.com/microsoft/graphrag/blob/main/graphrag/prompts/summarize_descriptions.py)
|
||||
|
||||
### Tokens (values provided by extractor)
|
||||
|
||||
@ -29,7 +29,7 @@ Each of these prompts may be overridden by writing a custom prompt file in plain
|
||||
|
||||
## Claim Extraction
|
||||
|
||||
[Prompt Source](http://github.com/microsoft/graphrag/blob/main/graphrag/index/graph/extractors/claims/prompts.py)
|
||||
[Prompt Source](http://github.com/microsoft/graphrag/blob/main/graphrag/prompts/claim_extraction.py)
|
||||
|
||||
### Tokens (values provided by extractor)
|
||||
|
||||
@ -47,7 +47,7 @@ See the [configuration documentation](../config/overview.md) for details on how
|
||||
|
||||
## Generate Community Reports
|
||||
|
||||
[Prompt Source](http://github.com/microsoft/graphrag/blob/main/graphrag/index/graph/extractors/community_reports/prompts.py)
|
||||
[Prompt Source](http://github.com/microsoft/graphrag/blob/main/graphrag/prompts/community_report.py)
|
||||
|
||||
### Tokens (values provided by extractor)
|
||||
|
||||
|
||||
@ -56,11 +56,11 @@ Below are the key parameters of the [GlobalSearch class](https://github.com/micr
|
||||
|
||||
* `llm`: OpenAI model object to be used for response generation
|
||||
* `context_builder`: [context builder](https://github.com/microsoft/graphrag/blob/main//graphrag/query/structured_search/global_search/community_context.py) object to be used for preparing context data from community reports
|
||||
* `map_system_prompt`: prompt template used in the `map` stage. Default template can be found at [map_system_prompt](https://github.com/microsoft/graphrag/blob/main//graphrag/query/structured_search/global_search/map_system_prompt.py)
|
||||
* `reduce_system_prompt`: prompt template used in the `reduce` stage, default template can be found at [reduce_system_prompt](https://github.com/microsoft/graphrag/blob/main//graphrag/query/structured_search/global_search/reduce_system_prompt.py)
|
||||
* `map_system_prompt`: prompt template used in the `map` stage. Default template can be found at [map_system_prompt](https://github.com/microsoft/graphrag/blob/main//graphrag/prompts/query/global_search_map_system_prompt.py)
|
||||
* `reduce_system_prompt`: prompt template used in the `reduce` stage, default template can be found at [reduce_system_prompt](https://github.com/microsoft/graphrag/blob/main//graphrag/prompts/query/global_search_reduce_system_prompt.py)
|
||||
* `response_type`: free-form text describing the desired response type and format (e.g., `Multiple Paragraphs`, `Multi-Page Report`)
|
||||
* `allow_general_knowledge`: setting this to True will include additional instructions to the `reduce_system_prompt` to prompt the LLM to incorporate relevant real-world knowledge outside of the dataset. Note that this may increase hallucinations, but can be useful for certain scenarios. Default is False
|
||||
*`general_knowledge_inclusion_prompt`: instruction to add to the `reduce_system_prompt` if `allow_general_knowledge` is enabled. Default instruction can be found at [general_knowledge_instruction](https://github.com/microsoft/graphrag/blob/main//graphrag/query/structured_search/global_search/reduce_system_prompt.py)
|
||||
*`general_knowledge_inclusion_prompt`: instruction to add to the `reduce_system_prompt` if `allow_general_knowledge` is enabled. Default instruction can be found at [general_knowledge_instruction](https://github.com/microsoft/graphrag/blob/main//graphrag/prompts/query/global_search_knowledge_system_prompt.py)
|
||||
* `max_data_tokens`: token budget for the context data
|
||||
* `map_llm_params`: a dictionary of additional parameters (e.g., temperature, max_tokens) to be passed to the LLM call at the `map` stage
|
||||
* `reduce_llm_params`: a dictionary of additional parameters (e.g., temperature, max_tokens) to passed to the LLM call at the `reduce` stage
|
||||
|
||||
@ -50,7 +50,7 @@ Below are the key parameters of the [LocalSearch class](https://github.com/micro
|
||||
|
||||
* `llm`: OpenAI model object to be used for response generation
|
||||
* `context_builder`: [context builder](https://github.com/microsoft/graphrag/blob/main//graphrag/query/structured_search/local_search/mixed_context.py) object to be used for preparing context data from collections of knowledge model objects
|
||||
* `system_prompt`: prompt template used to generate the search response. Default template can be found at [system_prompt](https://github.com/microsoft/graphrag/blob/main//graphrag/query/structured_search/local_search/system_prompt.py)
|
||||
* `system_prompt`: prompt template used to generate the search response. Default template can be found at [system_prompt](https://github.com/microsoft/graphrag/blob/main//graphrag/prompts/query/local_search_system_prompt.py)
|
||||
* `response_type`: free-form text describing the desired response type and format (e.g., `Multiple Paragraphs`, `Multi-Page Report`)
|
||||
* `llm_params`: a dictionary of additional parameters (e.g., temperature, max_tokens) to be passed to the LLM call
|
||||
* `context_builder_params`: a dictionary of additional parameters to be passed to the [`context_builder`](https://github.com/microsoft/graphrag/blob/main//graphrag/query/structured_search/local_search/mixed_context.py) object when building context for the search prompt
|
||||
|
||||
@ -13,7 +13,7 @@ Below are the key parameters of the [Question Generation class](https://github.c
|
||||
|
||||
* `llm`: OpenAI model object to be used for response generation
|
||||
* `context_builder`: [context builder](https://github.com/microsoft/graphrag/blob/main//graphrag/query/structured_search/local_search/mixed_context.py) object to be used for preparing context data from collections of knowledge model objects, using the same context builder class as in local search
|
||||
* `system_prompt`: prompt template used to generate candidate questions. Default template can be found at [system_prompt](https://github.com/microsoft/graphrag/blob/main//graphrag/query/question_gen/system_prompt.py)
|
||||
* `system_prompt`: prompt template used to generate candidate questions. Default template can be found at [system_prompt](https://github.com/microsoft/graphrag/blob/main//graphrag/prompts/query/question_gen_system_prompt.py)
|
||||
* `llm_params`: a dictionary of additional parameters (e.g., temperature, max_tokens) to be passed to the LLM call
|
||||
* `context_builder_params`: a dictionary of additional parameters to be passed to the [`context_builder`](https://github.com/microsoft/graphrag/blob/main//graphrag/query/structured_search/local_search/mixed_context.py) object when building context for the question generation prompt
|
||||
* `callbacks`: optional callback functions, can be used to provide custom event handlers for LLM's completion streaming events
|
||||
|
||||
@ -98,6 +98,14 @@ async def global_search(
|
||||
dynamic_community_selection=dynamic_community_selection,
|
||||
)
|
||||
_entities = read_indexer_entities(nodes, entities, community_level=community_level)
|
||||
map_prompt = _load_search_prompt(config.root_dir, config.global_search.map_prompt)
|
||||
reduce_prompt = _load_search_prompt(
|
||||
config.root_dir, config.global_search.reduce_prompt
|
||||
)
|
||||
knowledge_prompt = _load_search_prompt(
|
||||
config.root_dir, config.global_search.knowledge_prompt
|
||||
)
|
||||
|
||||
search_engine = get_global_search_engine(
|
||||
config,
|
||||
reports=reports,
|
||||
@ -105,6 +113,9 @@ async def global_search(
|
||||
communities=_communities,
|
||||
response_type=response_type,
|
||||
dynamic_community_selection=dynamic_community_selection,
|
||||
map_system_prompt=map_prompt,
|
||||
reduce_system_prompt=reduce_prompt,
|
||||
general_knowledge_inclusion_prompt=knowledge_prompt,
|
||||
)
|
||||
result: SearchResult = await search_engine.asearch(query=query)
|
||||
response = result.response
|
||||
@ -156,6 +167,14 @@ async def global_search_streaming(
|
||||
dynamic_community_selection=dynamic_community_selection,
|
||||
)
|
||||
_entities = read_indexer_entities(nodes, entities, community_level=community_level)
|
||||
map_prompt = _load_search_prompt(config.root_dir, config.global_search.map_prompt)
|
||||
reduce_prompt = _load_search_prompt(
|
||||
config.root_dir, config.global_search.reduce_prompt
|
||||
)
|
||||
knowledge_prompt = _load_search_prompt(
|
||||
config.root_dir, config.global_search.knowledge_prompt
|
||||
)
|
||||
|
||||
search_engine = get_global_search_engine(
|
||||
config,
|
||||
reports=reports,
|
||||
@ -163,6 +182,9 @@ async def global_search_streaming(
|
||||
communities=_communities,
|
||||
response_type=response_type,
|
||||
dynamic_community_selection=dynamic_community_selection,
|
||||
map_system_prompt=map_prompt,
|
||||
reduce_system_prompt=reduce_prompt,
|
||||
general_knowledge_inclusion_prompt=knowledge_prompt,
|
||||
)
|
||||
search_result = search_engine.astream_search(query=query)
|
||||
|
||||
@ -238,6 +260,7 @@ async def local_search(
|
||||
|
||||
_entities = read_indexer_entities(nodes, entities, community_level)
|
||||
_covariates = read_indexer_covariates(covariates) if covariates is not None else []
|
||||
prompt = _load_search_prompt(config.root_dir, config.local_search.prompt)
|
||||
|
||||
search_engine = get_local_search_engine(
|
||||
config=config,
|
||||
@ -248,6 +271,7 @@ async def local_search(
|
||||
covariates={"claims": _covariates},
|
||||
description_embedding_store=description_embedding_store, # type: ignore
|
||||
response_type=response_type,
|
||||
system_prompt=prompt,
|
||||
)
|
||||
|
||||
result: SearchResult = await search_engine.asearch(query=query)
|
||||
@ -312,6 +336,7 @@ async def local_search_streaming(
|
||||
|
||||
_entities = read_indexer_entities(nodes, entities, community_level)
|
||||
_covariates = read_indexer_covariates(covariates) if covariates is not None else []
|
||||
prompt = _load_search_prompt(config.root_dir, config.local_search.prompt)
|
||||
|
||||
search_engine = get_local_search_engine(
|
||||
config=config,
|
||||
@ -322,6 +347,7 @@ async def local_search_streaming(
|
||||
covariates={"claims": _covariates},
|
||||
description_embedding_store=description_embedding_store, # type: ignore
|
||||
response_type=response_type,
|
||||
system_prompt=prompt,
|
||||
)
|
||||
search_result = search_engine.astream_search(query=query)
|
||||
|
||||
@ -401,7 +427,7 @@ async def drift_search(
|
||||
_entities = read_indexer_entities(nodes, entities, community_level)
|
||||
_reports = read_indexer_reports(community_reports, nodes, community_level)
|
||||
read_indexer_report_embeddings(_reports, full_content_embedding_store)
|
||||
|
||||
prompt = _load_search_prompt(config.root_dir, config.drift_search.prompt)
|
||||
search_engine = get_drift_search_engine(
|
||||
config=config,
|
||||
reports=_reports,
|
||||
@ -409,6 +435,7 @@ async def drift_search(
|
||||
entities=_entities,
|
||||
relationships=read_indexer_relationships(relationships),
|
||||
description_embedding_store=description_embedding_store, # type: ignore
|
||||
local_system_prompt=prompt,
|
||||
)
|
||||
|
||||
result: SearchResult = await search_engine.asearch(query=query)
|
||||
@ -551,3 +578,17 @@ def _reformat_context_data(context_data: dict) -> dict:
|
||||
continue
|
||||
final_format[key] = records
|
||||
return final_format
|
||||
|
||||
|
||||
def _load_search_prompt(root_dir: str, prompt_config: str | None) -> str | None:
|
||||
"""
|
||||
Load the search prompt from disk if configured.
|
||||
|
||||
If not, leave it empty - the search functions will load their defaults.
|
||||
|
||||
"""
|
||||
if prompt_config:
|
||||
prompt_file = Path(root_dir) / prompt_config
|
||||
if prompt_file.exists():
|
||||
return prompt_file.read_bytes().decode(encoding="utf-8")
|
||||
return None
|
||||
|
||||
@ -5,14 +5,24 @@
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from graphrag.index.graph.extractors.claims.prompts import CLAIM_EXTRACTION_PROMPT
|
||||
from graphrag.index.graph.extractors.community_reports.prompts import (
|
||||
COMMUNITY_REPORT_PROMPT,
|
||||
)
|
||||
from graphrag.index.graph.extractors.graph.prompts import GRAPH_EXTRACTION_PROMPT
|
||||
from graphrag.index.graph.extractors.summarize.prompts import SUMMARIZE_PROMPT
|
||||
from graphrag.index.init_content import INIT_DOTENV, INIT_YAML
|
||||
from graphrag.logging import ReporterType, create_progress_reporter
|
||||
from graphrag.prompts.index.claim_extraction import CLAIM_EXTRACTION_PROMPT
|
||||
from graphrag.prompts.index.community_report import (
|
||||
COMMUNITY_REPORT_PROMPT,
|
||||
)
|
||||
from graphrag.prompts.index.entity_extraction import GRAPH_EXTRACTION_PROMPT
|
||||
from graphrag.prompts.index.summarize_descriptions import SUMMARIZE_PROMPT
|
||||
from graphrag.prompts.query.drift_search_system_prompt import DRIFT_LOCAL_SYSTEM_PROMPT
|
||||
from graphrag.prompts.query.global_search_knowledge_system_prompt import (
|
||||
GENERAL_KNOWLEDGE_INSTRUCTION,
|
||||
)
|
||||
from graphrag.prompts.query.global_search_map_system_prompt import MAP_SYSTEM_PROMPT
|
||||
from graphrag.prompts.query.global_search_reduce_system_prompt import (
|
||||
REDUCE_SYSTEM_PROMPT,
|
||||
)
|
||||
from graphrag.prompts.query.local_search_system_prompt import LOCAL_SEARCH_SYSTEM_PROMPT
|
||||
from graphrag.prompts.query.question_gen_system_prompt import QUESTION_SYSTEM_PROMPT
|
||||
|
||||
|
||||
def initialize_project_at(path: Path) -> None:
|
||||
@ -40,28 +50,21 @@ def initialize_project_at(path: Path) -> None:
|
||||
if not prompts_dir.exists():
|
||||
prompts_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
entity_extraction = prompts_dir / "entity_extraction.txt"
|
||||
if not entity_extraction.exists():
|
||||
with entity_extraction.open("wb") as file:
|
||||
file.write(
|
||||
GRAPH_EXTRACTION_PROMPT.encode(encoding="utf-8", errors="strict")
|
||||
)
|
||||
prompts = {
|
||||
"entity_extraction": GRAPH_EXTRACTION_PROMPT,
|
||||
"summarize_descriptions": SUMMARIZE_PROMPT,
|
||||
"claim_extraction": CLAIM_EXTRACTION_PROMPT,
|
||||
"community_report": COMMUNITY_REPORT_PROMPT,
|
||||
"drift_search_system_prompt": DRIFT_LOCAL_SYSTEM_PROMPT,
|
||||
"global_search_map_system_prompt": MAP_SYSTEM_PROMPT,
|
||||
"global_search_reduce_system_prompt": REDUCE_SYSTEM_PROMPT,
|
||||
"global_search_knowledge_system_prompt": GENERAL_KNOWLEDGE_INSTRUCTION,
|
||||
"local_search_system_prompt": LOCAL_SEARCH_SYSTEM_PROMPT,
|
||||
"question_gen_system_prompt": QUESTION_SYSTEM_PROMPT,
|
||||
}
|
||||
|
||||
summarize_descriptions = prompts_dir / "summarize_descriptions.txt"
|
||||
if not summarize_descriptions.exists():
|
||||
with summarize_descriptions.open("wb") as file:
|
||||
file.write(SUMMARIZE_PROMPT.encode(encoding="utf-8", errors="strict"))
|
||||
|
||||
claim_extraction = prompts_dir / "claim_extraction.txt"
|
||||
if not claim_extraction.exists():
|
||||
with claim_extraction.open("wb") as file:
|
||||
file.write(
|
||||
CLAIM_EXTRACTION_PROMPT.encode(encoding="utf-8", errors="strict")
|
||||
)
|
||||
|
||||
community_report = prompts_dir / "community_report.txt"
|
||||
if not community_report.exists():
|
||||
with community_report.open("wb") as file:
|
||||
file.write(
|
||||
COMMUNITY_REPORT_PROMPT.encode(encoding="utf-8", errors="strict")
|
||||
)
|
||||
for name, content in prompts.items():
|
||||
prompt_file = prompts_dir / f"{name}.txt"
|
||||
if not prompt_file.exists():
|
||||
with prompt_file.open("wb") as file:
|
||||
file.write(content.encode(encoding="utf-8", errors="strict"))
|
||||
|
||||
@ -51,6 +51,7 @@ from .models import (
|
||||
ClaimExtractionConfig,
|
||||
ClusterGraphConfig,
|
||||
CommunityReportsConfig,
|
||||
DRIFTSearchConfig,
|
||||
EmbedGraphConfig,
|
||||
EntityExtractionConfig,
|
||||
GlobalSearchConfig,
|
||||
@ -85,6 +86,7 @@ __all__ = [
|
||||
"ClusterGraphConfigInput",
|
||||
"CommunityReportsConfig",
|
||||
"CommunityReportsConfigInput",
|
||||
"DRIFTSearchConfig",
|
||||
"EmbedGraphConfig",
|
||||
"EmbedGraphConfigInput",
|
||||
"EntityExtractionConfig",
|
||||
|
||||
@ -39,6 +39,7 @@ from .models import (
|
||||
ClaimExtractionConfig,
|
||||
ClusterGraphConfig,
|
||||
CommunityReportsConfig,
|
||||
DRIFTSearchConfig,
|
||||
EmbedGraphConfig,
|
||||
EntityExtractionConfig,
|
||||
GlobalSearchConfig,
|
||||
@ -514,6 +515,7 @@ def create_graphrag_config(
|
||||
reader.envvar_prefix(Section.local_search),
|
||||
):
|
||||
local_search_model = LocalSearchConfig(
|
||||
prompt=reader.str("prompt") or None,
|
||||
text_unit_prop=reader.float("text_unit_prop")
|
||||
or defs.LOCAL_SEARCH_TEXT_UNIT_PROP,
|
||||
community_prop=reader.float("community_prop")
|
||||
@ -541,6 +543,9 @@ def create_graphrag_config(
|
||||
reader.envvar_prefix(Section.global_search),
|
||||
):
|
||||
global_search_model = GlobalSearchConfig(
|
||||
map_prompt=reader.str("map_prompt") or None,
|
||||
reduce_prompt=reader.str("reduce_prompt") or None,
|
||||
knowledge_prompt=reader.str("knowledge_prompt") or None,
|
||||
temperature=reader.float("llm_temperature")
|
||||
or defs.GLOBAL_SEARCH_LLM_TEMPERATURE,
|
||||
top_p=reader.float("llm_top_p") or defs.GLOBAL_SEARCH_LLM_TOP_P,
|
||||
@ -556,6 +561,54 @@ def create_graphrag_config(
|
||||
concurrency=reader.int("concurrency") or defs.GLOBAL_SEARCH_CONCURRENCY,
|
||||
)
|
||||
|
||||
with (
|
||||
reader.use(values.get("drift_search")),
|
||||
reader.envvar_prefix(Section.drift_search),
|
||||
):
|
||||
drift_search_model = DRIFTSearchConfig(
|
||||
prompt=reader.str("prompt") or None,
|
||||
temperature=reader.float("llm_temperature")
|
||||
or defs.DRIFT_SEARCH_LLM_TEMPERATURE,
|
||||
top_p=reader.float("llm_top_p") or defs.DRIFT_SEARCH_LLM_TOP_P,
|
||||
n=reader.int("llm_n") or defs.DRIFT_SEARCH_LLM_N,
|
||||
max_tokens=reader.int(Fragment.max_tokens)
|
||||
or defs.DRIFT_SEARCH_MAX_TOKENS,
|
||||
data_max_tokens=reader.int("data_max_tokens")
|
||||
or defs.DRIFT_SEARCH_DATA_MAX_TOKENS,
|
||||
concurrency=reader.int("concurrency") or defs.DRIFT_SEARCH_CONCURRENCY,
|
||||
drift_k_followups=reader.int("drift_k_followups")
|
||||
or defs.DRIFT_SEARCH_K_FOLLOW_UPS,
|
||||
primer_folds=reader.int("primer_folds")
|
||||
or defs.DRIFT_SEARCH_PRIMER_FOLDS,
|
||||
primer_llm_max_tokens=reader.int("primer_llm_max_tokens")
|
||||
or defs.DRIFT_SEARCH_PRIMER_MAX_TOKENS,
|
||||
n_depth=reader.int("n_depth") or defs.DRIFT_N_DEPTH,
|
||||
local_search_text_unit_prop=reader.float("local_search_text_unit_prop")
|
||||
or defs.DRIFT_LOCAL_SEARCH_TEXT_UNIT_PROP,
|
||||
local_search_community_prop=reader.float("local_search_community_prop")
|
||||
or defs.DRIFT_LOCAL_SEARCH_COMMUNITY_PROP,
|
||||
local_search_top_k_mapped_entities=reader.int(
|
||||
"local_search_top_k_mapped_entities"
|
||||
)
|
||||
or defs.DRIFT_LOCAL_SEARCH_TOP_K_MAPPED_ENTITIES,
|
||||
local_search_top_k_relationships=reader.int(
|
||||
"local_search_top_k_relationships"
|
||||
)
|
||||
or defs.DRIFT_LOCAL_SEARCH_TOP_K_RELATIONSHIPS,
|
||||
local_search_max_data_tokens=reader.int("local_search_max_data_tokens")
|
||||
or defs.DRIFT_LOCAL_SEARCH_MAX_TOKENS,
|
||||
local_search_temperature=reader.float("local_search_temperature")
|
||||
or defs.DRIFT_LOCAL_SEARCH_LLM_TEMPERATURE,
|
||||
local_search_top_p=reader.float("local_search_top_p")
|
||||
or defs.DRIFT_LOCAL_SEARCH_LLM_TOP_P,
|
||||
local_search_n=reader.int("local_search_n")
|
||||
or defs.DRIFT_LOCAL_SEARCH_LLM_N,
|
||||
local_search_llm_max_gen_tokens=reader.int(
|
||||
"local_search_llm_max_gen_tokens"
|
||||
)
|
||||
or defs.DRIFT_LOCAL_SEARCH_LLM_MAX_TOKENS,
|
||||
)
|
||||
|
||||
encoding_model = reader.str(Fragment.encoding_model) or defs.ENCODING_MODEL
|
||||
skip_workflows = reader.list("skip_workflows") or []
|
||||
|
||||
@ -583,6 +636,7 @@ def create_graphrag_config(
|
||||
skip_workflows=skip_workflows,
|
||||
local_search=local_search_model,
|
||||
global_search=global_search_model,
|
||||
drift_search=drift_search_model,
|
||||
)
|
||||
|
||||
|
||||
@ -649,6 +703,7 @@ class Section(str, Enum):
|
||||
update_index_storage = "UPDATE_INDEX_STORAGE"
|
||||
local_search = "LOCAL_SEARCH"
|
||||
global_search = "GLOBAL_SEARCH"
|
||||
drift_search = "DRIFT_SEARCH"
|
||||
|
||||
|
||||
def _is_azure(llm_type: LLMType | None) -> bool:
|
||||
|
||||
@ -11,6 +11,9 @@ import graphrag.config.defaults as defs
|
||||
class DRIFTSearchConfig(BaseModel):
|
||||
"""The default configuration section for Cache."""
|
||||
|
||||
prompt: str | None = Field(
|
||||
description="The drift search prompt to use.", default=None
|
||||
)
|
||||
temperature: float = Field(
|
||||
description="The temperature to use for token generation.",
|
||||
default=defs.DRIFT_SEARCH_LLM_TEMPERATURE,
|
||||
|
||||
@ -11,6 +11,15 @@ import graphrag.config.defaults as defs
|
||||
class GlobalSearchConfig(BaseModel):
|
||||
"""The default configuration section for Cache."""
|
||||
|
||||
map_prompt: str | None = Field(
|
||||
description="The global search mapper prompt to use.", default=None
|
||||
)
|
||||
reduce_prompt: str | None = Field(
|
||||
description="The global search reducer to use.", default=None
|
||||
)
|
||||
knowledge_prompt: str | None = Field(
|
||||
description="The global search general prompt to use.", default=None
|
||||
)
|
||||
temperature: float | None = Field(
|
||||
description="The temperature to use for token generation.",
|
||||
default=defs.GLOBAL_SEARCH_LLM_TEMPERATURE,
|
||||
|
||||
@ -13,6 +13,7 @@ from .chunking_config import ChunkingConfig
|
||||
from .claim_extraction_config import ClaimExtractionConfig
|
||||
from .cluster_graph_config import ClusterGraphConfig
|
||||
from .community_reports_config import CommunityReportsConfig
|
||||
from .drift_config import DRIFTSearchConfig
|
||||
from .embed_graph_config import EmbedGraphConfig
|
||||
from .entity_extraction_config import EntityExtractionConfig
|
||||
from .global_search_config import GlobalSearchConfig
|
||||
@ -141,6 +142,11 @@ class GraphRagConfig(LLMConfig):
|
||||
)
|
||||
"""The global search configuration."""
|
||||
|
||||
drift_search: DRIFTSearchConfig = Field(
|
||||
description="The drift search configuration.", default=DRIFTSearchConfig()
|
||||
)
|
||||
"""The drift search configuration."""
|
||||
|
||||
encoding_model: str = Field(
|
||||
description="The encoding model to use.", default=defs.ENCODING_MODEL
|
||||
)
|
||||
|
||||
@ -11,6 +11,9 @@ import graphrag.config.defaults as defs
|
||||
class LocalSearchConfig(BaseModel):
|
||||
"""The default configuration section for Cache."""
|
||||
|
||||
prompt: str | None = Field(
|
||||
description="The local search prompt to use.", default=None
|
||||
)
|
||||
text_unit_prop: float = Field(
|
||||
description="The text unit proportion.",
|
||||
default=defs.LOCAL_SEARCH_TEXT_UNIT_PROP,
|
||||
|
||||
@ -3,16 +3,13 @@
|
||||
|
||||
"""The Indexing Engine graph extractors package root."""
|
||||
|
||||
from .claims import CLAIM_EXTRACTION_PROMPT, ClaimExtractor
|
||||
from .claims import ClaimExtractor
|
||||
from .community_reports import (
|
||||
COMMUNITY_REPORT_PROMPT,
|
||||
CommunityReportsExtractor,
|
||||
)
|
||||
from .graph import GraphExtractionResult, GraphExtractor
|
||||
|
||||
__all__ = [
|
||||
"CLAIM_EXTRACTION_PROMPT",
|
||||
"COMMUNITY_REPORT_PROMPT",
|
||||
"ClaimExtractor",
|
||||
"CommunityReportsExtractor",
|
||||
"GraphExtractionResult",
|
||||
|
||||
@ -4,6 +4,5 @@
|
||||
"""The Indexing Engine graph extractors claims package root."""
|
||||
|
||||
from .claim_extractor import ClaimExtractor
|
||||
from .prompts import CLAIM_EXTRACTION_PROMPT
|
||||
|
||||
__all__ = ["CLAIM_EXTRACTION_PROMPT", "ClaimExtractor"]
|
||||
__all__ = ["ClaimExtractor"]
|
||||
|
||||
@ -13,8 +13,7 @@ import tiktoken
|
||||
import graphrag.config.defaults as defs
|
||||
from graphrag.index.typing import ErrorHandlerFn
|
||||
from graphrag.llm import CompletionLLM
|
||||
|
||||
from .prompts import (
|
||||
from graphrag.prompts.index.claim_extraction import (
|
||||
CLAIM_EXTRACTION_PROMPT,
|
||||
CONTINUE_PROMPT,
|
||||
LOOP_PROMPT,
|
||||
|
||||
@ -8,7 +8,6 @@ import graphrag.index.graph.extractors.community_reports.schemas as schemas
|
||||
from .build_mixed_context import build_mixed_context
|
||||
from .community_reports_extractor import CommunityReportsExtractor
|
||||
from .prep_community_report_context import prep_community_report_context
|
||||
from .prompts import COMMUNITY_REPORT_PROMPT
|
||||
from .sort_context import sort_context
|
||||
from .utils import (
|
||||
filter_claims_to_nodes,
|
||||
@ -20,7 +19,6 @@ from .utils import (
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"COMMUNITY_REPORT_PROMPT",
|
||||
"CommunityReportsExtractor",
|
||||
"build_mixed_context",
|
||||
"filter_claims_to_nodes",
|
||||
|
||||
@ -11,8 +11,7 @@ from typing import Any
|
||||
from graphrag.index.typing import ErrorHandlerFn
|
||||
from graphrag.index.utils import dict_has_keys_with_types
|
||||
from graphrag.llm import CompletionLLM
|
||||
|
||||
from .prompts import COMMUNITY_REPORT_PROMPT
|
||||
from graphrag.prompts.index.community_report import COMMUNITY_REPORT_PROMPT
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@ -8,11 +8,9 @@ from .graph_extractor import (
|
||||
GraphExtractionResult,
|
||||
GraphExtractor,
|
||||
)
|
||||
from .prompts import GRAPH_EXTRACTION_PROMPT
|
||||
|
||||
__all__ = [
|
||||
"DEFAULT_ENTITY_TYPES",
|
||||
"GRAPH_EXTRACTION_PROMPT",
|
||||
"GraphExtractionResult",
|
||||
"GraphExtractor",
|
||||
]
|
||||
|
||||
@ -17,8 +17,11 @@ import graphrag.config.defaults as defs
|
||||
from graphrag.index.typing import ErrorHandlerFn
|
||||
from graphrag.index.utils import clean_str
|
||||
from graphrag.llm import CompletionLLM
|
||||
|
||||
from .prompts import CONTINUE_PROMPT, GRAPH_EXTRACTION_PROMPT, LOOP_PROMPT
|
||||
from graphrag.prompts.index.entity_extraction import (
|
||||
CONTINUE_PROMPT,
|
||||
GRAPH_EXTRACTION_PROMPT,
|
||||
LOOP_PROMPT,
|
||||
)
|
||||
|
||||
DEFAULT_TUPLE_DELIMITER = "<|>"
|
||||
DEFAULT_RECORD_DELIMITER = "##"
|
||||
|
||||
@ -7,6 +7,5 @@ from .description_summary_extractor import (
|
||||
SummarizationResult,
|
||||
SummarizeExtractor,
|
||||
)
|
||||
from .prompts import SUMMARIZE_PROMPT
|
||||
|
||||
__all__ = ["SUMMARIZE_PROMPT", "SummarizationResult", "SummarizeExtractor"]
|
||||
__all__ = ["SummarizationResult", "SummarizeExtractor"]
|
||||
|
||||
@ -9,8 +9,7 @@ from dataclasses import dataclass
|
||||
from graphrag.index.typing import ErrorHandlerFn
|
||||
from graphrag.index.utils.tokens import num_tokens_from_string
|
||||
from graphrag.llm import CompletionLLM
|
||||
|
||||
from .prompts import SUMMARIZE_PROMPT
|
||||
from graphrag.prompts.index.summarize_descriptions import SUMMARIZE_PROMPT
|
||||
|
||||
# Max token size for input prompts
|
||||
DEFAULT_MAX_INPUT_TOKENS = 4_000
|
||||
|
||||
@ -158,6 +158,7 @@ snapshots:
|
||||
transient: false
|
||||
|
||||
local_search:
|
||||
prompt: "prompts/local_search_system_prompt.txt"
|
||||
# text_unit_prop: {defs.LOCAL_SEARCH_TEXT_UNIT_PROP}
|
||||
# community_prop: {defs.LOCAL_SEARCH_COMMUNITY_PROP}
|
||||
# conversation_history_max_turns: {defs.LOCAL_SEARCH_CONVERSATION_HISTORY_MAX_TURNS}
|
||||
@ -169,6 +170,9 @@ local_search:
|
||||
# max_tokens: {defs.LOCAL_SEARCH_MAX_TOKENS}
|
||||
|
||||
global_search:
|
||||
map_prompt: "prompts/global_search_map_system_prompt.txt"
|
||||
reduce_prompt: "prompts/global_search_reduce_system_prompt.txt"
|
||||
knowledge_prompt: "prompts/global_search_knowledge_system_prompt.txt"
|
||||
# llm_temperature: {defs.GLOBAL_SEARCH_LLM_TEMPERATURE} # temperature for sampling
|
||||
# llm_top_p: {defs.GLOBAL_SEARCH_LLM_TOP_P} # top-p sampling
|
||||
# llm_n: {defs.GLOBAL_SEARCH_LLM_N} # Number of completions to generate
|
||||
@ -177,6 +181,28 @@ global_search:
|
||||
# map_max_tokens: {defs.GLOBAL_SEARCH_MAP_MAX_TOKENS}
|
||||
# reduce_max_tokens: {defs.GLOBAL_SEARCH_REDUCE_MAX_TOKENS}
|
||||
# concurrency: {defs.GLOBAL_SEARCH_CONCURRENCY}
|
||||
|
||||
drift_search:
|
||||
prompt: "prompts/drift_search_system_prompt.txt"
|
||||
# temperature: {defs.DRIFT_SEARCH_LLM_TEMPERATURE}
|
||||
# top_p: {defs.DRIFT_SEARCH_LLM_TOP_P}
|
||||
# n: {defs.DRIFT_SEARCH_LLM_N}
|
||||
# max_tokens: {defs.DRIFT_SEARCH_MAX_TOKENS}
|
||||
# data_max_tokens: {defs.DRIFT_SEARCH_DATA_MAX_TOKENS}
|
||||
# concurrency: {defs.DRIFT_SEARCH_CONCURRENCY}
|
||||
# drift_k_followups: {defs.DRIFT_SEARCH_K_FOLLOW_UPS}
|
||||
# primer_folds: {defs.DRIFT_SEARCH_PRIMER_FOLDS}
|
||||
# primer_llm_max_tokens: {defs.DRIFT_SEARCH_PRIMER_MAX_TOKENS}
|
||||
# n_depth: {defs.DRIFT_N_DEPTH}
|
||||
# local_search_text_unit_prop: {defs.DRIFT_LOCAL_SEARCH_TEXT_UNIT_PROP}
|
||||
# local_search_community_prop: {defs.DRIFT_LOCAL_SEARCH_COMMUNITY_PROP}
|
||||
# local_search_top_k_mapped_entities: {defs.DRIFT_LOCAL_SEARCH_TOP_K_MAPPED_ENTITIES}
|
||||
# local_search_top_k_relationships: {defs.DRIFT_LOCAL_SEARCH_TOP_K_RELATIONSHIPS}
|
||||
# local_search_max_data_tokens: {defs.DRIFT_LOCAL_SEARCH_MAX_TOKENS}
|
||||
# local_search_temperature: {defs.DRIFT_LOCAL_SEARCH_LLM_TEMPERATURE}
|
||||
# local_search_top_p: {defs.DRIFT_LOCAL_SEARCH_LLM_TOP_P}
|
||||
# local_search_n: {defs.DRIFT_LOCAL_SEARCH_LLM_N}
|
||||
# local_search_llm_max_gen_tokens: {defs.DRIFT_LOCAL_SEARCH_LLM_MAX_TOKENS}
|
||||
"""
|
||||
|
||||
INIT_DOTENV = """\
|
||||
|
||||
4
graphrag/prompts/__init__.py
Normal file
4
graphrag/prompts/__init__.py
Normal file
@ -0,0 +1,4 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""All prompts for the system."""
|
||||
4
graphrag/prompts/index/__init__.py
Normal file
4
graphrag/prompts/index/__init__.py
Normal file
@ -0,0 +1,4 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""All prompts for indexing."""
|
||||
4
graphrag/prompts/query/__init__.py
Normal file
4
graphrag/prompts/query/__init__.py
Normal file
@ -0,0 +1,4 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""All prompts for query."""
|
||||
@ -0,0 +1,9 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""Global Search system prompts."""
|
||||
|
||||
GENERAL_KNOWLEDGE_INSTRUCTION = """
|
||||
The response may also include relevant real-world knowledge outside the dataset, but it must be explicitly annotated with a verification tag [LLM: verify]. For example:
|
||||
"This is an example sentence supported by real-world knowledge [LLM: verify]."
|
||||
"""
|
||||
@ -81,8 +81,3 @@ Add sections and commentary to the response as appropriate for the length and fo
|
||||
NO_DATA_ANSWER = (
|
||||
"I am sorry but I am unable to answer this question given the provided data."
|
||||
)
|
||||
|
||||
GENERAL_KNOWLEDGE_INSTRUCTION = """
|
||||
The response may also include relevant real-world knowledge outside the dataset, but it must be explicitly annotated with a verification tag [LLM: verify]. For example:
|
||||
"This is an example sentence supported by real-world knowledge [LLM: verify]."
|
||||
"""
|
||||
@ -42,6 +42,7 @@ def get_local_search_engine(
|
||||
covariates: dict[str, list[Covariate]],
|
||||
response_type: str,
|
||||
description_embedding_store: BaseVectorStore,
|
||||
system_prompt: str | None = None,
|
||||
) -> LocalSearch:
|
||||
"""Create a local search engine based on data + configuration."""
|
||||
llm = get_llm(config)
|
||||
@ -52,6 +53,7 @@ def get_local_search_engine(
|
||||
|
||||
return LocalSearch(
|
||||
llm=llm,
|
||||
system_prompt=system_prompt,
|
||||
context_builder=LocalSearchMixedContext(
|
||||
community_reports=reports,
|
||||
text_units=text_units,
|
||||
@ -95,6 +97,9 @@ def get_global_search_engine(
|
||||
communities: list[Community],
|
||||
response_type: str,
|
||||
dynamic_community_selection: bool = False,
|
||||
map_system_prompt: str | None = None,
|
||||
reduce_system_prompt: str | None = None,
|
||||
general_knowledge_inclusion_prompt: str | None = None,
|
||||
) -> GlobalSearch:
|
||||
"""Create a global search engine based on data + configuration."""
|
||||
token_encoder = tiktoken.get_encoding(config.encoding_model)
|
||||
@ -118,6 +123,9 @@ def get_global_search_engine(
|
||||
|
||||
return GlobalSearch(
|
||||
llm=get_llm(config),
|
||||
map_system_prompt=map_system_prompt,
|
||||
reduce_system_prompt=reduce_system_prompt,
|
||||
general_knowledge_inclusion_prompt=general_knowledge_inclusion_prompt,
|
||||
context_builder=GlobalCommunityContext(
|
||||
community_reports=reports,
|
||||
communities=communities,
|
||||
@ -166,6 +174,7 @@ def get_drift_search_engine(
|
||||
entities: list[Entity],
|
||||
relationships: list[Relationship],
|
||||
description_embedding_store: BaseVectorStore,
|
||||
local_system_prompt: str | None = None,
|
||||
) -> DRIFTSearch:
|
||||
"""Create a local search engine based on data + configuration."""
|
||||
llm = get_llm(config)
|
||||
@ -182,6 +191,8 @@ def get_drift_search_engine(
|
||||
reports=reports,
|
||||
entity_text_embeddings=description_embedding_store,
|
||||
text_units=text_units,
|
||||
local_system_prompt=local_system_prompt,
|
||||
config=config.drift_search,
|
||||
),
|
||||
token_encoder=token_encoder,
|
||||
)
|
||||
|
||||
@ -9,6 +9,7 @@ from typing import Any
|
||||
|
||||
import tiktoken
|
||||
|
||||
from graphrag.prompts.query.question_gen_system_prompt import QUESTION_SYSTEM_PROMPT
|
||||
from graphrag.query.context_builder.builders import LocalContextBuilder
|
||||
from graphrag.query.context_builder.conversation_history import (
|
||||
ConversationHistory,
|
||||
@ -16,7 +17,6 @@ from graphrag.query.context_builder.conversation_history import (
|
||||
from graphrag.query.llm.base import BaseLLM, BaseLLMCallback
|
||||
from graphrag.query.llm.text_utils import num_tokens
|
||||
from graphrag.query.question_gen.base import BaseQuestionGen, QuestionResult
|
||||
from graphrag.query.question_gen.system_prompt import QUESTION_SYSTEM_PROMPT
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@ -19,14 +19,14 @@ from graphrag.model import (
|
||||
Relationship,
|
||||
TextUnit,
|
||||
)
|
||||
from graphrag.prompts.query.drift_search_system_prompt import (
|
||||
DRIFT_LOCAL_SYSTEM_PROMPT,
|
||||
)
|
||||
from graphrag.query.context_builder.entity_extraction import EntityVectorStoreKey
|
||||
from graphrag.query.llm.base import BaseTextEmbedding
|
||||
from graphrag.query.llm.oai.chat_openai import ChatOpenAI
|
||||
from graphrag.query.structured_search.base import DRIFTContextBuilder
|
||||
from graphrag.query.structured_search.drift_search.primer import PrimerQueryProcessor
|
||||
from graphrag.query.structured_search.drift_search.system_prompt import (
|
||||
DRIFT_LOCAL_SYSTEM_PROMPT,
|
||||
)
|
||||
from graphrag.query.structured_search.local_search.mixed_context import (
|
||||
LocalSearchMixedContext,
|
||||
)
|
||||
@ -51,7 +51,7 @@ class DRIFTSearchContextBuilder(DRIFTContextBuilder):
|
||||
token_encoder: tiktoken.Encoding | None = None,
|
||||
embedding_vectorstore_key: str = EntityVectorStoreKey.ID,
|
||||
config: DRIFTSearchConfig | None = None,
|
||||
local_system_prompt: str = DRIFT_LOCAL_SYSTEM_PROMPT,
|
||||
local_system_prompt: str | None = None,
|
||||
local_mixed_context: LocalSearchMixedContext | None = None,
|
||||
):
|
||||
"""Initialize the DRIFT search context builder with necessary components."""
|
||||
@ -59,7 +59,7 @@ class DRIFTSearchContextBuilder(DRIFTContextBuilder):
|
||||
self.chat_llm = chat_llm
|
||||
self.text_embedder = text_embedder
|
||||
self.token_encoder = token_encoder
|
||||
self.local_system_prompt = local_system_prompt
|
||||
self.local_system_prompt = local_system_prompt or DRIFT_LOCAL_SYSTEM_PROMPT
|
||||
|
||||
self.entities = entities
|
||||
self.entity_text_embeddings = entity_text_embeddings
|
||||
|
||||
@ -15,13 +15,13 @@ from tqdm.asyncio import tqdm_asyncio
|
||||
|
||||
from graphrag.config.models.drift_config import DRIFTSearchConfig
|
||||
from graphrag.model import CommunityReport
|
||||
from graphrag.prompts.query.drift_search_system_prompt import (
|
||||
DRIFT_PRIMER_PROMPT,
|
||||
)
|
||||
from graphrag.query.llm.base import BaseTextEmbedding
|
||||
from graphrag.query.llm.oai.chat_openai import ChatOpenAI
|
||||
from graphrag.query.llm.text_utils import num_tokens
|
||||
from graphrag.query.structured_search.base import SearchResult
|
||||
from graphrag.query.structured_search.drift_search.system_prompt import (
|
||||
DRIFT_PRIMER_PROMPT,
|
||||
)
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@ -16,6 +16,16 @@ import tiktoken
|
||||
|
||||
from graphrag.callbacks.global_search_callbacks import GlobalSearchLLMCallback
|
||||
from graphrag.llm.openai.utils import try_parse_json_object
|
||||
from graphrag.prompts.query.global_search_knowledge_system_prompt import (
|
||||
GENERAL_KNOWLEDGE_INSTRUCTION,
|
||||
)
|
||||
from graphrag.prompts.query.global_search_map_system_prompt import (
|
||||
MAP_SYSTEM_PROMPT,
|
||||
)
|
||||
from graphrag.prompts.query.global_search_reduce_system_prompt import (
|
||||
NO_DATA_ANSWER,
|
||||
REDUCE_SYSTEM_PROMPT,
|
||||
)
|
||||
from graphrag.query.context_builder.builders import GlobalContextBuilder
|
||||
from graphrag.query.context_builder.conversation_history import (
|
||||
ConversationHistory,
|
||||
@ -23,14 +33,6 @@ from graphrag.query.context_builder.conversation_history import (
|
||||
from graphrag.query.llm.base import BaseLLM
|
||||
from graphrag.query.llm.text_utils import num_tokens
|
||||
from graphrag.query.structured_search.base import BaseSearch, SearchResult
|
||||
from graphrag.query.structured_search.global_search.map_system_prompt import (
|
||||
MAP_SYSTEM_PROMPT,
|
||||
)
|
||||
from graphrag.query.structured_search.global_search.reduce_system_prompt import (
|
||||
GENERAL_KNOWLEDGE_INSTRUCTION,
|
||||
NO_DATA_ANSWER,
|
||||
REDUCE_SYSTEM_PROMPT,
|
||||
)
|
||||
|
||||
DEFAULT_MAP_LLM_PARAMS = {
|
||||
"max_tokens": 1000,
|
||||
@ -62,11 +64,11 @@ class GlobalSearch(BaseSearch[GlobalContextBuilder]):
|
||||
llm: BaseLLM,
|
||||
context_builder: GlobalContextBuilder,
|
||||
token_encoder: tiktoken.Encoding | None = None,
|
||||
map_system_prompt: str = MAP_SYSTEM_PROMPT,
|
||||
reduce_system_prompt: str = REDUCE_SYSTEM_PROMPT,
|
||||
map_system_prompt: str | None = None,
|
||||
reduce_system_prompt: str | None = None,
|
||||
response_type: str = "multiple paragraphs",
|
||||
allow_general_knowledge: bool = False,
|
||||
general_knowledge_inclusion_prompt: str = GENERAL_KNOWLEDGE_INSTRUCTION,
|
||||
general_knowledge_inclusion_prompt: str | None = None,
|
||||
json_mode: bool = True,
|
||||
callbacks: list[GlobalSearchLLMCallback] | None = None,
|
||||
max_data_tokens: int = 8000,
|
||||
@ -81,11 +83,13 @@ class GlobalSearch(BaseSearch[GlobalContextBuilder]):
|
||||
token_encoder=token_encoder,
|
||||
context_builder_params=context_builder_params,
|
||||
)
|
||||
self.map_system_prompt = map_system_prompt
|
||||
self.reduce_system_prompt = reduce_system_prompt
|
||||
self.map_system_prompt = map_system_prompt or MAP_SYSTEM_PROMPT
|
||||
self.reduce_system_prompt = reduce_system_prompt or REDUCE_SYSTEM_PROMPT
|
||||
self.response_type = response_type
|
||||
self.allow_general_knowledge = allow_general_knowledge
|
||||
self.general_knowledge_inclusion_prompt = general_knowledge_inclusion_prompt
|
||||
self.general_knowledge_inclusion_prompt = (
|
||||
general_knowledge_inclusion_prompt or GENERAL_KNOWLEDGE_INSTRUCTION
|
||||
)
|
||||
self.callbacks = callbacks
|
||||
self.max_data_tokens = max_data_tokens
|
||||
|
||||
|
||||
@ -10,6 +10,9 @@ from typing import Any
|
||||
|
||||
import tiktoken
|
||||
|
||||
from graphrag.prompts.query.local_search_system_prompt import (
|
||||
LOCAL_SEARCH_SYSTEM_PROMPT,
|
||||
)
|
||||
from graphrag.query.context_builder.builders import LocalContextBuilder
|
||||
from graphrag.query.context_builder.conversation_history import (
|
||||
ConversationHistory,
|
||||
@ -17,9 +20,6 @@ from graphrag.query.context_builder.conversation_history import (
|
||||
from graphrag.query.llm.base import BaseLLM, BaseLLMCallback
|
||||
from graphrag.query.llm.text_utils import num_tokens
|
||||
from graphrag.query.structured_search.base import BaseSearch, SearchResult
|
||||
from graphrag.query.structured_search.local_search.system_prompt import (
|
||||
LOCAL_SEARCH_SYSTEM_PROMPT,
|
||||
)
|
||||
|
||||
DEFAULT_LLM_PARAMS = {
|
||||
"max_tokens": 1500,
|
||||
@ -37,7 +37,7 @@ class LocalSearch(BaseSearch[LocalContextBuilder]):
|
||||
llm: BaseLLM,
|
||||
context_builder: LocalContextBuilder,
|
||||
token_encoder: tiktoken.Encoding | None = None,
|
||||
system_prompt: str = LOCAL_SEARCH_SYSTEM_PROMPT,
|
||||
system_prompt: str | None = None,
|
||||
response_type: str = "multiple paragraphs",
|
||||
callbacks: list[BaseLLMCallback] | None = None,
|
||||
llm_params: dict[str, Any] = DEFAULT_LLM_PARAMS,
|
||||
@ -50,7 +50,7 @@ class LocalSearch(BaseSearch[LocalContextBuilder]):
|
||||
llm_params=llm_params,
|
||||
context_builder_params=context_builder_params or {},
|
||||
)
|
||||
self.system_prompt = system_prompt
|
||||
self.system_prompt = system_prompt or LOCAL_SEARCH_SYSTEM_PROMPT
|
||||
self.callbacks = callbacks
|
||||
self.response_type = response_type
|
||||
|
||||
|
||||
@ -28,6 +28,7 @@ from graphrag.config import (
|
||||
ClusterGraphConfigInput,
|
||||
CommunityReportsConfig,
|
||||
CommunityReportsConfigInput,
|
||||
DRIFTSearchConfig,
|
||||
EmbedGraphConfig,
|
||||
EmbedGraphConfigInput,
|
||||
EntityExtractionConfig,
|
||||
@ -202,6 +203,7 @@ class TestDefaultConfig(unittest.TestCase):
|
||||
assert ClaimExtractionConfig is not None
|
||||
assert ClusterGraphConfig is not None
|
||||
assert CommunityReportsConfig is not None
|
||||
assert DRIFTSearchConfig is not None
|
||||
assert EmbedGraphConfig is not None
|
||||
assert EntityExtractionConfig is not None
|
||||
assert GlobalSearchConfig is not None
|
||||
|
||||
Loading…
Reference in New Issue
Block a user