diff --git a/__pycache__/api.cpython-311.pyc b/__pycache__/api.cpython-311.pyc new file mode 100644 index 0000000..02d78b6 Binary files /dev/null and b/__pycache__/api.cpython-311.pyc differ diff --git a/__pycache__/database.cpython-311.pyc b/__pycache__/database.cpython-311.pyc new file mode 100644 index 0000000..45936d5 Binary files /dev/null and b/__pycache__/database.cpython-311.pyc differ diff --git a/api.py b/api.py new file mode 100644 index 0000000..4060c64 --- /dev/null +++ b/api.py @@ -0,0 +1,41 @@ +from openai import OpenAI +import aiohttp +import json + +# DeepSeek API Configuration +DEEPSEEK_API_KEY = "sk-3027fb3c810b4e17985fa397d41250b9" +DEEPSEEK_BASE_URL = "https://api.deepseek.com/v1" + +client = OpenAI( + base_url=DEEPSEEK_BASE_URL, + api_key=DEEPSEEK_API_KEY +) + +async def get_aiohttp_session(): + """获取异步HTTP会话""" + return aiohttp.ClientSession( + base_url=f"{DEEPSEEK_BASE_URL}/", + headers={"Authorization": f"Bearer {DEEPSEEK_API_KEY}"} + ) + +async def call_deepseek(messages: list) -> dict: + """异步调用DeepSeek API""" + async with await get_aiohttp_session() as session: + async with session.post("/chat/completions", json={ + "model": "deepseek-chat", + "messages": messages, + "response_format": {"type": "json_object"} + }) as response: + if response.status == 200: + data = await response.json() + return json.loads(data["choices"][0]["message"]["content"]) + else: + raise Exception(f"API调用失败: {await response.text()}") + +async def call_deepseek_sync(messages: list) -> str: + """同步调用DeepSeek API""" + response = client.chat.completions.create( + model="deepseek-chat", + messages=messages + ) + return response.choices[0].message.content \ No newline at end of file diff --git a/database.py b/database.py new file mode 100644 index 0000000..7b7784a --- /dev/null +++ b/database.py @@ -0,0 +1,230 @@ +from motor.motor_asyncio import AsyncIOMotorClient +from redis import asyncio as aioredis +from datetime import datetime, timezone +from bson import ObjectId +import json +from typing import List, Dict, Any, Optional + +# 数据库配置 +MONGODB_URL = "mongodb://paper:SYX7cdJNMRbiytra@222.186.10.253:27017/paper" +REDIS_URL = "redis://:Obscura@2024@222.186.10.253:6379" + +# Redis 数据库索引 +REDIS_REPORT_DB = 190 # 文献分析报告 +REDIS_CHAT_DB = 191 # 聊天历史 +REDIS_TASK_DB = 192 # 任务状态 + +class Database: + """数据库连接管理类""" + client: AsyncIOMotorClient = None + +db = Database() + +async def get_database() -> AsyncIOMotorClient: + return db.client["paper"] + +async def connect_to_mongo(): + """连接到MongoDB""" + try: + db.client = AsyncIOMotorClient(MONGODB_URL) + await db.client.admin.command('ping') + print("Successfully connected to MongoDB") + except Exception as e: + print(f"Could not connect to MongoDB: {e}") + raise + +async def close_mongo_connection(): + """关闭MongoDB连接""" + db.client.close() + +async def get_redis(): + """获取Redis连接""" + redis = aioredis.from_url( + REDIS_URL, + encoding="utf-8", + decode_responses=True, + ) + return redis + +class DatabaseOperations: + """数据库操作类""" + + @staticmethod + async def insert_paper(paper: dict): + """插入新文献""" + db = await get_database() + result = await db.papers.insert_one(paper) + return result + + @staticmethod + async def set_paper_processing_status(file_hash: str): + """设置文献处理状态""" + redis = await get_redis() + try: + await redis.select(REDIS_REPORT_DB) + report_key = f"paper_report:{file_hash}" + initial_status = { + "status": "processing", + "message": "Analysis in progress" + } + await redis.set(report_key, json.dumps(initial_status)) + finally: + await redis.aclose() + + @staticmethod + async def save_paper_report(file_hash: str, report_data: dict): + """保存文献分析报告""" + redis = await get_redis() + try: + await redis.select(REDIS_REPORT_DB) + report_key = f"paper_report:{file_hash}" + await redis.set(report_key, json.dumps(report_data)) + finally: + await redis.aclose() + + @staticmethod + async def get_paper_report(file_hash: str): + """获取文献分析报告""" + redis = await get_redis() + try: + await redis.select(REDIS_REPORT_DB) + report_key = f"paper_report:{file_hash}" + report_data = await redis.get(report_key) + return json.loads(report_data) if report_data else None + finally: + await redis.aclose() + + @staticmethod + async def save_qa_history(file_hash: str, question: str, answer: str): + """保存问答历史""" + redis = await get_redis() + try: + await redis.select(REDIS_CHAT_DB) + chat_history_key = f"chat_history:{file_hash}" + + existing_history = await redis.get(chat_history_key) + history = json.loads(existing_history) if existing_history else [] + + history.append({ + "question": question, + "answer": answer, + "timestamp": datetime.now(timezone.utc).isoformat() + }) + + await redis.set(chat_history_key, json.dumps(history)) + finally: + await redis.aclose() + + @staticmethod + async def get_qa_history(file_hash: str): + """获取问答历史""" + redis = await get_redis() + try: + await redis.select(REDIS_CHAT_DB) + chat_history_key = f"chat_history:{file_hash}" + history_data = await redis.get(chat_history_key) + return json.loads(history_data) if history_data else [] + finally: + await redis.aclose() + + @staticmethod + async def save_task_status(task_id: str, status_data: dict): + """保存任务状态""" + redis = await get_redis() + try: + await redis.select(REDIS_TASK_DB) + await redis.hset(f"task:{task_id}", mapping=status_data) + finally: + await redis.aclose() + + @staticmethod + async def get_task_status(task_id: str): + """获取任务状态""" + redis = await get_redis() + try: + await redis.select(REDIS_TASK_DB) + task_data = await redis.hgetall(f"task:{task_id}") + return task_data if task_data else None + finally: + await redis.aclose() + + @staticmethod + async def delete_paper(paper_id: str): + """删除文献""" + db = await get_database() + redis = await get_redis() + + try: + paper = await db.papers.find_one({"file_hash": paper_id}) + if not paper: + return None + + file_hash = paper.get("file_hash") + await db.papers.delete_one({"file_hash": paper_id}) + + # 删除Redis中的数据 + try: + await redis.select(REDIS_REPORT_DB) + report_key = f"paper_report:{file_hash}" + await redis.delete(report_key) + + await redis.select(REDIS_CHAT_DB) + chat_history_key = f"chat_history:{paper_id}" + await redis.delete(chat_history_key) + finally: + await redis.aclose() + + return {"message": "paper successfully deleted"} + except Exception as e: + raise Exception(f"Failed to delete paper: {str(e)}") + + @staticmethod + async def get_papers(): + """获取所有文献列表""" + db = await get_database() + papers = [] + async for ref in db.papers.find(): + papers.append({ + "_id": str(ref["_id"]), + "paper_title": ref["paper_title"], + "upload_time": ref["upload_time"], + "file_hash": ref.get("file_hash") + }) + return papers + + @staticmethod + async def get_paper(file_hash: str): + """获取单个文献信息""" + db = await get_database() + return await db.papers.find_one({"file_hash": file_hash}) + + @staticmethod + async def check_existing_report(file_hash: str): + """检查是否已存在分析报告""" + redis = await get_redis() + try: + await redis.select(REDIS_REPORT_DB) + report_key = f"paper_report:{file_hash}" + existing_report = await redis.get(report_key) + if existing_report: + report_data = json.loads(existing_report) + if report_data.get("status") == "completed": + return True + return False + finally: + await redis.aclose() + + @staticmethod + async def save_analysis_error(file_hash: str, error_message: str): + """保存分析错误状态""" + redis = await get_redis() + try: + await redis.select(REDIS_REPORT_DB) + report_key = f"paper_report:{file_hash}" + error_status = { + "status": "failed", + "message": str(error_message) + } + await redis.set(report_key, json.dumps(error_status)) + finally: + await redis.aclose() \ No newline at end of file diff --git a/paper.py b/paper.py index 105bcd7..198881b 100644 --- a/paper.py +++ b/paper.py @@ -1,59 +1,23 @@ -from fastapi import FastAPI, HTTPException, UploadFile, File, Form, Query +from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware -from fastapi.responses import JSONResponse -from typing import Dict, List, Optional +from typing import List, Optional from datetime import datetime, timezone import asyncio -from motor.motor_asyncio import AsyncIOMotorClient from bson import ObjectId from pydantic import BaseModel, Field -from redis import asyncio as aioredis -from typing import Any -from contextlib import asynccontextmanager import json -import os -import PyPDF2 import aiohttp from concurrent.futures import ThreadPoolExecutor -from openai import OpenAI -import hashlib +from contextlib import asynccontextmanager -# Database Configuration -MONGODB_URL = "mongodb://paper:SYX7cdJNMRbiytra@222.186.10.253:27017/paper" -REDIS_URL = "redis://:Obscura@2024@222.186.10.253:6379" +from database import ( + connect_to_mongo, close_mongo_connection, + DatabaseOperations as db_ops +) +from api import call_deepseek, call_deepseek_sync - -# Database connection -class Database: - """数据库连接管理类""" - client: AsyncIOMotorClient = None - -db = Database() - -async def get_database() -> AsyncIOMotorClient: - return db.client["paper"] - -async def connect_to_mongo(): - try: - db.client = AsyncIOMotorClient(MONGODB_URL) - # 验证连接 - await db.client.admin.command('ping') - print("Successfully connected to MongoDB") - except Exception as e: - print(f"Could not connect to MongoDB: {e}") - raise - -async def close_mongo_connection(): - db.client.close() - -# Redis setup -async def get_redis(): - redis = aioredis.from_url( - REDIS_URL, - encoding="utf-8", - decode_responses=True, - ) - return redis +# 在全局范围创建线程池 +pdf_thread_pool = ThreadPoolExecutor(max_workers=3) # 限制并发PDF处理数量 @asynccontextmanager async def lifespan(app: FastAPI): @@ -76,7 +40,6 @@ app.add_middleware( expose_headers=["*"] ) - # 在其他 Pydantic 模型后添加 class paperModel(BaseModel): """文献引用模型""" @@ -89,101 +52,36 @@ class paperModel(BaseModel): arbitrary_types_allowed = True json_encoders = {ObjectId: str} - # 删除文献 @app.delete("/paper/delete/{paper_id}") -async def delete_paper( - paper_id: str -): +async def delete_paper(paper_id: str): """删除文献""" - db = await get_database() - try: - # 获取文献信息 - paper = await db.papers.find_one({"file_hash": paper_id}) - if not paper: + result = await db_ops.delete_paper(paper_id) + if not result: raise HTTPException(status_code=404, detail="paper not found") - - # 从文件存储服务删除文件 - file_hash = paper.get("file_hash") - # 删除数据库记录 - await db.papers.delete_one({"file_hash": paper_id}) - - # 删除Redis中的分析报告(如果存在) - try: - redis = await get_redis() - await redis.select(190) - report_key = f"paper_report:{file_hash}" - await redis.delete(report_key) - - # 删除聊天历史(如果存在) - await redis.select(191) - chat_history_key = f"chat_history:{paper_id}" - await redis.delete(chat_history_key) - except Exception as redis_error: - print(f"Warning: Failed to delete Redis keys: {str(redis_error)}") - finally: - try: - await redis.aclose() - except Exception as e: - print(f"Error closing Redis connection: {e}") - - return {"message": "paper successfully deleted"} - + return result except Exception as e: - print(f"Error deleting paper: {str(e)}") - raise HTTPException(status_code=500, detail=f"Failed to delete paper: {str(e)}") - + raise HTTPException(status_code=500, detail=str(e)) + @app.get("/paper/papers") async def get_papers(): """获取文献列表""" - db = await get_database() - try: - papers = [] - async for ref in db.papers.find(): - papers.append({ - "_id": str(ref["_id"]), - "paper_link": ref["paper_link"], - "paper_title": ref["paper_title"], - "upload_time": ref["upload_time"], - "file_hash": ref.get("file_hash") - }) - return papers - + return await db_ops.get_papers() except Exception as e: - print(f"Error getting papers: {str(e)}") raise HTTPException(status_code=500, detail=str(e)) - + @app.get("/paper/report/{file_hash}") -async def get_report( - file_hash: str -): - """从 Redis db190 直接通过文件哈希值读取已保存的文献报告""" - redis = await get_redis() - +async def get_report(file_hash: str): + """获取文献报告""" try: - # 选择 db190 - await redis.select(190) - report_key = f"paper_report:{file_hash}" - - # 获取已保存的报告 - existing_report = await redis.get(report_key) - - if not existing_report: + report = await db_ops.get_paper_report(file_hash) + if not report: raise HTTPException(status_code=404, detail="No saved paper report found") - - # 返回报告 - return json.loads(existing_report) - + return report except Exception as e: - print(f"Error getting paper report: {str(e)}") raise HTTPException(status_code=500, detail=str(e)) - finally: - try: - await redis.aclose() - except Exception as e: - print(f"Error closing Redis connection: {e}") class FileUpload(BaseModel): filename: str @@ -195,9 +93,6 @@ class BatchUploadRequest(BaseModel): @app.post("/paper/upload") async def batch_upload(request: BatchUploadRequest): """批量上传项目相关文献""" - db = await get_database() - redis = await get_redis() - try: uploaded_papers = [] papers_to_analyze = [] @@ -206,13 +101,12 @@ async def batch_upload(request: BatchUploadRequest): try: # 1. 创建新记录 paper = { - "paper_link": f"https://files.aiot.ml/pdf/{file.hash}", "paper_title": file.filename, "upload_time": datetime.now(timezone.utc), "file_hash": file.hash } - result = await db.papers.insert_one(paper) + result = await db_ops.insert_paper(paper) paper_info = { "paper_id": str(result.inserted_id), "file_hash": file.hash, @@ -221,13 +115,7 @@ async def batch_upload(request: BatchUploadRequest): uploaded_papers.append(paper_info) # 2. 设置分析状态 - await redis.select(190) - report_key = f"paper_report:{file.hash}" - initial_status = { - "status": "processing", - "message": "Analysis in progress" - } - await redis.set(report_key, json.dumps(initial_status)) + await db_ops.set_paper_processing_status(file.hash) papers_to_analyze.append(paper_info) except Exception as file_error: @@ -254,56 +142,49 @@ async def batch_upload(request: BatchUploadRequest): except Exception as e: print(f"批量上传错误: {str(e)}") raise HTTPException(status_code=500, detail=str(e)) - finally: - try: - await redis.aclose() - except Exception as e: - print(f"Error closing Redis connection: {e}") +async def get_pdf_content(file_hash: str) -> Optional[str]: + """从文件存储服务获取PDF内容""" + async with aiohttp.ClientSession() as session: + async with session.get(f'https://files.aiot.ml/pdf/{file_hash}') as response: + if response.status != 200: + print(f"获取PDF内容失败: {response.status}") + return None + + pdf_content = await response.json() + if not pdf_content.get('content'): + return None + + content = pdf_content['content'] + if isinstance(content, list): + content = '\n'.join(content) + + return content async def batch_analysis(papers: List[dict]): """批量处理文献分析的后台任务""" - redis = await get_redis() - # 限制并发数量 semaphore = asyncio.Semaphore(3) async def single_paper(ref: dict): async with semaphore: try: - paper_id = ref["paper_id"] file_hash = ref["file_hash"] - # 再次检查报告是否存在(以防在开始分析前已经被其他进程分析) - await redis.select(190) - report_key = f"paper_report:{file_hash}" - existing_report = await redis.get(report_key) - if existing_report: - report_data = json.loads(existing_report) - if report_data.get("status") == "completed": - print(f"Report already exists for file hash: {file_hash}, skipping analysis") - return - # 从文件存储服务获取PDF内容 - async with aiohttp.ClientSession() as session: - async with session.get(f'https://files.aiot.ml/pdf/{file_hash}') as response: - if response.status != 200: - print(f"[Redis] 获取PDF内容失败: {response.status}") - raise Exception(f"Failed to get PDF content for file hash: {file_hash}") - - pdf_content = await response.json() - if not pdf_content.get('content'): - raise Exception("No PDF content returned") - + # 检查报告是否存在 + if await db_ops.check_existing_report(file_hash): + print(f"Report already exists for file hash: {file_hash}, skipping analysis") + return + + # 获取PDF内容 + content = await get_pdf_content(file_hash) + if not content: + await db_ops.save_analysis_error(file_hash, "Failed to get PDF content") + return + print(f"\n开始处理文献 {ref.get('paper_title', '未知标题')}") - print(f"文献ID: {paper_id}") print(f"文件哈希: {file_hash}") - # 使用获取到的PDF内容继续处理 - content = pdf_content['content'] - # 如果content是列表,将其合并为单个字符串 - if isinstance(content, list): - content = '\n'.join(content) - # 打印字符数 content_length = len(content) print(f"\n=== 步骤2: 内容长度检查 ===") @@ -315,19 +196,22 @@ async def batch_analysis(papers: List[dict]): print(f"文档长度在处理范围内 ({content_length} <= 200000)") document_analysis = await analyze_paper(content[:180000]) if not document_analysis: - raise Exception("Failed to analyze document") + await db_ops.save_analysis_error(file_hash, "Failed to analyze document") + return print("文档分析完成") else: print(f"\n=== 步骤3B: 使用分段分析方式 ===") print(f"文档超过200000字符 ({content_length} > 200000)") analysis_results = await analyze_long_file(content) if not analysis_results: - raise Exception("Failed to analyze document in segments") + await db_ops.save_analysis_error(file_hash, "Failed to analyze document in segments") + return print(f"分段分析完成,共分析了 {len(analysis_results)} 个段落") document_analysis = await merge_results(analysis_results) if not document_analysis: - raise Exception("Failed to merge analysis results") + await db_ops.save_analysis_error(file_hash, "Failed to merge analysis results") + return # 等待一小段时间避免API限制 await asyncio.sleep(1) @@ -335,7 +219,8 @@ async def batch_analysis(papers: List[dict]): # 异步分析文献价值 value_evaluation = await paper_value(document_analysis) if not value_evaluation: - raise Exception("Failed to evaluate value") + await db_ops.save_analysis_error(file_hash, "Failed to evaluate value") + return print("文献价值分析完成") # 合并结果 @@ -347,98 +232,26 @@ async def batch_analysis(papers: List[dict]): } # 保存结果 - await redis.select(190) - await redis.set(report_key, json.dumps(analysis_result)) + await db_ops.save_paper_report(file_hash, analysis_result) except Exception as e: - try: - await redis.select(190) - report_key = f"paper_report:{file_hash}" - error_status = { - "status": "failed", - "message": str(e) - } - await redis.set(report_key, json.dumps(error_status)) - except Exception as redis_error: - print(f"Error updating Redis status: {redis_error}") + print(f"分析文献时出错: {str(e)}") + await db_ops.save_analysis_error(file_hash, str(e)) try: # 并发处理所有引用 await asyncio.gather( *(single_paper(ref) for ref in papers) ) - finally: - try: - await redis.aclose() - except Exception as e: - print(f"Error closing Redis connection: {e}") - -# 在全局范围创建线程池 -pdf_thread_pool = ThreadPoolExecutor(max_workers=3) # 限制并发PDF处理数量 -# DeepSeek API Configuration -client = OpenAI( - base_url="https://api.deepseek.com/v1", - api_key="sk-3027fb3c810b4e17985fa397d41250b9" -) -# 创建异步HTTP客户端会话 -async def get_aiohttp_session(): - return aiohttp.ClientSession( - base_url="https://api.deepseek.com/v1/", # 添加了末尾的斜杠 - headers={"Authorization": f"Bearer sk-3027fb3c810b4e17985fa397d41250b9"} - ) - -async def read_pdf(file_path: str) -> str: - """在线程池中异步读取PDF""" - def read_pdf(): - try: - pdf_content = "" - with open(file_path, 'rb') as file: - pdf_reader = PyPDF2.PdfReader(file) - for page in pdf_reader.pages: - content = page.extract_text() - pdf_content += content - - # 根据内容长度返回不同的结果 - content_length = len(pdf_content) - print(f"\nPDF原始内容字符数: {content_length}") - - if content_length <= 180000: - print("文档长度 ≤ 180000,返回完整内容") - return pdf_content - elif content_length <= 200000: - print("文档长度在 180000-200000 之间,截取前 180000 个字符") - return pdf_content[:180000] - else: - print("文档长度 > 200000,返回完整内容供分段处理") - return pdf_content # 返回完整内容,由调用者处理分段 - - except Exception as e: - print(f"Error reading PDF: {e}") - return "" - - loop = asyncio.get_event_loop() - return await loop.run_in_executor(pdf_thread_pool, read_pdf) - -async def call_deepseek(messages: list) -> dict: - """异步调用DeepSeek API""" - async with await get_aiohttp_session() as session: - async with session.post("/chat/completions", json={ - "model": "deepseek-chat", - "messages": messages, - "response_format": {"type": "json_object"} - }) as response: - if response.status == 200: - data = await response.json() - return json.loads(data["choices"][0]["message"]["content"]) - else: - raise Exception(f"API调用失败: {await response.text()}") + except Exception as e: + print(f"批量分析任务出错: {str(e)}") async def analyze_long_file(content: str) -> List[dict]: """分段分析长文档""" # 将内容分成多个段落,每段约60000个字符(估算后约50000 tokens) segments = [] content_length = len(content) - segment_size = 60000 # 减小段落大小 + segment_size = 180000 # 减小段落大小 print(f"\n开始分段处理,总字符数: {content_length}") print(f"每段大小: {segment_size} 字符") @@ -598,6 +411,7 @@ async def merge_results(results: List[dict]) -> dict: print("分析结果合并失败") return result + async def analyze_paper(content: str): """分析文献的基本信息和内容""" system_prompt = """ @@ -698,20 +512,9 @@ class TaskStatus: FAILED = "failed" @app.post("/paper/{file_hash}/qa") -async def ask_reference_question( - file_hash: str, - question: QuestionModel -): +async def ask_reference_question(file_hash: str, question: QuestionModel): """Ask questions about the paper (async)""" - db = await get_database() - redis = await get_redis() - try: - # Verify paper exists - paper = await db.papers.find_one({"file_hash": file_hash}) - if not paper: - raise HTTPException(status_code=404, detail="Paper not found") - # Generate task ID task_id = str(ObjectId()) @@ -719,8 +522,7 @@ async def ask_reference_question( asyncio.create_task(paper_question( task_id=task_id, file_hash=file_hash, - question=question.question, - paper=paper + question=question.question )) return { @@ -730,56 +532,36 @@ async def ask_reference_question( } except Exception as e: - print(f"Error in ask_reference_question: {str(e)}") raise HTTPException(status_code=500, detail=str(e)) - finally: - try: - await redis.aclose() - except Exception as e: - print(f"Error closing Redis connection: {e}") -async def paper_question(task_id: str, file_hash: str, question: str, paper: dict): +async def paper_question(task_id: str, file_hash: str, question: str): """Process paper Q&A background task""" - redis = await get_redis() - try: + # Get paper info + paper = await db_ops.get_paper(file_hash) + if not paper: + raise Exception("Paper not found") + # Update task status to processing - await redis.select(192) # Use db192 to store task status - await redis.hset( - f"task:{task_id}", - mapping={ - "status": TaskStatus.PROCESSING, - "file_hash": file_hash, - "question": question - } - ) - - # Get analysis report from Redis - await redis.select(190) - report_key = f"paper_report:{file_hash}" - report_data = await redis.get(report_key) + await db_ops.save_task_status(task_id, { + "status": TaskStatus.PROCESSING, + "file_hash": file_hash, + "question": question + }) + # Get analysis report + report_data = await db_ops.get_paper_report(file_hash) if not report_data: raise Exception("Paper analysis report does not exist, please analyze first") - # Get PDF content from file storage service - async with aiohttp.ClientSession() as session: - async with session.get(f'https://files.aiot.ml/pdf/{file_hash}') as response: - if response.status != 200: - raise Exception(f"Failed to get PDF content: HTTP {response.status}") - - pdf_content = await response.json() - if not pdf_content.get('content'): - raise Exception("No PDF content returned") - - content = pdf_content['content'] - # If content is a list, join it into a single string - if isinstance(content, list): - content = '\n'.join(content) - - # Limit content length - MAX_CHARS = 180000 - content = content[:MAX_CHARS] + # Get PDF content + content = await get_pdf_content(file_hash) + if not content: + raise Exception("Failed to get PDF content") + + # Limit content length + MAX_CHARS = 180000 + content = content[:MAX_CHARS] # Build system prompt and user prompt system_prompt = """ @@ -791,7 +573,7 @@ async def paper_question(task_id: str, file_hash: str, question: str, paper: dic # Build context context = { "Paper Content": content, - "Analysis Report": json.loads(report_data) + "Analysis Report": report_data } messages = [ @@ -800,94 +582,39 @@ async def paper_question(task_id: str, file_hash: str, question: str, paper: dic ] # Call DeepSeek API for answer - response = client.chat.completions.create( - model="deepseek-chat", - messages=messages - ) - answer = response.choices[0].message.content + answer = await call_deepseek_sync(messages) - # Save conversation history to Redis db191 - await redis.select(191) - chat_history_key = f"chat_history:{file_hash}" - - # Get existing history - existing_history = await redis.get(chat_history_key) - history = json.loads(existing_history) if existing_history else [] - - # Add new conversation - history.append({ - "question": question, - "answer": answer, - "timestamp": datetime.now(timezone.utc).isoformat() - }) - - # Save updated history - await redis.set(chat_history_key, json.dumps(history)) + # Save conversation history + await db_ops.save_qa_history(file_hash, question, answer) # Update task status to completed - await redis.select(192) - await redis.hset( - f"task:{task_id}", - mapping={ - "status": TaskStatus.COMPLETED, - "answer": answer, - "paper_title": paper.get("paper_title") - } - ) + await db_ops.save_task_status(task_id, { + "status": TaskStatus.COMPLETED, + "answer": answer, + "paper_title": paper.get("paper_title") + }) except Exception as e: print(f"Error processing question: {e}") # Update task status to failed - await redis.select(192) - await redis.hset( - f"task:{task_id}", - mapping={ - "status": TaskStatus.FAILED, - "error": str(e) - } - ) - finally: - try: - await redis.aclose() - except Exception as e: - print(f"Error closing Redis connection: {e}") + await db_ops.save_task_status(task_id, { + "status": TaskStatus.FAILED, + "error": str(e) + }) @app.get("/paper/{file_hash}/qa/history") -async def get_paper_qa_history( - file_hash: str -): +async def get_paper_qa_history(file_hash: str): """Get paper Q&A history""" - redis = await get_redis() - try: - # Get conversation history from Redis db191 - await redis.select(191) - chat_history_key = f"chat_history:{file_hash}" - - history_data = await redis.get(chat_history_key) - if not history_data: - return [] - - return json.loads(history_data) - + return await db_ops.get_qa_history(file_hash) except Exception as e: - print(f"Error getting QA history: {str(e)}") raise HTTPException(status_code=500, detail=str(e)) - finally: - try: - await redis.aclose() - except Exception as e: - print(f"Error closing Redis connection: {e}") @app.get("/paper/task/{task_id}") async def get_task_status(task_id: str): """Get task status and result""" - redis = await get_redis() - try: - await redis.select(192) - task_data = await redis.hgetall(f"task:{task_id}") - + task_data = await db_ops.get_task_status(task_id) if not task_data: raise HTTPException(status_code=404, detail="Task not found") @@ -896,59 +623,30 @@ async def get_task_status(task_id: str): "status": task_data.get("status", TaskStatus.PENDING) } - # If task completed, add results if task_data.get("status") == TaskStatus.COMPLETED: response.update({ "answer": task_data.get("answer"), "paper_title": task_data.get("paper_title") }) - # If task failed, add error information elif task_data.get("status") == TaskStatus.FAILED: response.update({ "error": task_data.get("error") }) return response - except Exception as e: - print(f"Error getting task status: {str(e)}") raise HTTPException(status_code=500, detail=str(e)) - finally: - try: - await redis.aclose() - except Exception as e: - print(f"Error closing Redis connection: {e}") -# 添加新的路由,通过哈希值获取报告 @app.get("/paper/check/{file_hash}") -async def check_report( - file_hash: str -): - """直接通过文件哈希值从 Redis db190 读取已保存的文献报告""" - redis = await get_redis() - +async def check_report(file_hash: str): + """检查文献报告状态""" try: - # 选择 db190 - await redis.select(190) - report_key = f"paper_report:{file_hash}" - - # 获取已保存的报告 - existing_report = await redis.get(report_key) - - if not existing_report: + report = await db_ops.get_paper_report(file_hash) + if not report: return {"status": "not_found"} - - # 返回报告 - return json.loads(existing_report) - + return report except Exception as e: - print(f"Error getting paper report by hash: {str(e)}") raise HTTPException(status_code=500, detail=str(e)) - finally: - try: - await redis.aclose() - except Exception as e: - print(f"Error closing Redis connection: {e}") if __name__ == "__main__": import uvicorn