This commit is contained in:
2025-01-19 07:52:50 +00:00
parent 064f56e4b9
commit 989c59b616
5 changed files with 383 additions and 414 deletions
Binary file not shown.
Binary file not shown.
+41
View File
@@ -0,0 +1,41 @@
from openai import OpenAI
import aiohttp
import json
# DeepSeek API Configuration
DEEPSEEK_API_KEY = "sk-3027fb3c810b4e17985fa397d41250b9"
DEEPSEEK_BASE_URL = "https://api.deepseek.com/v1"
client = OpenAI(
base_url=DEEPSEEK_BASE_URL,
api_key=DEEPSEEK_API_KEY
)
async def get_aiohttp_session():
"""获取异步HTTP会话"""
return aiohttp.ClientSession(
base_url=f"{DEEPSEEK_BASE_URL}/",
headers={"Authorization": f"Bearer {DEEPSEEK_API_KEY}"}
)
async def call_deepseek(messages: list) -> dict:
"""异步调用DeepSeek API"""
async with await get_aiohttp_session() as session:
async with session.post("/chat/completions", json={
"model": "deepseek-chat",
"messages": messages,
"response_format": {"type": "json_object"}
}) as response:
if response.status == 200:
data = await response.json()
return json.loads(data["choices"][0]["message"]["content"])
else:
raise Exception(f"API调用失败: {await response.text()}")
async def call_deepseek_sync(messages: list) -> str:
"""同步调用DeepSeek API"""
response = client.chat.completions.create(
model="deepseek-chat",
messages=messages
)
return response.choices[0].message.content
+230
View File
@@ -0,0 +1,230 @@
from motor.motor_asyncio import AsyncIOMotorClient
from redis import asyncio as aioredis
from datetime import datetime, timezone
from bson import ObjectId
import json
from typing import List, Dict, Any, Optional
# 数据库配置
MONGODB_URL = "mongodb://paper:SYX7cdJNMRbiytra@222.186.10.253:27017/paper"
REDIS_URL = "redis://:Obscura@2024@222.186.10.253:6379"
# Redis 数据库索引
REDIS_REPORT_DB = 190 # 文献分析报告
REDIS_CHAT_DB = 191 # 聊天历史
REDIS_TASK_DB = 192 # 任务状态
class Database:
"""数据库连接管理类"""
client: AsyncIOMotorClient = None
db = Database()
async def get_database() -> AsyncIOMotorClient:
return db.client["paper"]
async def connect_to_mongo():
"""连接到MongoDB"""
try:
db.client = AsyncIOMotorClient(MONGODB_URL)
await db.client.admin.command('ping')
print("Successfully connected to MongoDB")
except Exception as e:
print(f"Could not connect to MongoDB: {e}")
raise
async def close_mongo_connection():
"""关闭MongoDB连接"""
db.client.close()
async def get_redis():
"""获取Redis连接"""
redis = aioredis.from_url(
REDIS_URL,
encoding="utf-8",
decode_responses=True,
)
return redis
class DatabaseOperations:
"""数据库操作类"""
@staticmethod
async def insert_paper(paper: dict):
"""插入新文献"""
db = await get_database()
result = await db.papers.insert_one(paper)
return result
@staticmethod
async def set_paper_processing_status(file_hash: str):
"""设置文献处理状态"""
redis = await get_redis()
try:
await redis.select(REDIS_REPORT_DB)
report_key = f"paper_report:{file_hash}"
initial_status = {
"status": "processing",
"message": "Analysis in progress"
}
await redis.set(report_key, json.dumps(initial_status))
finally:
await redis.aclose()
@staticmethod
async def save_paper_report(file_hash: str, report_data: dict):
"""保存文献分析报告"""
redis = await get_redis()
try:
await redis.select(REDIS_REPORT_DB)
report_key = f"paper_report:{file_hash}"
await redis.set(report_key, json.dumps(report_data))
finally:
await redis.aclose()
@staticmethod
async def get_paper_report(file_hash: str):
"""获取文献分析报告"""
redis = await get_redis()
try:
await redis.select(REDIS_REPORT_DB)
report_key = f"paper_report:{file_hash}"
report_data = await redis.get(report_key)
return json.loads(report_data) if report_data else None
finally:
await redis.aclose()
@staticmethod
async def save_qa_history(file_hash: str, question: str, answer: str):
"""保存问答历史"""
redis = await get_redis()
try:
await redis.select(REDIS_CHAT_DB)
chat_history_key = f"chat_history:{file_hash}"
existing_history = await redis.get(chat_history_key)
history = json.loads(existing_history) if existing_history else []
history.append({
"question": question,
"answer": answer,
"timestamp": datetime.now(timezone.utc).isoformat()
})
await redis.set(chat_history_key, json.dumps(history))
finally:
await redis.aclose()
@staticmethod
async def get_qa_history(file_hash: str):
"""获取问答历史"""
redis = await get_redis()
try:
await redis.select(REDIS_CHAT_DB)
chat_history_key = f"chat_history:{file_hash}"
history_data = await redis.get(chat_history_key)
return json.loads(history_data) if history_data else []
finally:
await redis.aclose()
@staticmethod
async def save_task_status(task_id: str, status_data: dict):
"""保存任务状态"""
redis = await get_redis()
try:
await redis.select(REDIS_TASK_DB)
await redis.hset(f"task:{task_id}", mapping=status_data)
finally:
await redis.aclose()
@staticmethod
async def get_task_status(task_id: str):
"""获取任务状态"""
redis = await get_redis()
try:
await redis.select(REDIS_TASK_DB)
task_data = await redis.hgetall(f"task:{task_id}")
return task_data if task_data else None
finally:
await redis.aclose()
@staticmethod
async def delete_paper(paper_id: str):
"""删除文献"""
db = await get_database()
redis = await get_redis()
try:
paper = await db.papers.find_one({"file_hash": paper_id})
if not paper:
return None
file_hash = paper.get("file_hash")
await db.papers.delete_one({"file_hash": paper_id})
# 删除Redis中的数据
try:
await redis.select(REDIS_REPORT_DB)
report_key = f"paper_report:{file_hash}"
await redis.delete(report_key)
await redis.select(REDIS_CHAT_DB)
chat_history_key = f"chat_history:{paper_id}"
await redis.delete(chat_history_key)
finally:
await redis.aclose()
return {"message": "paper successfully deleted"}
except Exception as e:
raise Exception(f"Failed to delete paper: {str(e)}")
@staticmethod
async def get_papers():
"""获取所有文献列表"""
db = await get_database()
papers = []
async for ref in db.papers.find():
papers.append({
"_id": str(ref["_id"]),
"paper_title": ref["paper_title"],
"upload_time": ref["upload_time"],
"file_hash": ref.get("file_hash")
})
return papers
@staticmethod
async def get_paper(file_hash: str):
"""获取单个文献信息"""
db = await get_database()
return await db.papers.find_one({"file_hash": file_hash})
@staticmethod
async def check_existing_report(file_hash: str):
"""检查是否已存在分析报告"""
redis = await get_redis()
try:
await redis.select(REDIS_REPORT_DB)
report_key = f"paper_report:{file_hash}"
existing_report = await redis.get(report_key)
if existing_report:
report_data = json.loads(existing_report)
if report_data.get("status") == "completed":
return True
return False
finally:
await redis.aclose()
@staticmethod
async def save_analysis_error(file_hash: str, error_message: str):
"""保存分析错误状态"""
redis = await get_redis()
try:
await redis.select(REDIS_REPORT_DB)
report_key = f"paper_report:{file_hash}"
error_status = {
"status": "failed",
"message": str(error_message)
}
await redis.set(report_key, json.dumps(error_status))
finally:
await redis.aclose()
+93 -395
View File
@@ -1,59 +1,23 @@
from fastapi import FastAPI, HTTPException, UploadFile, File, Form, Query
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from typing import Dict, List, Optional
from typing import List, Optional
from datetime import datetime, timezone
import asyncio
from motor.motor_asyncio import AsyncIOMotorClient
from bson import ObjectId
from pydantic import BaseModel, Field
from redis import asyncio as aioredis
from typing import Any
from contextlib import asynccontextmanager
import json
import os
import PyPDF2
import aiohttp
from concurrent.futures import ThreadPoolExecutor
from openai import OpenAI
import hashlib
from contextlib import asynccontextmanager
# Database Configuration
MONGODB_URL = "mongodb://paper:SYX7cdJNMRbiytra@222.186.10.253:27017/paper"
REDIS_URL = "redis://:Obscura@2024@222.186.10.253:6379"
from database import (
connect_to_mongo, close_mongo_connection,
DatabaseOperations as db_ops
)
from api import call_deepseek, call_deepseek_sync
# Database connection
class Database:
"""数据库连接管理类"""
client: AsyncIOMotorClient = None
db = Database()
async def get_database() -> AsyncIOMotorClient:
return db.client["paper"]
async def connect_to_mongo():
try:
db.client = AsyncIOMotorClient(MONGODB_URL)
# 验证连接
await db.client.admin.command('ping')
print("Successfully connected to MongoDB")
except Exception as e:
print(f"Could not connect to MongoDB: {e}")
raise
async def close_mongo_connection():
db.client.close()
# Redis setup
async def get_redis():
redis = aioredis.from_url(
REDIS_URL,
encoding="utf-8",
decode_responses=True,
)
return redis
# 在全局范围创建线程池
pdf_thread_pool = ThreadPoolExecutor(max_workers=3) # 限制并发PDF处理数量
@asynccontextmanager
async def lifespan(app: FastAPI):
@@ -76,7 +40,6 @@ app.add_middleware(
expose_headers=["*"]
)
# 在其他 Pydantic 模型后添加
class paperModel(BaseModel):
"""文献引用模型"""
@@ -89,101 +52,36 @@ class paperModel(BaseModel):
arbitrary_types_allowed = True
json_encoders = {ObjectId: str}
# 删除文献
@app.delete("/paper/delete/{paper_id}")
async def delete_paper(
paper_id: str
):
async def delete_paper(paper_id: str):
"""删除文献"""
db = await get_database()
try:
# 获取文献信息
paper = await db.papers.find_one({"file_hash": paper_id})
if not paper:
result = await db_ops.delete_paper(paper_id)
if not result:
raise HTTPException(status_code=404, detail="paper not found")
# 从文件存储服务删除文件
file_hash = paper.get("file_hash")
# 删除数据库记录
await db.papers.delete_one({"file_hash": paper_id})
# 删除Redis中的分析报告(如果存在)
try:
redis = await get_redis()
await redis.select(190)
report_key = f"paper_report:{file_hash}"
await redis.delete(report_key)
# 删除聊天历史(如果存在)
await redis.select(191)
chat_history_key = f"chat_history:{paper_id}"
await redis.delete(chat_history_key)
except Exception as redis_error:
print(f"Warning: Failed to delete Redis keys: {str(redis_error)}")
finally:
try:
await redis.aclose()
return result
except Exception as e:
print(f"Error closing Redis connection: {e}")
return {"message": "paper successfully deleted"}
except Exception as e:
print(f"Error deleting paper: {str(e)}")
raise HTTPException(status_code=500, detail=f"Failed to delete paper: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/paper/papers")
async def get_papers():
"""获取文献列表"""
db = await get_database()
try:
papers = []
async for ref in db.papers.find():
papers.append({
"_id": str(ref["_id"]),
"paper_link": ref["paper_link"],
"paper_title": ref["paper_title"],
"upload_time": ref["upload_time"],
"file_hash": ref.get("file_hash")
})
return papers
return await db_ops.get_papers()
except Exception as e:
print(f"Error getting papers: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/paper/report/{file_hash}")
async def get_report(
file_hash: str
):
"""从 Redis db190 直接通过文件哈希值读取已保存的文献报告"""
redis = await get_redis()
async def get_report(file_hash: str):
"""获取文献报告"""
try:
# 选择 db190
await redis.select(190)
report_key = f"paper_report:{file_hash}"
# 获取已保存的报告
existing_report = await redis.get(report_key)
if not existing_report:
report = await db_ops.get_paper_report(file_hash)
if not report:
raise HTTPException(status_code=404, detail="No saved paper report found")
# 返回报告
return json.loads(existing_report)
return report
except Exception as e:
print(f"Error getting paper report: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
finally:
try:
await redis.aclose()
except Exception as e:
print(f"Error closing Redis connection: {e}")
class FileUpload(BaseModel):
filename: str
@@ -195,9 +93,6 @@ class BatchUploadRequest(BaseModel):
@app.post("/paper/upload")
async def batch_upload(request: BatchUploadRequest):
"""批量上传项目相关文献"""
db = await get_database()
redis = await get_redis()
try:
uploaded_papers = []
papers_to_analyze = []
@@ -206,13 +101,12 @@ async def batch_upload(request: BatchUploadRequest):
try:
# 1. 创建新记录
paper = {
"paper_link": f"https://files.aiot.ml/pdf/{file.hash}",
"paper_title": file.filename,
"upload_time": datetime.now(timezone.utc),
"file_hash": file.hash
}
result = await db.papers.insert_one(paper)
result = await db_ops.insert_paper(paper)
paper_info = {
"paper_id": str(result.inserted_id),
"file_hash": file.hash,
@@ -221,13 +115,7 @@ async def batch_upload(request: BatchUploadRequest):
uploaded_papers.append(paper_info)
# 2. 设置分析状态
await redis.select(190)
report_key = f"paper_report:{file.hash}"
initial_status = {
"status": "processing",
"message": "Analysis in progress"
}
await redis.set(report_key, json.dumps(initial_status))
await db_ops.set_paper_processing_status(file.hash)
papers_to_analyze.append(paper_info)
except Exception as file_error:
@@ -254,56 +142,49 @@ async def batch_upload(request: BatchUploadRequest):
except Exception as e:
print(f"批量上传错误: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
finally:
try:
await redis.aclose()
except Exception as e:
print(f"Error closing Redis connection: {e}")
async def get_pdf_content(file_hash: str) -> Optional[str]:
"""从文件存储服务获取PDF内容"""
async with aiohttp.ClientSession() as session:
async with session.get(f'https://files.aiot.ml/pdf/{file_hash}') as response:
if response.status != 200:
print(f"获取PDF内容失败: {response.status}")
return None
pdf_content = await response.json()
if not pdf_content.get('content'):
return None
content = pdf_content['content']
if isinstance(content, list):
content = '\n'.join(content)
return content
async def batch_analysis(papers: List[dict]):
"""批量处理文献分析的后台任务"""
redis = await get_redis()
# 限制并发数量
semaphore = asyncio.Semaphore(3)
async def single_paper(ref: dict):
async with semaphore:
try:
paper_id = ref["paper_id"]
file_hash = ref["file_hash"]
# 再次检查报告是否存在(以防在开始分析前已经被其他进程分析)
await redis.select(190)
report_key = f"paper_report:{file_hash}"
existing_report = await redis.get(report_key)
if existing_report:
report_data = json.loads(existing_report)
if report_data.get("status") == "completed":
# 检查报告是否存在
if await db_ops.check_existing_report(file_hash):
print(f"Report already exists for file hash: {file_hash}, skipping analysis")
return
# 从文件存储服务获取PDF内容
async with aiohttp.ClientSession() as session:
async with session.get(f'https://files.aiot.ml/pdf/{file_hash}') as response:
if response.status != 200:
print(f"[Redis] 获取PDF内容失败: {response.status}")
raise Exception(f"Failed to get PDF content for file hash: {file_hash}")
pdf_content = await response.json()
if not pdf_content.get('content'):
raise Exception("No PDF content returned")
# 获取PDF内容
content = await get_pdf_content(file_hash)
if not content:
await db_ops.save_analysis_error(file_hash, "Failed to get PDF content")
return
print(f"\n开始处理文献 {ref.get('paper_title', '未知标题')}")
print(f"文献ID: {paper_id}")
print(f"文件哈希: {file_hash}")
# 使用获取到的PDF内容继续处理
content = pdf_content['content']
# 如果content是列表,将其合并为单个字符串
if isinstance(content, list):
content = '\n'.join(content)
# 打印字符数
content_length = len(content)
print(f"\n=== 步骤2: 内容长度检查 ===")
@@ -315,19 +196,22 @@ async def batch_analysis(papers: List[dict]):
print(f"文档长度在处理范围内 ({content_length} <= 200000)")
document_analysis = await analyze_paper(content[:180000])
if not document_analysis:
raise Exception("Failed to analyze document")
await db_ops.save_analysis_error(file_hash, "Failed to analyze document")
return
print("文档分析完成")
else:
print(f"\n=== 步骤3B: 使用分段分析方式 ===")
print(f"文档超过200000字符 ({content_length} > 200000)")
analysis_results = await analyze_long_file(content)
if not analysis_results:
raise Exception("Failed to analyze document in segments")
await db_ops.save_analysis_error(file_hash, "Failed to analyze document in segments")
return
print(f"分段分析完成,共分析了 {len(analysis_results)} 个段落")
document_analysis = await merge_results(analysis_results)
if not document_analysis:
raise Exception("Failed to merge analysis results")
await db_ops.save_analysis_error(file_hash, "Failed to merge analysis results")
return
# 等待一小段时间避免API限制
await asyncio.sleep(1)
@@ -335,7 +219,8 @@ async def batch_analysis(papers: List[dict]):
# 异步分析文献价值
value_evaluation = await paper_value(document_analysis)
if not value_evaluation:
raise Exception("Failed to evaluate value")
await db_ops.save_analysis_error(file_hash, "Failed to evaluate value")
return
print("文献价值分析完成")
# 合并结果
@@ -347,98 +232,26 @@ async def batch_analysis(papers: List[dict]):
}
# 保存结果
await redis.select(190)
await redis.set(report_key, json.dumps(analysis_result))
await db_ops.save_paper_report(file_hash, analysis_result)
except Exception as e:
try:
await redis.select(190)
report_key = f"paper_report:{file_hash}"
error_status = {
"status": "failed",
"message": str(e)
}
await redis.set(report_key, json.dumps(error_status))
except Exception as redis_error:
print(f"Error updating Redis status: {redis_error}")
print(f"分析文献时出错: {str(e)}")
await db_ops.save_analysis_error(file_hash, str(e))
try:
# 并发处理所有引用
await asyncio.gather(
*(single_paper(ref) for ref in papers)
)
finally:
try:
await redis.aclose()
except Exception as e:
print(f"Error closing Redis connection: {e}")
# 在全局范围创建线程池
pdf_thread_pool = ThreadPoolExecutor(max_workers=3) # 限制并发PDF处理数量
# DeepSeek API Configuration
client = OpenAI(
base_url="https://api.deepseek.com/v1",
api_key="sk-3027fb3c810b4e17985fa397d41250b9"
)
# 创建异步HTTP客户端会话
async def get_aiohttp_session():
return aiohttp.ClientSession(
base_url="https://api.deepseek.com/v1/", # 添加了末尾的斜杠
headers={"Authorization": f"Bearer sk-3027fb3c810b4e17985fa397d41250b9"}
)
async def read_pdf(file_path: str) -> str:
"""在线程池中异步读取PDF"""
def read_pdf():
try:
pdf_content = ""
with open(file_path, 'rb') as file:
pdf_reader = PyPDF2.PdfReader(file)
for page in pdf_reader.pages:
content = page.extract_text()
pdf_content += content
# 根据内容长度返回不同的结果
content_length = len(pdf_content)
print(f"\nPDF原始内容字符数: {content_length}")
if content_length <= 180000:
print("文档长度 ≤ 180000,返回完整内容")
return pdf_content
elif content_length <= 200000:
print("文档长度在 180000-200000 之间,截取前 180000 个字符")
return pdf_content[:180000]
else:
print("文档长度 > 200000,返回完整内容供分段处理")
return pdf_content # 返回完整内容,由调用者处理分段
except Exception as e:
print(f"Error reading PDF: {e}")
return ""
loop = asyncio.get_event_loop()
return await loop.run_in_executor(pdf_thread_pool, read_pdf)
async def call_deepseek(messages: list) -> dict:
"""异步调用DeepSeek API"""
async with await get_aiohttp_session() as session:
async with session.post("/chat/completions", json={
"model": "deepseek-chat",
"messages": messages,
"response_format": {"type": "json_object"}
}) as response:
if response.status == 200:
data = await response.json()
return json.loads(data["choices"][0]["message"]["content"])
else:
raise Exception(f"API调用失败: {await response.text()}")
print(f"批量分析任务出错: {str(e)}")
async def analyze_long_file(content: str) -> List[dict]:
"""分段分析长文档"""
# 将内容分成多个段落,每段约60000个字符(估算后约50000 tokens
segments = []
content_length = len(content)
segment_size = 60000 # 减小段落大小
segment_size = 180000 # 减小段落大小
print(f"\n开始分段处理,总字符数: {content_length}")
print(f"每段大小: {segment_size} 字符")
@@ -598,6 +411,7 @@ async def merge_results(results: List[dict]) -> dict:
print("分析结果合并失败")
return result
async def analyze_paper(content: str):
"""分析文献的基本信息和内容"""
system_prompt = """
@@ -698,20 +512,9 @@ class TaskStatus:
FAILED = "failed"
@app.post("/paper/{file_hash}/qa")
async def ask_reference_question(
file_hash: str,
question: QuestionModel
):
async def ask_reference_question(file_hash: str, question: QuestionModel):
"""Ask questions about the paper (async)"""
db = await get_database()
redis = await get_redis()
try:
# Verify paper exists
paper = await db.papers.find_one({"file_hash": file_hash})
if not paper:
raise HTTPException(status_code=404, detail="Paper not found")
# Generate task ID
task_id = str(ObjectId())
@@ -719,8 +522,7 @@ async def ask_reference_question(
asyncio.create_task(paper_question(
task_id=task_id,
file_hash=file_hash,
question=question.question,
paper=paper
question=question.question
))
return {
@@ -730,52 +532,32 @@ async def ask_reference_question(
}
except Exception as e:
print(f"Error in ask_reference_question: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
finally:
try:
await redis.aclose()
except Exception as e:
print(f"Error closing Redis connection: {e}")
async def paper_question(task_id: str, file_hash: str, question: str, paper: dict):
async def paper_question(task_id: str, file_hash: str, question: str):
"""Process paper Q&A background task"""
redis = await get_redis()
try:
# Get paper info
paper = await db_ops.get_paper(file_hash)
if not paper:
raise Exception("Paper not found")
# Update task status to processing
await redis.select(192) # Use db192 to store task status
await redis.hset(
f"task:{task_id}",
mapping={
await db_ops.save_task_status(task_id, {
"status": TaskStatus.PROCESSING,
"file_hash": file_hash,
"question": question
}
)
# Get analysis report from Redis
await redis.select(190)
report_key = f"paper_report:{file_hash}"
report_data = await redis.get(report_key)
})
# Get analysis report
report_data = await db_ops.get_paper_report(file_hash)
if not report_data:
raise Exception("Paper analysis report does not exist, please analyze first")
# Get PDF content from file storage service
async with aiohttp.ClientSession() as session:
async with session.get(f'https://files.aiot.ml/pdf/{file_hash}') as response:
if response.status != 200:
raise Exception(f"Failed to get PDF content: HTTP {response.status}")
pdf_content = await response.json()
if not pdf_content.get('content'):
raise Exception("No PDF content returned")
content = pdf_content['content']
# If content is a list, join it into a single string
if isinstance(content, list):
content = '\n'.join(content)
# Get PDF content
content = await get_pdf_content(file_hash)
if not content:
raise Exception("Failed to get PDF content")
# Limit content length
MAX_CHARS = 180000
@@ -791,7 +573,7 @@ async def paper_question(task_id: str, file_hash: str, question: str, paper: dic
# Build context
context = {
"Paper Content": content,
"Analysis Report": json.loads(report_data)
"Analysis Report": report_data
}
messages = [
@@ -800,94 +582,39 @@ async def paper_question(task_id: str, file_hash: str, question: str, paper: dic
]
# Call DeepSeek API for answer
response = client.chat.completions.create(
model="deepseek-chat",
messages=messages
)
answer = response.choices[0].message.content
answer = await call_deepseek_sync(messages)
# Save conversation history to Redis db191
await redis.select(191)
chat_history_key = f"chat_history:{file_hash}"
# Get existing history
existing_history = await redis.get(chat_history_key)
history = json.loads(existing_history) if existing_history else []
# Add new conversation
history.append({
"question": question,
"answer": answer,
"timestamp": datetime.now(timezone.utc).isoformat()
})
# Save updated history
await redis.set(chat_history_key, json.dumps(history))
# Save conversation history
await db_ops.save_qa_history(file_hash, question, answer)
# Update task status to completed
await redis.select(192)
await redis.hset(
f"task:{task_id}",
mapping={
await db_ops.save_task_status(task_id, {
"status": TaskStatus.COMPLETED,
"answer": answer,
"paper_title": paper.get("paper_title")
}
)
})
except Exception as e:
print(f"Error processing question: {e}")
# Update task status to failed
await redis.select(192)
await redis.hset(
f"task:{task_id}",
mapping={
await db_ops.save_task_status(task_id, {
"status": TaskStatus.FAILED,
"error": str(e)
}
)
finally:
try:
await redis.aclose()
except Exception as e:
print(f"Error closing Redis connection: {e}")
})
@app.get("/paper/{file_hash}/qa/history")
async def get_paper_qa_history(
file_hash: str
):
async def get_paper_qa_history(file_hash: str):
"""Get paper Q&A history"""
redis = await get_redis()
try:
# Get conversation history from Redis db191
await redis.select(191)
chat_history_key = f"chat_history:{file_hash}"
history_data = await redis.get(chat_history_key)
if not history_data:
return []
return json.loads(history_data)
return await db_ops.get_qa_history(file_hash)
except Exception as e:
print(f"Error getting QA history: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
finally:
try:
await redis.aclose()
except Exception as e:
print(f"Error closing Redis connection: {e}")
@app.get("/paper/task/{task_id}")
async def get_task_status(task_id: str):
"""Get task status and result"""
redis = await get_redis()
try:
await redis.select(192)
task_data = await redis.hgetall(f"task:{task_id}")
task_data = await db_ops.get_task_status(task_id)
if not task_data:
raise HTTPException(status_code=404, detail="Task not found")
@@ -896,59 +623,30 @@ async def get_task_status(task_id: str):
"status": task_data.get("status", TaskStatus.PENDING)
}
# If task completed, add results
if task_data.get("status") == TaskStatus.COMPLETED:
response.update({
"answer": task_data.get("answer"),
"paper_title": task_data.get("paper_title")
})
# If task failed, add error information
elif task_data.get("status") == TaskStatus.FAILED:
response.update({
"error": task_data.get("error")
})
return response
except Exception as e:
print(f"Error getting task status: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
finally:
try:
await redis.aclose()
except Exception as e:
print(f"Error closing Redis connection: {e}")
# 添加新的路由,通过哈希值获取报告
@app.get("/paper/check/{file_hash}")
async def check_report(
file_hash: str
):
"""直接通过文件哈希值从 Redis db190 读取已保存的文献报告"""
redis = await get_redis()
async def check_report(file_hash: str):
"""检查文献报告状态"""
try:
# 选择 db190
await redis.select(190)
report_key = f"paper_report:{file_hash}"
# 获取已保存的报告
existing_report = await redis.get(report_key)
if not existing_report:
report = await db_ops.get_paper_report(file_hash)
if not report:
return {"status": "not_found"}
# 返回报告
return json.loads(existing_report)
return report
except Exception as e:
print(f"Error getting paper report by hash: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
finally:
try:
await redis.aclose()
except Exception as e:
print(f"Error closing Redis connection: {e}")
if __name__ == "__main__":
import uvicorn