Some algorithmic refactors

This commit is contained in:
Alonso Guevara 2024-10-14 19:33:55 -06:00
parent 60441985f2
commit 44246e2fd9
6 changed files with 124 additions and 71 deletions

View File

@ -62,3 +62,30 @@ class CommunityReport(Named):
full_content_embedding=d.get(full_content_embedding_key),
attributes=d.get(attributes_key),
)
def to_dict(
self,
id_key: str = "id",
title_key: str = "title",
community_id_key: str = "community_id",
short_id_key: str = "short_id",
summary_key: str = "summary",
full_content_key: str = "full_content",
rank_key: str = "rank",
summary_embedding_key: str = "summary_embedding",
full_content_embedding_key: str = "full_content_embedding",
attributes_key: str = "attributes",
) -> dict[str, Any]:
"""Convert the community report to a dictionary."""
return {
id_key: self.id,
title_key: self.title,
community_id_key: self.community_id,
short_id_key: self.short_id,
summary_key: self.summary,
full_content_key: self.full_content,
rank_key: self.rank,
summary_embedding_key: self.summary_embedding,
full_content_embedding_key: self.full_content_embedding,
attributes_key: self.attributes,
}

View File

@ -146,15 +146,21 @@ class DriftAction:
DriftAction
A deserialized instance of DriftAction.
"""
action = cls(data["query"])
# Ensure 'query' exists in the data, raise a ValueError if missing
query = data.get("query")
if query is None:
error_message = "Missing 'query' key in serialized data"
raise ValueError(error_message)
# Initialize the DriftAction
action = cls(query)
action.answer = data.get("answer")
action.score = data.get("score")
action.metadata = data.get("metadata", {})
action.follow_ups = (
[cls.deserialize(fu_data) for fu_data in data.get("follow_up_queries", [])]
if "follow_ups" in data
else []
)
action.follow_ups = [
cls.deserialize(fu_data) for fu_data in data.get("follow_up_queries", [])
]
return action
@classmethod
@ -187,10 +193,22 @@ class DriftAction:
action.score = response.get("score")
return action
error_message = "Response must be a dictionary"
# If response is a string, attempt to parse as JSON
if isinstance(response, str):
try:
parsed_response = json.loads(response)
if isinstance(parsed_response, dict):
return cls.from_primer_response(query, parsed_response)
error_message = "Parsed response must be a dictionary."
raise ValueError(error_message)
except json.JSONDecodeError as e:
error_message = f"Failed to parse response string: {e}. Parsed response must be a dictionary."
raise ValueError(error_message)
error_message = f"Unsupported response type: {type(response).__name__}. Expected a dictionary or JSON string."
raise ValueError(error_message)
def __hash__(self):
def __hash__(self) -> int:
"""
Allow DriftAction objects to be hashable for use in networkx.MultiDiGraph.
@ -203,7 +221,7 @@ class DriftAction:
"""
return hash(self.query)
def __eq__(self, other):
def __eq__(self, other: object) -> bool:
"""
Check equality based on the query string.
@ -215,4 +233,6 @@ class DriftAction:
bool
True if the other object is a DriftAction with the same query, False otherwise.
"""
return isinstance(other, DriftAction) and self.query == other.query
if not isinstance(other, DriftAction):
return False
return self.query == other.query

View File

@ -111,12 +111,16 @@ class DRIFTSearchContextBuilder(DRIFTContextBuilder):
------
ValueError: If some reports are missing full content or full content embeddings.
"""
report_df = pd.DataFrame([report.__dict__ for report in reports])
report_df = pd.DataFrame([report.to_dict() for report in reports])
missing_content_error = "Some reports are missing full content."
missing_embedding_error = "Some reports are missing full content embeddings."
if report_df["full_content"].isna().sum() > 0:
if (
"full_content" not in report_df.columns
or report_df["full_content"].isna().sum() > 0
):
raise ValueError(missing_content_error)
if (
"full_content_embedding" not in report_df.columns
or report_df["full_content_embedding"].isna().sum() > 0
@ -141,7 +145,9 @@ class DRIFTSearchContextBuilder(DRIFTContextBuilder):
bool: True if embeddings match, otherwise False.
"""
return (
isinstance(query_embedding, type(embedding))
query_embedding is not None
and embedding is not None
and isinstance(query_embedding, type(embedding))
and len(query_embedding) == len(embedding)
and isinstance(query_embedding[0], type(embedding[0]))
)
@ -182,21 +188,27 @@ class DRIFTSearchContextBuilder(DRIFTContextBuilder):
report_df = self.convert_reports_to_df(self.reports)
if self.check_query_doc_encodings(
# Check compatibility between query embedding and document embeddings
if not self.check_query_doc_encodings(
query_embedding, report_df["full_content_embedding"].iloc[0]
):
report_df["similarity"] = report_df["full_content_embedding"].apply(
lambda x: np.dot(x, query_embedding)
/ (np.linalg.norm(x) * np.linalg.norm(query_embedding))
)
top_k = report_df.sort_values("similarity", ascending=False).head(
self.config.drift_k_followups
)
else:
incompatible_embeddings_error = (
error_message = (
"Query and document embeddings are not compatible. "
"Please ensure that the embeddings are of the same type and length."
)
raise ValueError(incompatible_embeddings_error)
raise ValueError(error_message)
# Vectorized cosine similarity computation
query_norm = np.linalg.norm(query_embedding)
document_norms = np.linalg.norm(
report_df["full_content_embedding"].to_list(), axis=1
)
dot_products = np.dot(
np.vstack(report_df["full_content_embedding"].to_list()), query_embedding
)
report_df["similarity"] = dot_products / (document_norms * query_norm)
# Sort by similarity and select top-k
top_k = report_df.nlargest(self.config.drift_k_followups, "similarity")
return top_k.loc[:, ["short_id", "community_id", "full_content"]]

View File

@ -8,6 +8,7 @@ import logging
import secrets
import time
import numpy as np
import pandas as pd
import tiktoken
from tqdm.asyncio import tqdm_asyncio
@ -159,19 +160,19 @@ class DRIFTPrimer:
-------
SearchResult: The search result containing the response and context data.
"""
start_time = time.time()
start_time = time.perf_counter()
report_folds = self.split_reports(top_k_reports)
tasks = [self.decompose_query(query, fold) for fold in report_folds]
results_with_tokens = await tqdm_asyncio.gather(*tasks)
completion_time = time.time() - start_time
completion_time = time.perf_counter() - start_time
return SearchResult(
response=[response for response, _ in results_with_tokens],
context_data={"top_k_reports": top_k_reports},
context_text=str(top_k_reports),
context_text=top_k_reports.to_json(),
completion_time=completion_time,
llm_calls=2,
llm_calls=len(results_with_tokens),
prompt_tokens=sum(tokens for _, tokens in results_with_tokens),
)
@ -186,17 +187,7 @@ class DRIFTPrimer:
-------
list[pd.DataFrame]: List of report folds.
"""
folds = []
num_reports = len(reports)
primer_folds = self.config.primer_folds or 1 # Ensure at least one fold
for i in range(primer_folds):
start_idx = i * num_reports // primer_folds
end_idx = (
num_reports
if i == primer_folds - 1
else (i + 1) * num_reports // primer_folds
)
fold = reports.iloc[start_idx:end_idx]
folds.append(fold)
return folds
if primer_folds == 1:
return [reports]
return [pd.DataFrame(fold) for fold in np.array_split(reports, primer_folds)]

View File

@ -124,13 +124,9 @@ class DRIFTSearch(BaseSearch[DRIFTSearchContextBuilder]):
error_msg = "No intermediate answers found in primer response. Ensure that the primer response includes intermediate answers."
raise RuntimeError(error_msg)
intermediate_answer = "\n\n".join(
[
i["intermediate_answer"]
for i in response
if "intermediate_answer" in i
]
)
intermediate_answer = "\n\n".join([
i["intermediate_answer"] for i in response if "intermediate_answer" in i
])
follow_ups = [fu for i in response for fu in i.get("follow_up_queries", [])]
if len(follow_ups) == 0:
@ -193,7 +189,7 @@ class DRIFTSearch(BaseSearch[DRIFTSearchContextBuilder]):
error_msg = "DRIFT Search query cannot be empty."
raise ValueError(error_msg)
start_time = time.time()
start_time = time.perf_counter()
primer_token_ct = 0
context_token_ct = 0
@ -233,7 +229,7 @@ class DRIFTSearch(BaseSearch[DRIFTSearchContextBuilder]):
self.query_state.add_all_follow_ups(action, action.follow_ups)
epochs += 1
t_elapsed = time.time() - start_time
t_elapsed = time.perf_counter() - start_time
# Calculate token usage
total_tokens = (

View File

@ -5,7 +5,11 @@
import logging
import random
from typing import Any
from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
from typing import Callable
import networkx as nx
@ -55,7 +59,9 @@ class QueryState:
"""Find all unanswered actions in the graph."""
return [node for node in self.graph.nodes if not node.is_complete]
def rank_incomplete_actions(self, scorer: Any | None = None) -> list[DriftAction]:
def rank_incomplete_actions(
self, scorer: Callable[[DriftAction], float] | None = None
) -> list[DriftAction]:
"""Rank all unanswered actions in the graph if scorer available."""
unanswered = self.find_incomplete_actions()
if scorer:
@ -63,9 +69,9 @@ class QueryState:
node.compute_score(scorer)
return sorted(
unanswered,
key=lambda node: node.score
if node.score is not None
else float("-inf"),
key=lambda node: (
node.score if node.score is not None else float("-inf")
),
reverse=True,
)
@ -81,30 +87,31 @@ class QueryState:
node_to_id = {node: idx for idx, node in enumerate(self.graph.nodes())}
# Serialize nodes
nodes = []
for node in self.graph.nodes():
node_data = node.serialize(include_follow_ups=False)
node_data["id"] = node_to_id[node]
node_attributes = self.graph.nodes[node]
if node_attributes:
node_data.update(node_attributes)
nodes.append(node_data)
nodes: list[dict[str, Any]] = [
{
**node.serialize(include_follow_ups=False),
"id": node_to_id[node],
**self.graph.nodes[node],
}
for node in self.graph.nodes()
]
# Serialize edges
edges = []
for u, v, edge_data in self.graph.edges(data=True):
edge_info = {
edges: list[dict[str, Any]] = [
{
"source": node_to_id[u],
"target": node_to_id[v],
"weight": edge_data.get("weight", 1.0),
}
edges.append(edge_info)
for u, v, edge_data in self.graph.edges(data=True)
]
if include_context:
context_data = {}
for node in nodes:
if node["metadata"].get("context_data") and node.get("query"):
context_data[node["query"]] = node["metadata"]["context_data"]
context_data = {
node["query"]: node["metadata"]["context_data"]
for node in nodes
if node["metadata"].get("context_data") and node.get("query")
}
context_text = str(context_data)