diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py index f8f85d141a..f40d2d4bd2 100644 --- a/api/core/rag/retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -63,6 +63,7 @@ from libs.json_in_md_parser import parse_and_check_json_markdown from models import UploadFile from models.dataset import ChildChunk, Dataset, DatasetMetadata, DatasetQuery, DocumentSegment, SegmentAttachmentBinding from models.dataset import Document as DatasetDocument +from models.enums import CreatorUserRole from services.external_knowledge_service import ExternalDatasetService default_retrieval_model: dict[str, Any] = { @@ -176,13 +177,17 @@ class DatasetRetrieval: ) all_documents = [] - user_from = "account" if invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} else "end_user" + creator_user_role = ( + CreatorUserRole.ACCOUNT + if invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} + else CreatorUserRole.END_USER + ) if retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE: all_documents = self.single_retrieve( app_id, tenant_id, user_id, - user_from, + creator_user_role, query, available_datasets, model_instance, @@ -197,7 +202,7 @@ class DatasetRetrieval: app_id, tenant_id, user_id, - user_from, + creator_user_role, available_datasets, query, retrieve_config.top_k or 0, @@ -334,7 +339,7 @@ class DatasetRetrieval: app_id: str, tenant_id: str, user_id: str, - user_from: str, + creator_user_role: CreatorUserRole, query: str, available_datasets: list, model_instance: ModelInstance, @@ -444,7 +449,7 @@ class DatasetRetrieval: weights=retrieval_model_config.get("weights", None), document_ids_filter=document_ids_filter, ) - self._on_query(query, None, [dataset_id], app_id, user_from, user_id) + self._on_query(query, None, [dataset_id], app_id, creator_user_role, user_id) if results: thread = threading.Thread( @@ -466,7 +471,7 @@ class DatasetRetrieval: app_id: str, tenant_id: str, user_id: str, - user_from: str, + creator_user_role: CreatorUserRole, available_datasets: list, query: str | None, top_k: int, @@ -584,7 +589,7 @@ class DatasetRetrieval: if thread_exceptions: raise thread_exceptions[0] - self._on_query(query, attachment_ids, dataset_ids, app_id, user_from, user_id) + self._on_query(query, attachment_ids, dataset_ids, app_id, creator_user_role, user_id) if all_documents: # add thread to call _on_retrieval_end @@ -733,7 +738,7 @@ class DatasetRetrieval: attachment_ids: list[str] | None, dataset_ids: list[str], app_id: str, - user_from: str, + creator_user_role: CreatorUserRole, user_id: str, ): """ @@ -755,7 +760,7 @@ class DatasetRetrieval: content=json.dumps(contents), source="app", source_app_id=app_id, - created_by_role=user_from, + created_by_role=creator_user_role, created_by=user_id, ) dataset_queries.append(dataset_query)