Files
2025-01-19 07:52:50 +00:00

230 lines
7.4 KiB
Python

from motor.motor_asyncio import AsyncIOMotorClient
from redis import asyncio as aioredis
from datetime import datetime, timezone
from bson import ObjectId
import json
from typing import List, Dict, Any, Optional
# 数据库配置
MONGODB_URL = "mongodb://paper:SYX7cdJNMRbiytra@222.186.10.253:27017/paper"
REDIS_URL = "redis://:Obscura@2024@222.186.10.253:6379"
# Redis 数据库索引
REDIS_REPORT_DB = 190 # 文献分析报告
REDIS_CHAT_DB = 191 # 聊天历史
REDIS_TASK_DB = 192 # 任务状态
class Database:
"""数据库连接管理类"""
client: AsyncIOMotorClient = None
db = Database()
async def get_database() -> AsyncIOMotorClient:
return db.client["paper"]
async def connect_to_mongo():
"""连接到MongoDB"""
try:
db.client = AsyncIOMotorClient(MONGODB_URL)
await db.client.admin.command('ping')
print("Successfully connected to MongoDB")
except Exception as e:
print(f"Could not connect to MongoDB: {e}")
raise
async def close_mongo_connection():
"""关闭MongoDB连接"""
db.client.close()
async def get_redis():
"""获取Redis连接"""
redis = aioredis.from_url(
REDIS_URL,
encoding="utf-8",
decode_responses=True,
)
return redis
class DatabaseOperations:
"""数据库操作类"""
@staticmethod
async def insert_paper(paper: dict):
"""插入新文献"""
db = await get_database()
result = await db.papers.insert_one(paper)
return result
@staticmethod
async def set_paper_processing_status(file_hash: str):
"""设置文献处理状态"""
redis = await get_redis()
try:
await redis.select(REDIS_REPORT_DB)
report_key = f"paper_report:{file_hash}"
initial_status = {
"status": "processing",
"message": "Analysis in progress"
}
await redis.set(report_key, json.dumps(initial_status))
finally:
await redis.aclose()
@staticmethod
async def save_paper_report(file_hash: str, report_data: dict):
"""保存文献分析报告"""
redis = await get_redis()
try:
await redis.select(REDIS_REPORT_DB)
report_key = f"paper_report:{file_hash}"
await redis.set(report_key, json.dumps(report_data))
finally:
await redis.aclose()
@staticmethod
async def get_paper_report(file_hash: str):
"""获取文献分析报告"""
redis = await get_redis()
try:
await redis.select(REDIS_REPORT_DB)
report_key = f"paper_report:{file_hash}"
report_data = await redis.get(report_key)
return json.loads(report_data) if report_data else None
finally:
await redis.aclose()
@staticmethod
async def save_qa_history(file_hash: str, question: str, answer: str):
"""保存问答历史"""
redis = await get_redis()
try:
await redis.select(REDIS_CHAT_DB)
chat_history_key = f"chat_history:{file_hash}"
existing_history = await redis.get(chat_history_key)
history = json.loads(existing_history) if existing_history else []
history.append({
"question": question,
"answer": answer,
"timestamp": datetime.now(timezone.utc).isoformat()
})
await redis.set(chat_history_key, json.dumps(history))
finally:
await redis.aclose()
@staticmethod
async def get_qa_history(file_hash: str):
"""获取问答历史"""
redis = await get_redis()
try:
await redis.select(REDIS_CHAT_DB)
chat_history_key = f"chat_history:{file_hash}"
history_data = await redis.get(chat_history_key)
return json.loads(history_data) if history_data else []
finally:
await redis.aclose()
@staticmethod
async def save_task_status(task_id: str, status_data: dict):
"""保存任务状态"""
redis = await get_redis()
try:
await redis.select(REDIS_TASK_DB)
await redis.hset(f"task:{task_id}", mapping=status_data)
finally:
await redis.aclose()
@staticmethod
async def get_task_status(task_id: str):
"""获取任务状态"""
redis = await get_redis()
try:
await redis.select(REDIS_TASK_DB)
task_data = await redis.hgetall(f"task:{task_id}")
return task_data if task_data else None
finally:
await redis.aclose()
@staticmethod
async def delete_paper(paper_id: str):
"""删除文献"""
db = await get_database()
redis = await get_redis()
try:
paper = await db.papers.find_one({"file_hash": paper_id})
if not paper:
return None
file_hash = paper.get("file_hash")
await db.papers.delete_one({"file_hash": paper_id})
# 删除Redis中的数据
try:
await redis.select(REDIS_REPORT_DB)
report_key = f"paper_report:{file_hash}"
await redis.delete(report_key)
await redis.select(REDIS_CHAT_DB)
chat_history_key = f"chat_history:{paper_id}"
await redis.delete(chat_history_key)
finally:
await redis.aclose()
return {"message": "paper successfully deleted"}
except Exception as e:
raise Exception(f"Failed to delete paper: {str(e)}")
@staticmethod
async def get_papers():
"""获取所有文献列表"""
db = await get_database()
papers = []
async for ref in db.papers.find():
papers.append({
"_id": str(ref["_id"]),
"paper_title": ref["paper_title"],
"upload_time": ref["upload_time"],
"file_hash": ref.get("file_hash")
})
return papers
@staticmethod
async def get_paper(file_hash: str):
"""获取单个文献信息"""
db = await get_database()
return await db.papers.find_one({"file_hash": file_hash})
@staticmethod
async def check_existing_report(file_hash: str):
"""检查是否已存在分析报告"""
redis = await get_redis()
try:
await redis.select(REDIS_REPORT_DB)
report_key = f"paper_report:{file_hash}"
existing_report = await redis.get(report_key)
if existing_report:
report_data = json.loads(existing_report)
if report_data.get("status") == "completed":
return True
return False
finally:
await redis.aclose()
@staticmethod
async def save_analysis_error(file_hash: str, error_message: str):
"""保存分析错误状态"""
redis = await get_redis()
try:
await redis.select(REDIS_REPORT_DB)
report_key = f"paper_report:{file_hash}"
error_status = {
"status": "failed",
"message": str(error_message)
}
await redis.set(report_key, json.dumps(error_status))
finally:
await redis.aclose()