graphrag/graphrag/query/structured_search/basic_search/basic_context.py
Derek Worthen 2b70e4a4f3
Tokenizer (#2051)
* Add LiteLLM chat and embedding model providers.

* Fix code review findings.

* Add litellm.

* Fix formatting.

* Update dictionary.

* Update litellm.

* Fix embedding.

* Remove manual use of tiktoken and replace with
Tokenizer interface. Adds support for encoding
and decoding the models supported by litellm.

* Update litellm.

* Configure litellm to drop unsupported params.

* Cleanup semversioner release notes.

* Add num_tokens util to Tokenizer interface.

* Update litellm service factories.

* Cleanup litellm chat/embedding model argument assignment.

* Update chat and embedding type field for litellm use and future migration away from fnllm.

* Flatten litellm service organization.

* Update litellm.

* Update litellm factory validation.

* Flatten litellm rate limit service organization.

* Update rate limiter - disable with None/null instead of 0.

* Fix usage of get_tokenizer.

* Update litellm service registrations.

* Add jitter to exponential retry.

* Update validation.

* Update validation.

* Add litellm request logging layer.

* Update cache key.

* Update defaults.

---------

Co-authored-by: Alonso Guevara <alonsog@microsoft.com>
2025-09-22 13:55:14 -06:00

114 lines
3.9 KiB
Python

# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""Basic Context Builder implementation."""
import logging
from typing import cast
import pandas as pd
from graphrag.data_model.text_unit import TextUnit
from graphrag.language_model.protocol.base import EmbeddingModel
from graphrag.query.context_builder.builders import (
BasicContextBuilder,
ContextBuilderResult,
)
from graphrag.query.context_builder.conversation_history import ConversationHistory
from graphrag.tokenizer.get_tokenizer import get_tokenizer
from graphrag.tokenizer.tokenizer import Tokenizer
from graphrag.vector_stores.base import BaseVectorStore
logger = logging.getLogger(__name__)
class BasicSearchContext(BasicContextBuilder):
"""Class representing the Basic Search Context Builder."""
def __init__(
self,
text_embedder: EmbeddingModel,
text_unit_embeddings: BaseVectorStore,
text_units: list[TextUnit] | None = None,
tokenizer: Tokenizer | None = None,
embedding_vectorstore_key: str = "id",
):
self.text_embedder = text_embedder
self.tokenizer = tokenizer or get_tokenizer()
self.text_units = text_units
self.text_unit_embeddings = text_unit_embeddings
self.embedding_vectorstore_key = embedding_vectorstore_key
self.text_id_map = self._map_ids()
def build_context(
self,
query: str,
conversation_history: ConversationHistory | None = None,
k: int = 10,
max_context_tokens: int = 12_000,
context_name: str = "Sources",
column_delimiter: str = "|",
text_id_col: str = "source_id",
text_col: str = "text",
**kwargs,
) -> ContextBuilderResult:
"""Build the context for the basic search mode."""
if query != "":
related_texts = self.text_unit_embeddings.similarity_search_by_text(
text=query,
text_embedder=lambda t: self.text_embedder.embed(t),
k=k,
)
related_text_list = [
{
text_id_col: self.text_id_map[f"{chunk.document.id}"],
text_col: chunk.document.text,
}
for chunk in related_texts
]
related_text_df = pd.DataFrame(related_text_list)
else:
related_text_df = pd.DataFrame({
text_id_col: [],
text_col: [],
})
# add these related text chunks into context until we fill up the context window
current_tokens = 0
text_ids = []
current_tokens = len(
self.tokenizer.encode(text_id_col + column_delimiter + text_col + "\n")
)
for i, row in related_text_df.iterrows():
text = row[text_id_col] + column_delimiter + row[text_col] + "\n"
tokens = len(self.tokenizer.encode(text))
if current_tokens + tokens > max_context_tokens:
msg = f"Reached token limit: {current_tokens + tokens}. Reverting to previous context state"
logger.warning(msg)
break
current_tokens += tokens
text_ids.append(i)
final_text_df = cast(
"pd.DataFrame",
related_text_df[related_text_df.index.isin(text_ids)].reset_index(
drop=True
),
)
final_text = final_text_df.to_csv(
index=False, escapechar="\\", sep=column_delimiter
)
return ContextBuilderResult(
context_chunks=final_text,
context_records={context_name: final_text_df},
)
def _map_ids(self) -> dict[str, str]:
"""Map id to short id in the text units."""
id_map = {}
text_units = self.text_units or []
for unit in text_units:
id_map[unit.id] = unit.short_id
return id_map