capture generated tokens in token count

This commit is contained in:
Julian Whiting 2024-10-11 16:51:25 -04:00
parent d24e0bd3cc
commit 48afa46f24
2 changed files with 14 additions and 6 deletions

View File

@ -7,6 +7,8 @@ import json
import logging
from typing import Any
from graphrag.query.llm.text_utils import num_tokens
log = logging.getLogger(__name__)
@ -78,7 +80,13 @@ class DriftAction:
self.answer = response.pop("response", None)
self.score = response.pop("score", float("-inf"))
self.metadata.update({"context_data": search_result.context_data})
self.metadata.update({"token_ct": search_result.token_ct})
if self.answer is None:
log.warning("No answer found for query: %s", self.query)
generated_tokens = 0
else:
generated_tokens = num_tokens(self.answer, search_engine.token_encoder)
self.metadata.update({"token_ct": search_result.prompt_tokens + generated_tokens})
self.follow_ups = response.pop("follow_up_queries", [])
if not self.follow_ups:

View File

@ -129,11 +129,11 @@ class DRIFTSearch(BaseSearch[DRIFTSearchContextBuilder]):
])
follow_ups = [
fu
for i in response
for fu in i["follow_up_queries"]
if "follow_up_queries" in i
]
fu
for i in response
if "follow_up_queries" in i and i["follow_up_queries"]
for fu in i["follow_up_queries"]
]
if len(follow_ups) == 0:
error_msg = "No follow-up queries found in primer response. Ensure that the primer response includes follow-up queries."
raise RuntimeError(error_msg)