This commit is contained in:
2025-01-19 07:52:50 +00:00
parent 064f56e4b9
commit 989c59b616
5 changed files with 383 additions and 414 deletions
Binary file not shown.
Binary file not shown.
+41
View File
@@ -0,0 +1,41 @@
from openai import OpenAI
import aiohttp
import json
# DeepSeek API Configuration
DEEPSEEK_API_KEY = "sk-3027fb3c810b4e17985fa397d41250b9"
DEEPSEEK_BASE_URL = "https://api.deepseek.com/v1"
client = OpenAI(
base_url=DEEPSEEK_BASE_URL,
api_key=DEEPSEEK_API_KEY
)
async def get_aiohttp_session():
"""获取异步HTTP会话"""
return aiohttp.ClientSession(
base_url=f"{DEEPSEEK_BASE_URL}/",
headers={"Authorization": f"Bearer {DEEPSEEK_API_KEY}"}
)
async def call_deepseek(messages: list) -> dict:
"""异步调用DeepSeek API"""
async with await get_aiohttp_session() as session:
async with session.post("/chat/completions", json={
"model": "deepseek-chat",
"messages": messages,
"response_format": {"type": "json_object"}
}) as response:
if response.status == 200:
data = await response.json()
return json.loads(data["choices"][0]["message"]["content"])
else:
raise Exception(f"API调用失败: {await response.text()}")
async def call_deepseek_sync(messages: list) -> str:
"""同步调用DeepSeek API"""
response = client.chat.completions.create(
model="deepseek-chat",
messages=messages
)
return response.choices[0].message.content
+230
View File
@@ -0,0 +1,230 @@
from motor.motor_asyncio import AsyncIOMotorClient
from redis import asyncio as aioredis
from datetime import datetime, timezone
from bson import ObjectId
import json
from typing import List, Dict, Any, Optional
# 数据库配置
MONGODB_URL = "mongodb://paper:SYX7cdJNMRbiytra@222.186.10.253:27017/paper"
REDIS_URL = "redis://:Obscura@2024@222.186.10.253:6379"
# Redis 数据库索引
REDIS_REPORT_DB = 190 # 文献分析报告
REDIS_CHAT_DB = 191 # 聊天历史
REDIS_TASK_DB = 192 # 任务状态
class Database:
"""数据库连接管理类"""
client: AsyncIOMotorClient = None
db = Database()
async def get_database() -> AsyncIOMotorClient:
return db.client["paper"]
async def connect_to_mongo():
"""连接到MongoDB"""
try:
db.client = AsyncIOMotorClient(MONGODB_URL)
await db.client.admin.command('ping')
print("Successfully connected to MongoDB")
except Exception as e:
print(f"Could not connect to MongoDB: {e}")
raise
async def close_mongo_connection():
"""关闭MongoDB连接"""
db.client.close()
async def get_redis():
"""获取Redis连接"""
redis = aioredis.from_url(
REDIS_URL,
encoding="utf-8",
decode_responses=True,
)
return redis
class DatabaseOperations:
"""数据库操作类"""
@staticmethod
async def insert_paper(paper: dict):
"""插入新文献"""
db = await get_database()
result = await db.papers.insert_one(paper)
return result
@staticmethod
async def set_paper_processing_status(file_hash: str):
"""设置文献处理状态"""
redis = await get_redis()
try:
await redis.select(REDIS_REPORT_DB)
report_key = f"paper_report:{file_hash}"
initial_status = {
"status": "processing",
"message": "Analysis in progress"
}
await redis.set(report_key, json.dumps(initial_status))
finally:
await redis.aclose()
@staticmethod
async def save_paper_report(file_hash: str, report_data: dict):
"""保存文献分析报告"""
redis = await get_redis()
try:
await redis.select(REDIS_REPORT_DB)
report_key = f"paper_report:{file_hash}"
await redis.set(report_key, json.dumps(report_data))
finally:
await redis.aclose()
@staticmethod
async def get_paper_report(file_hash: str):
"""获取文献分析报告"""
redis = await get_redis()
try:
await redis.select(REDIS_REPORT_DB)
report_key = f"paper_report:{file_hash}"
report_data = await redis.get(report_key)
return json.loads(report_data) if report_data else None
finally:
await redis.aclose()
@staticmethod
async def save_qa_history(file_hash: str, question: str, answer: str):
"""保存问答历史"""
redis = await get_redis()
try:
await redis.select(REDIS_CHAT_DB)
chat_history_key = f"chat_history:{file_hash}"
existing_history = await redis.get(chat_history_key)
history = json.loads(existing_history) if existing_history else []
history.append({
"question": question,
"answer": answer,
"timestamp": datetime.now(timezone.utc).isoformat()
})
await redis.set(chat_history_key, json.dumps(history))
finally:
await redis.aclose()
@staticmethod
async def get_qa_history(file_hash: str):
"""获取问答历史"""
redis = await get_redis()
try:
await redis.select(REDIS_CHAT_DB)
chat_history_key = f"chat_history:{file_hash}"
history_data = await redis.get(chat_history_key)
return json.loads(history_data) if history_data else []
finally:
await redis.aclose()
@staticmethod
async def save_task_status(task_id: str, status_data: dict):
"""保存任务状态"""
redis = await get_redis()
try:
await redis.select(REDIS_TASK_DB)
await redis.hset(f"task:{task_id}", mapping=status_data)
finally:
await redis.aclose()
@staticmethod
async def get_task_status(task_id: str):
"""获取任务状态"""
redis = await get_redis()
try:
await redis.select(REDIS_TASK_DB)
task_data = await redis.hgetall(f"task:{task_id}")
return task_data if task_data else None
finally:
await redis.aclose()
@staticmethod
async def delete_paper(paper_id: str):
"""删除文献"""
db = await get_database()
redis = await get_redis()
try:
paper = await db.papers.find_one({"file_hash": paper_id})
if not paper:
return None
file_hash = paper.get("file_hash")
await db.papers.delete_one({"file_hash": paper_id})
# 删除Redis中的数据
try:
await redis.select(REDIS_REPORT_DB)
report_key = f"paper_report:{file_hash}"
await redis.delete(report_key)
await redis.select(REDIS_CHAT_DB)
chat_history_key = f"chat_history:{paper_id}"
await redis.delete(chat_history_key)
finally:
await redis.aclose()
return {"message": "paper successfully deleted"}
except Exception as e:
raise Exception(f"Failed to delete paper: {str(e)}")
@staticmethod
async def get_papers():
"""获取所有文献列表"""
db = await get_database()
papers = []
async for ref in db.papers.find():
papers.append({
"_id": str(ref["_id"]),
"paper_title": ref["paper_title"],
"upload_time": ref["upload_time"],
"file_hash": ref.get("file_hash")
})
return papers
@staticmethod
async def get_paper(file_hash: str):
"""获取单个文献信息"""
db = await get_database()
return await db.papers.find_one({"file_hash": file_hash})
@staticmethod
async def check_existing_report(file_hash: str):
"""检查是否已存在分析报告"""
redis = await get_redis()
try:
await redis.select(REDIS_REPORT_DB)
report_key = f"paper_report:{file_hash}"
existing_report = await redis.get(report_key)
if existing_report:
report_data = json.loads(existing_report)
if report_data.get("status") == "completed":
return True
return False
finally:
await redis.aclose()
@staticmethod
async def save_analysis_error(file_hash: str, error_message: str):
"""保存分析错误状态"""
redis = await get_redis()
try:
await redis.select(REDIS_REPORT_DB)
report_key = f"paper_report:{file_hash}"
error_status = {
"status": "failed",
"message": str(error_message)
}
await redis.set(report_key, json.dumps(error_status))
finally:
await redis.aclose()
+112 -414
View File
@@ -1,59 +1,23 @@
from fastapi import FastAPI, HTTPException, UploadFile, File, Form, Query from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse from typing import List, Optional
from typing import Dict, List, Optional
from datetime import datetime, timezone from datetime import datetime, timezone
import asyncio import asyncio
from motor.motor_asyncio import AsyncIOMotorClient
from bson import ObjectId from bson import ObjectId
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from redis import asyncio as aioredis
from typing import Any
from contextlib import asynccontextmanager
import json import json
import os
import PyPDF2
import aiohttp import aiohttp
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from openai import OpenAI from contextlib import asynccontextmanager
import hashlib
# Database Configuration from database import (
MONGODB_URL = "mongodb://paper:SYX7cdJNMRbiytra@222.186.10.253:27017/paper" connect_to_mongo, close_mongo_connection,
REDIS_URL = "redis://:Obscura@2024@222.186.10.253:6379" DatabaseOperations as db_ops
)
from api import call_deepseek, call_deepseek_sync
# 在全局范围创建线程池
# Database connection pdf_thread_pool = ThreadPoolExecutor(max_workers=3) # 限制并发PDF处理数量
class Database:
"""数据库连接管理类"""
client: AsyncIOMotorClient = None
db = Database()
async def get_database() -> AsyncIOMotorClient:
return db.client["paper"]
async def connect_to_mongo():
try:
db.client = AsyncIOMotorClient(MONGODB_URL)
# 验证连接
await db.client.admin.command('ping')
print("Successfully connected to MongoDB")
except Exception as e:
print(f"Could not connect to MongoDB: {e}")
raise
async def close_mongo_connection():
db.client.close()
# Redis setup
async def get_redis():
redis = aioredis.from_url(
REDIS_URL,
encoding="utf-8",
decode_responses=True,
)
return redis
@asynccontextmanager @asynccontextmanager
async def lifespan(app: FastAPI): async def lifespan(app: FastAPI):
@@ -76,7 +40,6 @@ app.add_middleware(
expose_headers=["*"] expose_headers=["*"]
) )
# 在其他 Pydantic 模型后添加 # 在其他 Pydantic 模型后添加
class paperModel(BaseModel): class paperModel(BaseModel):
"""文献引用模型""" """文献引用模型"""
@@ -89,101 +52,36 @@ class paperModel(BaseModel):
arbitrary_types_allowed = True arbitrary_types_allowed = True
json_encoders = {ObjectId: str} json_encoders = {ObjectId: str}
# 删除文献 # 删除文献
@app.delete("/paper/delete/{paper_id}") @app.delete("/paper/delete/{paper_id}")
async def delete_paper( async def delete_paper(paper_id: str):
paper_id: str
):
"""删除文献""" """删除文献"""
db = await get_database()
try: try:
# 获取文献信息 result = await db_ops.delete_paper(paper_id)
paper = await db.papers.find_one({"file_hash": paper_id}) if not result:
if not paper:
raise HTTPException(status_code=404, detail="paper not found") raise HTTPException(status_code=404, detail="paper not found")
return result
# 从文件存储服务删除文件
file_hash = paper.get("file_hash")
# 删除数据库记录
await db.papers.delete_one({"file_hash": paper_id})
# 删除Redis中的分析报告(如果存在)
try:
redis = await get_redis()
await redis.select(190)
report_key = f"paper_report:{file_hash}"
await redis.delete(report_key)
# 删除聊天历史(如果存在)
await redis.select(191)
chat_history_key = f"chat_history:{paper_id}"
await redis.delete(chat_history_key)
except Exception as redis_error:
print(f"Warning: Failed to delete Redis keys: {str(redis_error)}")
finally:
try:
await redis.aclose()
except Exception as e:
print(f"Error closing Redis connection: {e}")
return {"message": "paper successfully deleted"}
except Exception as e: except Exception as e:
print(f"Error deleting paper: {str(e)}") raise HTTPException(status_code=500, detail=str(e))
raise HTTPException(status_code=500, detail=f"Failed to delete paper: {str(e)}")
@app.get("/paper/papers") @app.get("/paper/papers")
async def get_papers(): async def get_papers():
"""获取文献列表""" """获取文献列表"""
db = await get_database()
try: try:
papers = [] return await db_ops.get_papers()
async for ref in db.papers.find():
papers.append({
"_id": str(ref["_id"]),
"paper_link": ref["paper_link"],
"paper_title": ref["paper_title"],
"upload_time": ref["upload_time"],
"file_hash": ref.get("file_hash")
})
return papers
except Exception as e: except Exception as e:
print(f"Error getting papers: {str(e)}")
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
@app.get("/paper/report/{file_hash}") @app.get("/paper/report/{file_hash}")
async def get_report( async def get_report(file_hash: str):
file_hash: str """获取文献报告"""
):
"""从 Redis db190 直接通过文件哈希值读取已保存的文献报告"""
redis = await get_redis()
try: try:
# 选择 db190 report = await db_ops.get_paper_report(file_hash)
await redis.select(190) if not report:
report_key = f"paper_report:{file_hash}"
# 获取已保存的报告
existing_report = await redis.get(report_key)
if not existing_report:
raise HTTPException(status_code=404, detail="No saved paper report found") raise HTTPException(status_code=404, detail="No saved paper report found")
return report
# 返回报告
return json.loads(existing_report)
except Exception as e: except Exception as e:
print(f"Error getting paper report: {str(e)}")
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
finally:
try:
await redis.aclose()
except Exception as e:
print(f"Error closing Redis connection: {e}")
class FileUpload(BaseModel): class FileUpload(BaseModel):
filename: str filename: str
@@ -195,9 +93,6 @@ class BatchUploadRequest(BaseModel):
@app.post("/paper/upload") @app.post("/paper/upload")
async def batch_upload(request: BatchUploadRequest): async def batch_upload(request: BatchUploadRequest):
"""批量上传项目相关文献""" """批量上传项目相关文献"""
db = await get_database()
redis = await get_redis()
try: try:
uploaded_papers = [] uploaded_papers = []
papers_to_analyze = [] papers_to_analyze = []
@@ -206,13 +101,12 @@ async def batch_upload(request: BatchUploadRequest):
try: try:
# 1. 创建新记录 # 1. 创建新记录
paper = { paper = {
"paper_link": f"https://files.aiot.ml/pdf/{file.hash}",
"paper_title": file.filename, "paper_title": file.filename,
"upload_time": datetime.now(timezone.utc), "upload_time": datetime.now(timezone.utc),
"file_hash": file.hash "file_hash": file.hash
} }
result = await db.papers.insert_one(paper) result = await db_ops.insert_paper(paper)
paper_info = { paper_info = {
"paper_id": str(result.inserted_id), "paper_id": str(result.inserted_id),
"file_hash": file.hash, "file_hash": file.hash,
@@ -221,13 +115,7 @@ async def batch_upload(request: BatchUploadRequest):
uploaded_papers.append(paper_info) uploaded_papers.append(paper_info)
# 2. 设置分析状态 # 2. 设置分析状态
await redis.select(190) await db_ops.set_paper_processing_status(file.hash)
report_key = f"paper_report:{file.hash}"
initial_status = {
"status": "processing",
"message": "Analysis in progress"
}
await redis.set(report_key, json.dumps(initial_status))
papers_to_analyze.append(paper_info) papers_to_analyze.append(paper_info)
except Exception as file_error: except Exception as file_error:
@@ -254,56 +142,49 @@ async def batch_upload(request: BatchUploadRequest):
except Exception as e: except Exception as e:
print(f"批量上传错误: {str(e)}") print(f"批量上传错误: {str(e)}")
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
finally:
try:
await redis.aclose()
except Exception as e:
print(f"Error closing Redis connection: {e}")
async def get_pdf_content(file_hash: str) -> Optional[str]:
"""从文件存储服务获取PDF内容"""
async with aiohttp.ClientSession() as session:
async with session.get(f'https://files.aiot.ml/pdf/{file_hash}') as response:
if response.status != 200:
print(f"获取PDF内容失败: {response.status}")
return None
pdf_content = await response.json()
if not pdf_content.get('content'):
return None
content = pdf_content['content']
if isinstance(content, list):
content = '\n'.join(content)
return content
async def batch_analysis(papers: List[dict]): async def batch_analysis(papers: List[dict]):
"""批量处理文献分析的后台任务""" """批量处理文献分析的后台任务"""
redis = await get_redis()
# 限制并发数量 # 限制并发数量
semaphore = asyncio.Semaphore(3) semaphore = asyncio.Semaphore(3)
async def single_paper(ref: dict): async def single_paper(ref: dict):
async with semaphore: async with semaphore:
try: try:
paper_id = ref["paper_id"]
file_hash = ref["file_hash"] file_hash = ref["file_hash"]
# 再次检查报告是否存在(以防在开始分析前已经被其他进程分析) # 检查报告是否存在
await redis.select(190) if await db_ops.check_existing_report(file_hash):
report_key = f"paper_report:{file_hash}" print(f"Report already exists for file hash: {file_hash}, skipping analysis")
existing_report = await redis.get(report_key) return
if existing_report:
report_data = json.loads(existing_report) # 获取PDF内容
if report_data.get("status") == "completed": content = await get_pdf_content(file_hash)
print(f"Report already exists for file hash: {file_hash}, skipping analysis") if not content:
return await db_ops.save_analysis_error(file_hash, "Failed to get PDF content")
# 从文件存储服务获取PDF内容 return
async with aiohttp.ClientSession() as session:
async with session.get(f'https://files.aiot.ml/pdf/{file_hash}') as response:
if response.status != 200:
print(f"[Redis] 获取PDF内容失败: {response.status}")
raise Exception(f"Failed to get PDF content for file hash: {file_hash}")
pdf_content = await response.json()
if not pdf_content.get('content'):
raise Exception("No PDF content returned")
print(f"\n开始处理文献 {ref.get('paper_title', '未知标题')}") print(f"\n开始处理文献 {ref.get('paper_title', '未知标题')}")
print(f"文献ID: {paper_id}")
print(f"文件哈希: {file_hash}") print(f"文件哈希: {file_hash}")
# 使用获取到的PDF内容继续处理
content = pdf_content['content']
# 如果content是列表,将其合并为单个字符串
if isinstance(content, list):
content = '\n'.join(content)
# 打印字符数 # 打印字符数
content_length = len(content) content_length = len(content)
print(f"\n=== 步骤2: 内容长度检查 ===") print(f"\n=== 步骤2: 内容长度检查 ===")
@@ -315,19 +196,22 @@ async def batch_analysis(papers: List[dict]):
print(f"文档长度在处理范围内 ({content_length} <= 200000)") print(f"文档长度在处理范围内 ({content_length} <= 200000)")
document_analysis = await analyze_paper(content[:180000]) document_analysis = await analyze_paper(content[:180000])
if not document_analysis: if not document_analysis:
raise Exception("Failed to analyze document") await db_ops.save_analysis_error(file_hash, "Failed to analyze document")
return
print("文档分析完成") print("文档分析完成")
else: else:
print(f"\n=== 步骤3B: 使用分段分析方式 ===") print(f"\n=== 步骤3B: 使用分段分析方式 ===")
print(f"文档超过200000字符 ({content_length} > 200000)") print(f"文档超过200000字符 ({content_length} > 200000)")
analysis_results = await analyze_long_file(content) analysis_results = await analyze_long_file(content)
if not analysis_results: if not analysis_results:
raise Exception("Failed to analyze document in segments") await db_ops.save_analysis_error(file_hash, "Failed to analyze document in segments")
return
print(f"分段分析完成,共分析了 {len(analysis_results)} 个段落") print(f"分段分析完成,共分析了 {len(analysis_results)} 个段落")
document_analysis = await merge_results(analysis_results) document_analysis = await merge_results(analysis_results)
if not document_analysis: if not document_analysis:
raise Exception("Failed to merge analysis results") await db_ops.save_analysis_error(file_hash, "Failed to merge analysis results")
return
# 等待一小段时间避免API限制 # 等待一小段时间避免API限制
await asyncio.sleep(1) await asyncio.sleep(1)
@@ -335,7 +219,8 @@ async def batch_analysis(papers: List[dict]):
# 异步分析文献价值 # 异步分析文献价值
value_evaluation = await paper_value(document_analysis) value_evaluation = await paper_value(document_analysis)
if not value_evaluation: if not value_evaluation:
raise Exception("Failed to evaluate value") await db_ops.save_analysis_error(file_hash, "Failed to evaluate value")
return
print("文献价值分析完成") print("文献价值分析完成")
# 合并结果 # 合并结果
@@ -347,98 +232,26 @@ async def batch_analysis(papers: List[dict]):
} }
# 保存结果 # 保存结果
await redis.select(190) await db_ops.save_paper_report(file_hash, analysis_result)
await redis.set(report_key, json.dumps(analysis_result))
except Exception as e: except Exception as e:
try: print(f"分析文献时出错: {str(e)}")
await redis.select(190) await db_ops.save_analysis_error(file_hash, str(e))
report_key = f"paper_report:{file_hash}"
error_status = {
"status": "failed",
"message": str(e)
}
await redis.set(report_key, json.dumps(error_status))
except Exception as redis_error:
print(f"Error updating Redis status: {redis_error}")
try: try:
# 并发处理所有引用 # 并发处理所有引用
await asyncio.gather( await asyncio.gather(
*(single_paper(ref) for ref in papers) *(single_paper(ref) for ref in papers)
) )
finally: except Exception as e:
try: print(f"批量分析任务出错: {str(e)}")
await redis.aclose()
except Exception as e:
print(f"Error closing Redis connection: {e}")
# 在全局范围创建线程池
pdf_thread_pool = ThreadPoolExecutor(max_workers=3) # 限制并发PDF处理数量
# DeepSeek API Configuration
client = OpenAI(
base_url="https://api.deepseek.com/v1",
api_key="sk-3027fb3c810b4e17985fa397d41250b9"
)
# 创建异步HTTP客户端会话
async def get_aiohttp_session():
return aiohttp.ClientSession(
base_url="https://api.deepseek.com/v1/", # 添加了末尾的斜杠
headers={"Authorization": f"Bearer sk-3027fb3c810b4e17985fa397d41250b9"}
)
async def read_pdf(file_path: str) -> str:
"""在线程池中异步读取PDF"""
def read_pdf():
try:
pdf_content = ""
with open(file_path, 'rb') as file:
pdf_reader = PyPDF2.PdfReader(file)
for page in pdf_reader.pages:
content = page.extract_text()
pdf_content += content
# 根据内容长度返回不同的结果
content_length = len(pdf_content)
print(f"\nPDF原始内容字符数: {content_length}")
if content_length <= 180000:
print("文档长度 ≤ 180000,返回完整内容")
return pdf_content
elif content_length <= 200000:
print("文档长度在 180000-200000 之间,截取前 180000 个字符")
return pdf_content[:180000]
else:
print("文档长度 > 200000,返回完整内容供分段处理")
return pdf_content # 返回完整内容,由调用者处理分段
except Exception as e:
print(f"Error reading PDF: {e}")
return ""
loop = asyncio.get_event_loop()
return await loop.run_in_executor(pdf_thread_pool, read_pdf)
async def call_deepseek(messages: list) -> dict:
"""异步调用DeepSeek API"""
async with await get_aiohttp_session() as session:
async with session.post("/chat/completions", json={
"model": "deepseek-chat",
"messages": messages,
"response_format": {"type": "json_object"}
}) as response:
if response.status == 200:
data = await response.json()
return json.loads(data["choices"][0]["message"]["content"])
else:
raise Exception(f"API调用失败: {await response.text()}")
async def analyze_long_file(content: str) -> List[dict]: async def analyze_long_file(content: str) -> List[dict]:
"""分段分析长文档""" """分段分析长文档"""
# 将内容分成多个段落,每段约60000个字符(估算后约50000 tokens # 将内容分成多个段落,每段约60000个字符(估算后约50000 tokens
segments = [] segments = []
content_length = len(content) content_length = len(content)
segment_size = 60000 # 减小段落大小 segment_size = 180000 # 减小段落大小
print(f"\n开始分段处理,总字符数: {content_length}") print(f"\n开始分段处理,总字符数: {content_length}")
print(f"每段大小: {segment_size} 字符") print(f"每段大小: {segment_size} 字符")
@@ -598,6 +411,7 @@ async def merge_results(results: List[dict]) -> dict:
print("分析结果合并失败") print("分析结果合并失败")
return result return result
async def analyze_paper(content: str): async def analyze_paper(content: str):
"""分析文献的基本信息和内容""" """分析文献的基本信息和内容"""
system_prompt = """ system_prompt = """
@@ -698,20 +512,9 @@ class TaskStatus:
FAILED = "failed" FAILED = "failed"
@app.post("/paper/{file_hash}/qa") @app.post("/paper/{file_hash}/qa")
async def ask_reference_question( async def ask_reference_question(file_hash: str, question: QuestionModel):
file_hash: str,
question: QuestionModel
):
"""Ask questions about the paper (async)""" """Ask questions about the paper (async)"""
db = await get_database()
redis = await get_redis()
try: try:
# Verify paper exists
paper = await db.papers.find_one({"file_hash": file_hash})
if not paper:
raise HTTPException(status_code=404, detail="Paper not found")
# Generate task ID # Generate task ID
task_id = str(ObjectId()) task_id = str(ObjectId())
@@ -719,8 +522,7 @@ async def ask_reference_question(
asyncio.create_task(paper_question( asyncio.create_task(paper_question(
task_id=task_id, task_id=task_id,
file_hash=file_hash, file_hash=file_hash,
question=question.question, question=question.question
paper=paper
)) ))
return { return {
@@ -730,56 +532,36 @@ async def ask_reference_question(
} }
except Exception as e: except Exception as e:
print(f"Error in ask_reference_question: {str(e)}")
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
finally:
try:
await redis.aclose()
except Exception as e:
print(f"Error closing Redis connection: {e}")
async def paper_question(task_id: str, file_hash: str, question: str, paper: dict): async def paper_question(task_id: str, file_hash: str, question: str):
"""Process paper Q&A background task""" """Process paper Q&A background task"""
redis = await get_redis()
try: try:
# Get paper info
paper = await db_ops.get_paper(file_hash)
if not paper:
raise Exception("Paper not found")
# Update task status to processing # Update task status to processing
await redis.select(192) # Use db192 to store task status await db_ops.save_task_status(task_id, {
await redis.hset( "status": TaskStatus.PROCESSING,
f"task:{task_id}", "file_hash": file_hash,
mapping={ "question": question
"status": TaskStatus.PROCESSING, })
"file_hash": file_hash,
"question": question
}
)
# Get analysis report from Redis
await redis.select(190)
report_key = f"paper_report:{file_hash}"
report_data = await redis.get(report_key)
# Get analysis report
report_data = await db_ops.get_paper_report(file_hash)
if not report_data: if not report_data:
raise Exception("Paper analysis report does not exist, please analyze first") raise Exception("Paper analysis report does not exist, please analyze first")
# Get PDF content from file storage service # Get PDF content
async with aiohttp.ClientSession() as session: content = await get_pdf_content(file_hash)
async with session.get(f'https://files.aiot.ml/pdf/{file_hash}') as response: if not content:
if response.status != 200: raise Exception("Failed to get PDF content")
raise Exception(f"Failed to get PDF content: HTTP {response.status}")
# Limit content length
pdf_content = await response.json() MAX_CHARS = 180000
if not pdf_content.get('content'): content = content[:MAX_CHARS]
raise Exception("No PDF content returned")
content = pdf_content['content']
# If content is a list, join it into a single string
if isinstance(content, list):
content = '\n'.join(content)
# Limit content length
MAX_CHARS = 180000
content = content[:MAX_CHARS]
# Build system prompt and user prompt # Build system prompt and user prompt
system_prompt = """ system_prompt = """
@@ -791,7 +573,7 @@ async def paper_question(task_id: str, file_hash: str, question: str, paper: dic
# Build context # Build context
context = { context = {
"Paper Content": content, "Paper Content": content,
"Analysis Report": json.loads(report_data) "Analysis Report": report_data
} }
messages = [ messages = [
@@ -800,94 +582,39 @@ async def paper_question(task_id: str, file_hash: str, question: str, paper: dic
] ]
# Call DeepSeek API for answer # Call DeepSeek API for answer
response = client.chat.completions.create( answer = await call_deepseek_sync(messages)
model="deepseek-chat",
messages=messages
)
answer = response.choices[0].message.content
# Save conversation history to Redis db191 # Save conversation history
await redis.select(191) await db_ops.save_qa_history(file_hash, question, answer)
chat_history_key = f"chat_history:{file_hash}"
# Get existing history
existing_history = await redis.get(chat_history_key)
history = json.loads(existing_history) if existing_history else []
# Add new conversation
history.append({
"question": question,
"answer": answer,
"timestamp": datetime.now(timezone.utc).isoformat()
})
# Save updated history
await redis.set(chat_history_key, json.dumps(history))
# Update task status to completed # Update task status to completed
await redis.select(192) await db_ops.save_task_status(task_id, {
await redis.hset( "status": TaskStatus.COMPLETED,
f"task:{task_id}", "answer": answer,
mapping={ "paper_title": paper.get("paper_title")
"status": TaskStatus.COMPLETED, })
"answer": answer,
"paper_title": paper.get("paper_title")
}
)
except Exception as e: except Exception as e:
print(f"Error processing question: {e}") print(f"Error processing question: {e}")
# Update task status to failed # Update task status to failed
await redis.select(192) await db_ops.save_task_status(task_id, {
await redis.hset( "status": TaskStatus.FAILED,
f"task:{task_id}", "error": str(e)
mapping={ })
"status": TaskStatus.FAILED,
"error": str(e)
}
)
finally:
try:
await redis.aclose()
except Exception as e:
print(f"Error closing Redis connection: {e}")
@app.get("/paper/{file_hash}/qa/history") @app.get("/paper/{file_hash}/qa/history")
async def get_paper_qa_history( async def get_paper_qa_history(file_hash: str):
file_hash: str
):
"""Get paper Q&A history""" """Get paper Q&A history"""
redis = await get_redis()
try: try:
# Get conversation history from Redis db191 return await db_ops.get_qa_history(file_hash)
await redis.select(191)
chat_history_key = f"chat_history:{file_hash}"
history_data = await redis.get(chat_history_key)
if not history_data:
return []
return json.loads(history_data)
except Exception as e: except Exception as e:
print(f"Error getting QA history: {str(e)}")
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
finally:
try:
await redis.aclose()
except Exception as e:
print(f"Error closing Redis connection: {e}")
@app.get("/paper/task/{task_id}") @app.get("/paper/task/{task_id}")
async def get_task_status(task_id: str): async def get_task_status(task_id: str):
"""Get task status and result""" """Get task status and result"""
redis = await get_redis()
try: try:
await redis.select(192) task_data = await db_ops.get_task_status(task_id)
task_data = await redis.hgetall(f"task:{task_id}")
if not task_data: if not task_data:
raise HTTPException(status_code=404, detail="Task not found") raise HTTPException(status_code=404, detail="Task not found")
@@ -896,59 +623,30 @@ async def get_task_status(task_id: str):
"status": task_data.get("status", TaskStatus.PENDING) "status": task_data.get("status", TaskStatus.PENDING)
} }
# If task completed, add results
if task_data.get("status") == TaskStatus.COMPLETED: if task_data.get("status") == TaskStatus.COMPLETED:
response.update({ response.update({
"answer": task_data.get("answer"), "answer": task_data.get("answer"),
"paper_title": task_data.get("paper_title") "paper_title": task_data.get("paper_title")
}) })
# If task failed, add error information
elif task_data.get("status") == TaskStatus.FAILED: elif task_data.get("status") == TaskStatus.FAILED:
response.update({ response.update({
"error": task_data.get("error") "error": task_data.get("error")
}) })
return response return response
except Exception as e: except Exception as e:
print(f"Error getting task status: {str(e)}")
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
finally:
try:
await redis.aclose()
except Exception as e:
print(f"Error closing Redis connection: {e}")
# 添加新的路由,通过哈希值获取报告
@app.get("/paper/check/{file_hash}") @app.get("/paper/check/{file_hash}")
async def check_report( async def check_report(file_hash: str):
file_hash: str """检查文献报告状态"""
):
"""直接通过文件哈希值从 Redis db190 读取已保存的文献报告"""
redis = await get_redis()
try: try:
# 选择 db190 report = await db_ops.get_paper_report(file_hash)
await redis.select(190) if not report:
report_key = f"paper_report:{file_hash}"
# 获取已保存的报告
existing_report = await redis.get(report_key)
if not existing_report:
return {"status": "not_found"} return {"status": "not_found"}
return report
# 返回报告
return json.loads(existing_report)
except Exception as e: except Exception as e:
print(f"Error getting paper report by hash: {str(e)}")
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
finally:
try:
await redis.aclose()
except Exception as e:
print(f"Error closing Redis connection: {e}")
if __name__ == "__main__": if __name__ == "__main__":
import uvicorn import uvicorn