update
This commit is contained in:
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,41 @@
|
|||||||
|
from openai import OpenAI
|
||||||
|
import aiohttp
|
||||||
|
import json
|
||||||
|
|
||||||
|
# DeepSeek API Configuration
|
||||||
|
DEEPSEEK_API_KEY = "sk-3027fb3c810b4e17985fa397d41250b9"
|
||||||
|
DEEPSEEK_BASE_URL = "https://api.deepseek.com/v1"
|
||||||
|
|
||||||
|
client = OpenAI(
|
||||||
|
base_url=DEEPSEEK_BASE_URL,
|
||||||
|
api_key=DEEPSEEK_API_KEY
|
||||||
|
)
|
||||||
|
|
||||||
|
async def get_aiohttp_session():
|
||||||
|
"""获取异步HTTP会话"""
|
||||||
|
return aiohttp.ClientSession(
|
||||||
|
base_url=f"{DEEPSEEK_BASE_URL}/",
|
||||||
|
headers={"Authorization": f"Bearer {DEEPSEEK_API_KEY}"}
|
||||||
|
)
|
||||||
|
|
||||||
|
async def call_deepseek(messages: list) -> dict:
|
||||||
|
"""异步调用DeepSeek API"""
|
||||||
|
async with await get_aiohttp_session() as session:
|
||||||
|
async with session.post("/chat/completions", json={
|
||||||
|
"model": "deepseek-chat",
|
||||||
|
"messages": messages,
|
||||||
|
"response_format": {"type": "json_object"}
|
||||||
|
}) as response:
|
||||||
|
if response.status == 200:
|
||||||
|
data = await response.json()
|
||||||
|
return json.loads(data["choices"][0]["message"]["content"])
|
||||||
|
else:
|
||||||
|
raise Exception(f"API调用失败: {await response.text()}")
|
||||||
|
|
||||||
|
async def call_deepseek_sync(messages: list) -> str:
|
||||||
|
"""同步调用DeepSeek API"""
|
||||||
|
response = client.chat.completions.create(
|
||||||
|
model="deepseek-chat",
|
||||||
|
messages=messages
|
||||||
|
)
|
||||||
|
return response.choices[0].message.content
|
||||||
+230
@@ -0,0 +1,230 @@
|
|||||||
|
from motor.motor_asyncio import AsyncIOMotorClient
|
||||||
|
from redis import asyncio as aioredis
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from bson import ObjectId
|
||||||
|
import json
|
||||||
|
from typing import List, Dict, Any, Optional
|
||||||
|
|
||||||
|
# 数据库配置
|
||||||
|
MONGODB_URL = "mongodb://paper:SYX7cdJNMRbiytra@222.186.10.253:27017/paper"
|
||||||
|
REDIS_URL = "redis://:Obscura@2024@222.186.10.253:6379"
|
||||||
|
|
||||||
|
# Redis 数据库索引
|
||||||
|
REDIS_REPORT_DB = 190 # 文献分析报告
|
||||||
|
REDIS_CHAT_DB = 191 # 聊天历史
|
||||||
|
REDIS_TASK_DB = 192 # 任务状态
|
||||||
|
|
||||||
|
class Database:
|
||||||
|
"""数据库连接管理类"""
|
||||||
|
client: AsyncIOMotorClient = None
|
||||||
|
|
||||||
|
db = Database()
|
||||||
|
|
||||||
|
async def get_database() -> AsyncIOMotorClient:
|
||||||
|
return db.client["paper"]
|
||||||
|
|
||||||
|
async def connect_to_mongo():
|
||||||
|
"""连接到MongoDB"""
|
||||||
|
try:
|
||||||
|
db.client = AsyncIOMotorClient(MONGODB_URL)
|
||||||
|
await db.client.admin.command('ping')
|
||||||
|
print("Successfully connected to MongoDB")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Could not connect to MongoDB: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def close_mongo_connection():
|
||||||
|
"""关闭MongoDB连接"""
|
||||||
|
db.client.close()
|
||||||
|
|
||||||
|
async def get_redis():
|
||||||
|
"""获取Redis连接"""
|
||||||
|
redis = aioredis.from_url(
|
||||||
|
REDIS_URL,
|
||||||
|
encoding="utf-8",
|
||||||
|
decode_responses=True,
|
||||||
|
)
|
||||||
|
return redis
|
||||||
|
|
||||||
|
class DatabaseOperations:
|
||||||
|
"""数据库操作类"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def insert_paper(paper: dict):
|
||||||
|
"""插入新文献"""
|
||||||
|
db = await get_database()
|
||||||
|
result = await db.papers.insert_one(paper)
|
||||||
|
return result
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def set_paper_processing_status(file_hash: str):
|
||||||
|
"""设置文献处理状态"""
|
||||||
|
redis = await get_redis()
|
||||||
|
try:
|
||||||
|
await redis.select(REDIS_REPORT_DB)
|
||||||
|
report_key = f"paper_report:{file_hash}"
|
||||||
|
initial_status = {
|
||||||
|
"status": "processing",
|
||||||
|
"message": "Analysis in progress"
|
||||||
|
}
|
||||||
|
await redis.set(report_key, json.dumps(initial_status))
|
||||||
|
finally:
|
||||||
|
await redis.aclose()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def save_paper_report(file_hash: str, report_data: dict):
|
||||||
|
"""保存文献分析报告"""
|
||||||
|
redis = await get_redis()
|
||||||
|
try:
|
||||||
|
await redis.select(REDIS_REPORT_DB)
|
||||||
|
report_key = f"paper_report:{file_hash}"
|
||||||
|
await redis.set(report_key, json.dumps(report_data))
|
||||||
|
finally:
|
||||||
|
await redis.aclose()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def get_paper_report(file_hash: str):
|
||||||
|
"""获取文献分析报告"""
|
||||||
|
redis = await get_redis()
|
||||||
|
try:
|
||||||
|
await redis.select(REDIS_REPORT_DB)
|
||||||
|
report_key = f"paper_report:{file_hash}"
|
||||||
|
report_data = await redis.get(report_key)
|
||||||
|
return json.loads(report_data) if report_data else None
|
||||||
|
finally:
|
||||||
|
await redis.aclose()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def save_qa_history(file_hash: str, question: str, answer: str):
|
||||||
|
"""保存问答历史"""
|
||||||
|
redis = await get_redis()
|
||||||
|
try:
|
||||||
|
await redis.select(REDIS_CHAT_DB)
|
||||||
|
chat_history_key = f"chat_history:{file_hash}"
|
||||||
|
|
||||||
|
existing_history = await redis.get(chat_history_key)
|
||||||
|
history = json.loads(existing_history) if existing_history else []
|
||||||
|
|
||||||
|
history.append({
|
||||||
|
"question": question,
|
||||||
|
"answer": answer,
|
||||||
|
"timestamp": datetime.now(timezone.utc).isoformat()
|
||||||
|
})
|
||||||
|
|
||||||
|
await redis.set(chat_history_key, json.dumps(history))
|
||||||
|
finally:
|
||||||
|
await redis.aclose()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def get_qa_history(file_hash: str):
|
||||||
|
"""获取问答历史"""
|
||||||
|
redis = await get_redis()
|
||||||
|
try:
|
||||||
|
await redis.select(REDIS_CHAT_DB)
|
||||||
|
chat_history_key = f"chat_history:{file_hash}"
|
||||||
|
history_data = await redis.get(chat_history_key)
|
||||||
|
return json.loads(history_data) if history_data else []
|
||||||
|
finally:
|
||||||
|
await redis.aclose()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def save_task_status(task_id: str, status_data: dict):
|
||||||
|
"""保存任务状态"""
|
||||||
|
redis = await get_redis()
|
||||||
|
try:
|
||||||
|
await redis.select(REDIS_TASK_DB)
|
||||||
|
await redis.hset(f"task:{task_id}", mapping=status_data)
|
||||||
|
finally:
|
||||||
|
await redis.aclose()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def get_task_status(task_id: str):
|
||||||
|
"""获取任务状态"""
|
||||||
|
redis = await get_redis()
|
||||||
|
try:
|
||||||
|
await redis.select(REDIS_TASK_DB)
|
||||||
|
task_data = await redis.hgetall(f"task:{task_id}")
|
||||||
|
return task_data if task_data else None
|
||||||
|
finally:
|
||||||
|
await redis.aclose()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def delete_paper(paper_id: str):
|
||||||
|
"""删除文献"""
|
||||||
|
db = await get_database()
|
||||||
|
redis = await get_redis()
|
||||||
|
|
||||||
|
try:
|
||||||
|
paper = await db.papers.find_one({"file_hash": paper_id})
|
||||||
|
if not paper:
|
||||||
|
return None
|
||||||
|
|
||||||
|
file_hash = paper.get("file_hash")
|
||||||
|
await db.papers.delete_one({"file_hash": paper_id})
|
||||||
|
|
||||||
|
# 删除Redis中的数据
|
||||||
|
try:
|
||||||
|
await redis.select(REDIS_REPORT_DB)
|
||||||
|
report_key = f"paper_report:{file_hash}"
|
||||||
|
await redis.delete(report_key)
|
||||||
|
|
||||||
|
await redis.select(REDIS_CHAT_DB)
|
||||||
|
chat_history_key = f"chat_history:{paper_id}"
|
||||||
|
await redis.delete(chat_history_key)
|
||||||
|
finally:
|
||||||
|
await redis.aclose()
|
||||||
|
|
||||||
|
return {"message": "paper successfully deleted"}
|
||||||
|
except Exception as e:
|
||||||
|
raise Exception(f"Failed to delete paper: {str(e)}")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def get_papers():
|
||||||
|
"""获取所有文献列表"""
|
||||||
|
db = await get_database()
|
||||||
|
papers = []
|
||||||
|
async for ref in db.papers.find():
|
||||||
|
papers.append({
|
||||||
|
"_id": str(ref["_id"]),
|
||||||
|
"paper_title": ref["paper_title"],
|
||||||
|
"upload_time": ref["upload_time"],
|
||||||
|
"file_hash": ref.get("file_hash")
|
||||||
|
})
|
||||||
|
return papers
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def get_paper(file_hash: str):
|
||||||
|
"""获取单个文献信息"""
|
||||||
|
db = await get_database()
|
||||||
|
return await db.papers.find_one({"file_hash": file_hash})
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def check_existing_report(file_hash: str):
|
||||||
|
"""检查是否已存在分析报告"""
|
||||||
|
redis = await get_redis()
|
||||||
|
try:
|
||||||
|
await redis.select(REDIS_REPORT_DB)
|
||||||
|
report_key = f"paper_report:{file_hash}"
|
||||||
|
existing_report = await redis.get(report_key)
|
||||||
|
if existing_report:
|
||||||
|
report_data = json.loads(existing_report)
|
||||||
|
if report_data.get("status") == "completed":
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
finally:
|
||||||
|
await redis.aclose()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def save_analysis_error(file_hash: str, error_message: str):
|
||||||
|
"""保存分析错误状态"""
|
||||||
|
redis = await get_redis()
|
||||||
|
try:
|
||||||
|
await redis.select(REDIS_REPORT_DB)
|
||||||
|
report_key = f"paper_report:{file_hash}"
|
||||||
|
error_status = {
|
||||||
|
"status": "failed",
|
||||||
|
"message": str(error_message)
|
||||||
|
}
|
||||||
|
await redis.set(report_key, json.dumps(error_status))
|
||||||
|
finally:
|
||||||
|
await redis.aclose()
|
||||||
@@ -1,59 +1,23 @@
|
|||||||
from fastapi import FastAPI, HTTPException, UploadFile, File, Form, Query
|
from fastapi import FastAPI, HTTPException
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from fastapi.responses import JSONResponse
|
from typing import List, Optional
|
||||||
from typing import Dict, List, Optional
|
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
import asyncio
|
import asyncio
|
||||||
from motor.motor_asyncio import AsyncIOMotorClient
|
|
||||||
from bson import ObjectId
|
from bson import ObjectId
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from redis import asyncio as aioredis
|
|
||||||
from typing import Any
|
|
||||||
from contextlib import asynccontextmanager
|
|
||||||
import json
|
import json
|
||||||
import os
|
|
||||||
import PyPDF2
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from openai import OpenAI
|
from contextlib import asynccontextmanager
|
||||||
import hashlib
|
|
||||||
|
|
||||||
# Database Configuration
|
from database import (
|
||||||
MONGODB_URL = "mongodb://paper:SYX7cdJNMRbiytra@222.186.10.253:27017/paper"
|
connect_to_mongo, close_mongo_connection,
|
||||||
REDIS_URL = "redis://:Obscura@2024@222.186.10.253:6379"
|
DatabaseOperations as db_ops
|
||||||
|
)
|
||||||
|
from api import call_deepseek, call_deepseek_sync
|
||||||
|
|
||||||
|
# 在全局范围创建线程池
|
||||||
# Database connection
|
pdf_thread_pool = ThreadPoolExecutor(max_workers=3) # 限制并发PDF处理数量
|
||||||
class Database:
|
|
||||||
"""数据库连接管理类"""
|
|
||||||
client: AsyncIOMotorClient = None
|
|
||||||
|
|
||||||
db = Database()
|
|
||||||
|
|
||||||
async def get_database() -> AsyncIOMotorClient:
|
|
||||||
return db.client["paper"]
|
|
||||||
|
|
||||||
async def connect_to_mongo():
|
|
||||||
try:
|
|
||||||
db.client = AsyncIOMotorClient(MONGODB_URL)
|
|
||||||
# 验证连接
|
|
||||||
await db.client.admin.command('ping')
|
|
||||||
print("Successfully connected to MongoDB")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Could not connect to MongoDB: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
async def close_mongo_connection():
|
|
||||||
db.client.close()
|
|
||||||
|
|
||||||
# Redis setup
|
|
||||||
async def get_redis():
|
|
||||||
redis = aioredis.from_url(
|
|
||||||
REDIS_URL,
|
|
||||||
encoding="utf-8",
|
|
||||||
decode_responses=True,
|
|
||||||
)
|
|
||||||
return redis
|
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def lifespan(app: FastAPI):
|
async def lifespan(app: FastAPI):
|
||||||
@@ -76,7 +40,6 @@ app.add_middleware(
|
|||||||
expose_headers=["*"]
|
expose_headers=["*"]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# 在其他 Pydantic 模型后添加
|
# 在其他 Pydantic 模型后添加
|
||||||
class paperModel(BaseModel):
|
class paperModel(BaseModel):
|
||||||
"""文献引用模型"""
|
"""文献引用模型"""
|
||||||
@@ -89,101 +52,36 @@ class paperModel(BaseModel):
|
|||||||
arbitrary_types_allowed = True
|
arbitrary_types_allowed = True
|
||||||
json_encoders = {ObjectId: str}
|
json_encoders = {ObjectId: str}
|
||||||
|
|
||||||
|
|
||||||
# 删除文献
|
# 删除文献
|
||||||
@app.delete("/paper/delete/{paper_id}")
|
@app.delete("/paper/delete/{paper_id}")
|
||||||
async def delete_paper(
|
async def delete_paper(paper_id: str):
|
||||||
paper_id: str
|
|
||||||
):
|
|
||||||
"""删除文献"""
|
"""删除文献"""
|
||||||
db = await get_database()
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 获取文献信息
|
result = await db_ops.delete_paper(paper_id)
|
||||||
paper = await db.papers.find_one({"file_hash": paper_id})
|
if not result:
|
||||||
if not paper:
|
|
||||||
raise HTTPException(status_code=404, detail="paper not found")
|
raise HTTPException(status_code=404, detail="paper not found")
|
||||||
|
return result
|
||||||
# 从文件存储服务删除文件
|
|
||||||
file_hash = paper.get("file_hash")
|
|
||||||
# 删除数据库记录
|
|
||||||
await db.papers.delete_one({"file_hash": paper_id})
|
|
||||||
|
|
||||||
# 删除Redis中的分析报告(如果存在)
|
|
||||||
try:
|
|
||||||
redis = await get_redis()
|
|
||||||
await redis.select(190)
|
|
||||||
report_key = f"paper_report:{file_hash}"
|
|
||||||
await redis.delete(report_key)
|
|
||||||
|
|
||||||
# 删除聊天历史(如果存在)
|
|
||||||
await redis.select(191)
|
|
||||||
chat_history_key = f"chat_history:{paper_id}"
|
|
||||||
await redis.delete(chat_history_key)
|
|
||||||
except Exception as redis_error:
|
|
||||||
print(f"Warning: Failed to delete Redis keys: {str(redis_error)}")
|
|
||||||
finally:
|
|
||||||
try:
|
|
||||||
await redis.aclose()
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error closing Redis connection: {e}")
|
|
||||||
|
|
||||||
return {"message": "paper successfully deleted"}
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error deleting paper: {str(e)}")
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
raise HTTPException(status_code=500, detail=f"Failed to delete paper: {str(e)}")
|
|
||||||
|
|
||||||
@app.get("/paper/papers")
|
@app.get("/paper/papers")
|
||||||
async def get_papers():
|
async def get_papers():
|
||||||
"""获取文献列表"""
|
"""获取文献列表"""
|
||||||
db = await get_database()
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
papers = []
|
return await db_ops.get_papers()
|
||||||
async for ref in db.papers.find():
|
|
||||||
papers.append({
|
|
||||||
"_id": str(ref["_id"]),
|
|
||||||
"paper_link": ref["paper_link"],
|
|
||||||
"paper_title": ref["paper_title"],
|
|
||||||
"upload_time": ref["upload_time"],
|
|
||||||
"file_hash": ref.get("file_hash")
|
|
||||||
})
|
|
||||||
return papers
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error getting papers: {str(e)}")
|
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
@app.get("/paper/report/{file_hash}")
|
@app.get("/paper/report/{file_hash}")
|
||||||
async def get_report(
|
async def get_report(file_hash: str):
|
||||||
file_hash: str
|
"""获取文献报告"""
|
||||||
):
|
|
||||||
"""从 Redis db190 直接通过文件哈希值读取已保存的文献报告"""
|
|
||||||
redis = await get_redis()
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 选择 db190
|
report = await db_ops.get_paper_report(file_hash)
|
||||||
await redis.select(190)
|
if not report:
|
||||||
report_key = f"paper_report:{file_hash}"
|
|
||||||
|
|
||||||
# 获取已保存的报告
|
|
||||||
existing_report = await redis.get(report_key)
|
|
||||||
|
|
||||||
if not existing_report:
|
|
||||||
raise HTTPException(status_code=404, detail="No saved paper report found")
|
raise HTTPException(status_code=404, detail="No saved paper report found")
|
||||||
|
return report
|
||||||
# 返回报告
|
|
||||||
return json.loads(existing_report)
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error getting paper report: {str(e)}")
|
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
finally:
|
|
||||||
try:
|
|
||||||
await redis.aclose()
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error closing Redis connection: {e}")
|
|
||||||
|
|
||||||
class FileUpload(BaseModel):
|
class FileUpload(BaseModel):
|
||||||
filename: str
|
filename: str
|
||||||
@@ -195,9 +93,6 @@ class BatchUploadRequest(BaseModel):
|
|||||||
@app.post("/paper/upload")
|
@app.post("/paper/upload")
|
||||||
async def batch_upload(request: BatchUploadRequest):
|
async def batch_upload(request: BatchUploadRequest):
|
||||||
"""批量上传项目相关文献"""
|
"""批量上传项目相关文献"""
|
||||||
db = await get_database()
|
|
||||||
redis = await get_redis()
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
uploaded_papers = []
|
uploaded_papers = []
|
||||||
papers_to_analyze = []
|
papers_to_analyze = []
|
||||||
@@ -206,13 +101,12 @@ async def batch_upload(request: BatchUploadRequest):
|
|||||||
try:
|
try:
|
||||||
# 1. 创建新记录
|
# 1. 创建新记录
|
||||||
paper = {
|
paper = {
|
||||||
"paper_link": f"https://files.aiot.ml/pdf/{file.hash}",
|
|
||||||
"paper_title": file.filename,
|
"paper_title": file.filename,
|
||||||
"upload_time": datetime.now(timezone.utc),
|
"upload_time": datetime.now(timezone.utc),
|
||||||
"file_hash": file.hash
|
"file_hash": file.hash
|
||||||
}
|
}
|
||||||
|
|
||||||
result = await db.papers.insert_one(paper)
|
result = await db_ops.insert_paper(paper)
|
||||||
paper_info = {
|
paper_info = {
|
||||||
"paper_id": str(result.inserted_id),
|
"paper_id": str(result.inserted_id),
|
||||||
"file_hash": file.hash,
|
"file_hash": file.hash,
|
||||||
@@ -221,13 +115,7 @@ async def batch_upload(request: BatchUploadRequest):
|
|||||||
uploaded_papers.append(paper_info)
|
uploaded_papers.append(paper_info)
|
||||||
|
|
||||||
# 2. 设置分析状态
|
# 2. 设置分析状态
|
||||||
await redis.select(190)
|
await db_ops.set_paper_processing_status(file.hash)
|
||||||
report_key = f"paper_report:{file.hash}"
|
|
||||||
initial_status = {
|
|
||||||
"status": "processing",
|
|
||||||
"message": "Analysis in progress"
|
|
||||||
}
|
|
||||||
await redis.set(report_key, json.dumps(initial_status))
|
|
||||||
papers_to_analyze.append(paper_info)
|
papers_to_analyze.append(paper_info)
|
||||||
|
|
||||||
except Exception as file_error:
|
except Exception as file_error:
|
||||||
@@ -254,56 +142,49 @@ async def batch_upload(request: BatchUploadRequest):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"批量上传错误: {str(e)}")
|
print(f"批量上传错误: {str(e)}")
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
finally:
|
|
||||||
try:
|
|
||||||
await redis.aclose()
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error closing Redis connection: {e}")
|
|
||||||
|
|
||||||
|
async def get_pdf_content(file_hash: str) -> Optional[str]:
|
||||||
|
"""从文件存储服务获取PDF内容"""
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.get(f'https://files.aiot.ml/pdf/{file_hash}') as response:
|
||||||
|
if response.status != 200:
|
||||||
|
print(f"获取PDF内容失败: {response.status}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
pdf_content = await response.json()
|
||||||
|
if not pdf_content.get('content'):
|
||||||
|
return None
|
||||||
|
|
||||||
|
content = pdf_content['content']
|
||||||
|
if isinstance(content, list):
|
||||||
|
content = '\n'.join(content)
|
||||||
|
|
||||||
|
return content
|
||||||
|
|
||||||
async def batch_analysis(papers: List[dict]):
|
async def batch_analysis(papers: List[dict]):
|
||||||
"""批量处理文献分析的后台任务"""
|
"""批量处理文献分析的后台任务"""
|
||||||
redis = await get_redis()
|
|
||||||
|
|
||||||
# 限制并发数量
|
# 限制并发数量
|
||||||
semaphore = asyncio.Semaphore(3)
|
semaphore = asyncio.Semaphore(3)
|
||||||
|
|
||||||
async def single_paper(ref: dict):
|
async def single_paper(ref: dict):
|
||||||
async with semaphore:
|
async with semaphore:
|
||||||
try:
|
try:
|
||||||
paper_id = ref["paper_id"]
|
|
||||||
file_hash = ref["file_hash"]
|
file_hash = ref["file_hash"]
|
||||||
|
|
||||||
# 再次检查报告是否存在(以防在开始分析前已经被其他进程分析)
|
# 检查报告是否存在
|
||||||
await redis.select(190)
|
if await db_ops.check_existing_report(file_hash):
|
||||||
report_key = f"paper_report:{file_hash}"
|
print(f"Report already exists for file hash: {file_hash}, skipping analysis")
|
||||||
existing_report = await redis.get(report_key)
|
return
|
||||||
if existing_report:
|
|
||||||
report_data = json.loads(existing_report)
|
# 获取PDF内容
|
||||||
if report_data.get("status") == "completed":
|
content = await get_pdf_content(file_hash)
|
||||||
print(f"Report already exists for file hash: {file_hash}, skipping analysis")
|
if not content:
|
||||||
return
|
await db_ops.save_analysis_error(file_hash, "Failed to get PDF content")
|
||||||
# 从文件存储服务获取PDF内容
|
return
|
||||||
async with aiohttp.ClientSession() as session:
|
|
||||||
async with session.get(f'https://files.aiot.ml/pdf/{file_hash}') as response:
|
|
||||||
if response.status != 200:
|
|
||||||
print(f"[Redis] 获取PDF内容失败: {response.status}")
|
|
||||||
raise Exception(f"Failed to get PDF content for file hash: {file_hash}")
|
|
||||||
|
|
||||||
pdf_content = await response.json()
|
|
||||||
if not pdf_content.get('content'):
|
|
||||||
raise Exception("No PDF content returned")
|
|
||||||
|
|
||||||
print(f"\n开始处理文献 {ref.get('paper_title', '未知标题')}")
|
print(f"\n开始处理文献 {ref.get('paper_title', '未知标题')}")
|
||||||
print(f"文献ID: {paper_id}")
|
|
||||||
print(f"文件哈希: {file_hash}")
|
print(f"文件哈希: {file_hash}")
|
||||||
|
|
||||||
# 使用获取到的PDF内容继续处理
|
|
||||||
content = pdf_content['content']
|
|
||||||
# 如果content是列表,将其合并为单个字符串
|
|
||||||
if isinstance(content, list):
|
|
||||||
content = '\n'.join(content)
|
|
||||||
|
|
||||||
# 打印字符数
|
# 打印字符数
|
||||||
content_length = len(content)
|
content_length = len(content)
|
||||||
print(f"\n=== 步骤2: 内容长度检查 ===")
|
print(f"\n=== 步骤2: 内容长度检查 ===")
|
||||||
@@ -315,19 +196,22 @@ async def batch_analysis(papers: List[dict]):
|
|||||||
print(f"文档长度在处理范围内 ({content_length} <= 200000)")
|
print(f"文档长度在处理范围内 ({content_length} <= 200000)")
|
||||||
document_analysis = await analyze_paper(content[:180000])
|
document_analysis = await analyze_paper(content[:180000])
|
||||||
if not document_analysis:
|
if not document_analysis:
|
||||||
raise Exception("Failed to analyze document")
|
await db_ops.save_analysis_error(file_hash, "Failed to analyze document")
|
||||||
|
return
|
||||||
print("文档分析完成")
|
print("文档分析完成")
|
||||||
else:
|
else:
|
||||||
print(f"\n=== 步骤3B: 使用分段分析方式 ===")
|
print(f"\n=== 步骤3B: 使用分段分析方式 ===")
|
||||||
print(f"文档超过200000字符 ({content_length} > 200000)")
|
print(f"文档超过200000字符 ({content_length} > 200000)")
|
||||||
analysis_results = await analyze_long_file(content)
|
analysis_results = await analyze_long_file(content)
|
||||||
if not analysis_results:
|
if not analysis_results:
|
||||||
raise Exception("Failed to analyze document in segments")
|
await db_ops.save_analysis_error(file_hash, "Failed to analyze document in segments")
|
||||||
|
return
|
||||||
print(f"分段分析完成,共分析了 {len(analysis_results)} 个段落")
|
print(f"分段分析完成,共分析了 {len(analysis_results)} 个段落")
|
||||||
|
|
||||||
document_analysis = await merge_results(analysis_results)
|
document_analysis = await merge_results(analysis_results)
|
||||||
if not document_analysis:
|
if not document_analysis:
|
||||||
raise Exception("Failed to merge analysis results")
|
await db_ops.save_analysis_error(file_hash, "Failed to merge analysis results")
|
||||||
|
return
|
||||||
|
|
||||||
# 等待一小段时间避免API限制
|
# 等待一小段时间避免API限制
|
||||||
await asyncio.sleep(1)
|
await asyncio.sleep(1)
|
||||||
@@ -335,7 +219,8 @@ async def batch_analysis(papers: List[dict]):
|
|||||||
# 异步分析文献价值
|
# 异步分析文献价值
|
||||||
value_evaluation = await paper_value(document_analysis)
|
value_evaluation = await paper_value(document_analysis)
|
||||||
if not value_evaluation:
|
if not value_evaluation:
|
||||||
raise Exception("Failed to evaluate value")
|
await db_ops.save_analysis_error(file_hash, "Failed to evaluate value")
|
||||||
|
return
|
||||||
print("文献价值分析完成")
|
print("文献价值分析完成")
|
||||||
|
|
||||||
# 合并结果
|
# 合并结果
|
||||||
@@ -347,98 +232,26 @@ async def batch_analysis(papers: List[dict]):
|
|||||||
}
|
}
|
||||||
|
|
||||||
# 保存结果
|
# 保存结果
|
||||||
await redis.select(190)
|
await db_ops.save_paper_report(file_hash, analysis_result)
|
||||||
await redis.set(report_key, json.dumps(analysis_result))
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
try:
|
print(f"分析文献时出错: {str(e)}")
|
||||||
await redis.select(190)
|
await db_ops.save_analysis_error(file_hash, str(e))
|
||||||
report_key = f"paper_report:{file_hash}"
|
|
||||||
error_status = {
|
|
||||||
"status": "failed",
|
|
||||||
"message": str(e)
|
|
||||||
}
|
|
||||||
await redis.set(report_key, json.dumps(error_status))
|
|
||||||
except Exception as redis_error:
|
|
||||||
print(f"Error updating Redis status: {redis_error}")
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 并发处理所有引用
|
# 并发处理所有引用
|
||||||
await asyncio.gather(
|
await asyncio.gather(
|
||||||
*(single_paper(ref) for ref in papers)
|
*(single_paper(ref) for ref in papers)
|
||||||
)
|
)
|
||||||
finally:
|
except Exception as e:
|
||||||
try:
|
print(f"批量分析任务出错: {str(e)}")
|
||||||
await redis.aclose()
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error closing Redis connection: {e}")
|
|
||||||
|
|
||||||
# 在全局范围创建线程池
|
|
||||||
pdf_thread_pool = ThreadPoolExecutor(max_workers=3) # 限制并发PDF处理数量
|
|
||||||
# DeepSeek API Configuration
|
|
||||||
client = OpenAI(
|
|
||||||
base_url="https://api.deepseek.com/v1",
|
|
||||||
api_key="sk-3027fb3c810b4e17985fa397d41250b9"
|
|
||||||
)
|
|
||||||
# 创建异步HTTP客户端会话
|
|
||||||
async def get_aiohttp_session():
|
|
||||||
return aiohttp.ClientSession(
|
|
||||||
base_url="https://api.deepseek.com/v1/", # 添加了末尾的斜杠
|
|
||||||
headers={"Authorization": f"Bearer sk-3027fb3c810b4e17985fa397d41250b9"}
|
|
||||||
)
|
|
||||||
|
|
||||||
async def read_pdf(file_path: str) -> str:
|
|
||||||
"""在线程池中异步读取PDF"""
|
|
||||||
def read_pdf():
|
|
||||||
try:
|
|
||||||
pdf_content = ""
|
|
||||||
with open(file_path, 'rb') as file:
|
|
||||||
pdf_reader = PyPDF2.PdfReader(file)
|
|
||||||
for page in pdf_reader.pages:
|
|
||||||
content = page.extract_text()
|
|
||||||
pdf_content += content
|
|
||||||
|
|
||||||
# 根据内容长度返回不同的结果
|
|
||||||
content_length = len(pdf_content)
|
|
||||||
print(f"\nPDF原始内容字符数: {content_length}")
|
|
||||||
|
|
||||||
if content_length <= 180000:
|
|
||||||
print("文档长度 ≤ 180000,返回完整内容")
|
|
||||||
return pdf_content
|
|
||||||
elif content_length <= 200000:
|
|
||||||
print("文档长度在 180000-200000 之间,截取前 180000 个字符")
|
|
||||||
return pdf_content[:180000]
|
|
||||||
else:
|
|
||||||
print("文档长度 > 200000,返回完整内容供分段处理")
|
|
||||||
return pdf_content # 返回完整内容,由调用者处理分段
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error reading PDF: {e}")
|
|
||||||
return ""
|
|
||||||
|
|
||||||
loop = asyncio.get_event_loop()
|
|
||||||
return await loop.run_in_executor(pdf_thread_pool, read_pdf)
|
|
||||||
|
|
||||||
async def call_deepseek(messages: list) -> dict:
|
|
||||||
"""异步调用DeepSeek API"""
|
|
||||||
async with await get_aiohttp_session() as session:
|
|
||||||
async with session.post("/chat/completions", json={
|
|
||||||
"model": "deepseek-chat",
|
|
||||||
"messages": messages,
|
|
||||||
"response_format": {"type": "json_object"}
|
|
||||||
}) as response:
|
|
||||||
if response.status == 200:
|
|
||||||
data = await response.json()
|
|
||||||
return json.loads(data["choices"][0]["message"]["content"])
|
|
||||||
else:
|
|
||||||
raise Exception(f"API调用失败: {await response.text()}")
|
|
||||||
|
|
||||||
async def analyze_long_file(content: str) -> List[dict]:
|
async def analyze_long_file(content: str) -> List[dict]:
|
||||||
"""分段分析长文档"""
|
"""分段分析长文档"""
|
||||||
# 将内容分成多个段落,每段约60000个字符(估算后约50000 tokens)
|
# 将内容分成多个段落,每段约60000个字符(估算后约50000 tokens)
|
||||||
segments = []
|
segments = []
|
||||||
content_length = len(content)
|
content_length = len(content)
|
||||||
segment_size = 60000 # 减小段落大小
|
segment_size = 180000 # 减小段落大小
|
||||||
|
|
||||||
print(f"\n开始分段处理,总字符数: {content_length}")
|
print(f"\n开始分段处理,总字符数: {content_length}")
|
||||||
print(f"每段大小: {segment_size} 字符")
|
print(f"每段大小: {segment_size} 字符")
|
||||||
@@ -598,6 +411,7 @@ async def merge_results(results: List[dict]) -> dict:
|
|||||||
print("分析结果合并失败")
|
print("分析结果合并失败")
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
async def analyze_paper(content: str):
|
async def analyze_paper(content: str):
|
||||||
"""分析文献的基本信息和内容"""
|
"""分析文献的基本信息和内容"""
|
||||||
system_prompt = """
|
system_prompt = """
|
||||||
@@ -698,20 +512,9 @@ class TaskStatus:
|
|||||||
FAILED = "failed"
|
FAILED = "failed"
|
||||||
|
|
||||||
@app.post("/paper/{file_hash}/qa")
|
@app.post("/paper/{file_hash}/qa")
|
||||||
async def ask_reference_question(
|
async def ask_reference_question(file_hash: str, question: QuestionModel):
|
||||||
file_hash: str,
|
|
||||||
question: QuestionModel
|
|
||||||
):
|
|
||||||
"""Ask questions about the paper (async)"""
|
"""Ask questions about the paper (async)"""
|
||||||
db = await get_database()
|
|
||||||
redis = await get_redis()
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Verify paper exists
|
|
||||||
paper = await db.papers.find_one({"file_hash": file_hash})
|
|
||||||
if not paper:
|
|
||||||
raise HTTPException(status_code=404, detail="Paper not found")
|
|
||||||
|
|
||||||
# Generate task ID
|
# Generate task ID
|
||||||
task_id = str(ObjectId())
|
task_id = str(ObjectId())
|
||||||
|
|
||||||
@@ -719,8 +522,7 @@ async def ask_reference_question(
|
|||||||
asyncio.create_task(paper_question(
|
asyncio.create_task(paper_question(
|
||||||
task_id=task_id,
|
task_id=task_id,
|
||||||
file_hash=file_hash,
|
file_hash=file_hash,
|
||||||
question=question.question,
|
question=question.question
|
||||||
paper=paper
|
|
||||||
))
|
))
|
||||||
|
|
||||||
return {
|
return {
|
||||||
@@ -730,56 +532,36 @@ async def ask_reference_question(
|
|||||||
}
|
}
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error in ask_reference_question: {str(e)}")
|
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
finally:
|
|
||||||
try:
|
|
||||||
await redis.aclose()
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error closing Redis connection: {e}")
|
|
||||||
|
|
||||||
async def paper_question(task_id: str, file_hash: str, question: str, paper: dict):
|
async def paper_question(task_id: str, file_hash: str, question: str):
|
||||||
"""Process paper Q&A background task"""
|
"""Process paper Q&A background task"""
|
||||||
redis = await get_redis()
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
# Get paper info
|
||||||
|
paper = await db_ops.get_paper(file_hash)
|
||||||
|
if not paper:
|
||||||
|
raise Exception("Paper not found")
|
||||||
|
|
||||||
# Update task status to processing
|
# Update task status to processing
|
||||||
await redis.select(192) # Use db192 to store task status
|
await db_ops.save_task_status(task_id, {
|
||||||
await redis.hset(
|
"status": TaskStatus.PROCESSING,
|
||||||
f"task:{task_id}",
|
"file_hash": file_hash,
|
||||||
mapping={
|
"question": question
|
||||||
"status": TaskStatus.PROCESSING,
|
})
|
||||||
"file_hash": file_hash,
|
|
||||||
"question": question
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Get analysis report from Redis
|
|
||||||
await redis.select(190)
|
|
||||||
report_key = f"paper_report:{file_hash}"
|
|
||||||
report_data = await redis.get(report_key)
|
|
||||||
|
|
||||||
|
# Get analysis report
|
||||||
|
report_data = await db_ops.get_paper_report(file_hash)
|
||||||
if not report_data:
|
if not report_data:
|
||||||
raise Exception("Paper analysis report does not exist, please analyze first")
|
raise Exception("Paper analysis report does not exist, please analyze first")
|
||||||
|
|
||||||
# Get PDF content from file storage service
|
# Get PDF content
|
||||||
async with aiohttp.ClientSession() as session:
|
content = await get_pdf_content(file_hash)
|
||||||
async with session.get(f'https://files.aiot.ml/pdf/{file_hash}') as response:
|
if not content:
|
||||||
if response.status != 200:
|
raise Exception("Failed to get PDF content")
|
||||||
raise Exception(f"Failed to get PDF content: HTTP {response.status}")
|
|
||||||
|
# Limit content length
|
||||||
pdf_content = await response.json()
|
MAX_CHARS = 180000
|
||||||
if not pdf_content.get('content'):
|
content = content[:MAX_CHARS]
|
||||||
raise Exception("No PDF content returned")
|
|
||||||
|
|
||||||
content = pdf_content['content']
|
|
||||||
# If content is a list, join it into a single string
|
|
||||||
if isinstance(content, list):
|
|
||||||
content = '\n'.join(content)
|
|
||||||
|
|
||||||
# Limit content length
|
|
||||||
MAX_CHARS = 180000
|
|
||||||
content = content[:MAX_CHARS]
|
|
||||||
|
|
||||||
# Build system prompt and user prompt
|
# Build system prompt and user prompt
|
||||||
system_prompt = """
|
system_prompt = """
|
||||||
@@ -791,7 +573,7 @@ async def paper_question(task_id: str, file_hash: str, question: str, paper: dic
|
|||||||
# Build context
|
# Build context
|
||||||
context = {
|
context = {
|
||||||
"Paper Content": content,
|
"Paper Content": content,
|
||||||
"Analysis Report": json.loads(report_data)
|
"Analysis Report": report_data
|
||||||
}
|
}
|
||||||
|
|
||||||
messages = [
|
messages = [
|
||||||
@@ -800,94 +582,39 @@ async def paper_question(task_id: str, file_hash: str, question: str, paper: dic
|
|||||||
]
|
]
|
||||||
|
|
||||||
# Call DeepSeek API for answer
|
# Call DeepSeek API for answer
|
||||||
response = client.chat.completions.create(
|
answer = await call_deepseek_sync(messages)
|
||||||
model="deepseek-chat",
|
|
||||||
messages=messages
|
|
||||||
)
|
|
||||||
answer = response.choices[0].message.content
|
|
||||||
|
|
||||||
# Save conversation history to Redis db191
|
# Save conversation history
|
||||||
await redis.select(191)
|
await db_ops.save_qa_history(file_hash, question, answer)
|
||||||
chat_history_key = f"chat_history:{file_hash}"
|
|
||||||
|
|
||||||
# Get existing history
|
|
||||||
existing_history = await redis.get(chat_history_key)
|
|
||||||
history = json.loads(existing_history) if existing_history else []
|
|
||||||
|
|
||||||
# Add new conversation
|
|
||||||
history.append({
|
|
||||||
"question": question,
|
|
||||||
"answer": answer,
|
|
||||||
"timestamp": datetime.now(timezone.utc).isoformat()
|
|
||||||
})
|
|
||||||
|
|
||||||
# Save updated history
|
|
||||||
await redis.set(chat_history_key, json.dumps(history))
|
|
||||||
|
|
||||||
# Update task status to completed
|
# Update task status to completed
|
||||||
await redis.select(192)
|
await db_ops.save_task_status(task_id, {
|
||||||
await redis.hset(
|
"status": TaskStatus.COMPLETED,
|
||||||
f"task:{task_id}",
|
"answer": answer,
|
||||||
mapping={
|
"paper_title": paper.get("paper_title")
|
||||||
"status": TaskStatus.COMPLETED,
|
})
|
||||||
"answer": answer,
|
|
||||||
"paper_title": paper.get("paper_title")
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error processing question: {e}")
|
print(f"Error processing question: {e}")
|
||||||
# Update task status to failed
|
# Update task status to failed
|
||||||
await redis.select(192)
|
await db_ops.save_task_status(task_id, {
|
||||||
await redis.hset(
|
"status": TaskStatus.FAILED,
|
||||||
f"task:{task_id}",
|
"error": str(e)
|
||||||
mapping={
|
})
|
||||||
"status": TaskStatus.FAILED,
|
|
||||||
"error": str(e)
|
|
||||||
}
|
|
||||||
)
|
|
||||||
finally:
|
|
||||||
try:
|
|
||||||
await redis.aclose()
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error closing Redis connection: {e}")
|
|
||||||
|
|
||||||
@app.get("/paper/{file_hash}/qa/history")
|
@app.get("/paper/{file_hash}/qa/history")
|
||||||
async def get_paper_qa_history(
|
async def get_paper_qa_history(file_hash: str):
|
||||||
file_hash: str
|
|
||||||
):
|
|
||||||
"""Get paper Q&A history"""
|
"""Get paper Q&A history"""
|
||||||
redis = await get_redis()
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Get conversation history from Redis db191
|
return await db_ops.get_qa_history(file_hash)
|
||||||
await redis.select(191)
|
|
||||||
chat_history_key = f"chat_history:{file_hash}"
|
|
||||||
|
|
||||||
history_data = await redis.get(chat_history_key)
|
|
||||||
if not history_data:
|
|
||||||
return []
|
|
||||||
|
|
||||||
return json.loads(history_data)
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error getting QA history: {str(e)}")
|
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
finally:
|
|
||||||
try:
|
|
||||||
await redis.aclose()
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error closing Redis connection: {e}")
|
|
||||||
|
|
||||||
@app.get("/paper/task/{task_id}")
|
@app.get("/paper/task/{task_id}")
|
||||||
async def get_task_status(task_id: str):
|
async def get_task_status(task_id: str):
|
||||||
"""Get task status and result"""
|
"""Get task status and result"""
|
||||||
redis = await get_redis()
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await redis.select(192)
|
task_data = await db_ops.get_task_status(task_id)
|
||||||
task_data = await redis.hgetall(f"task:{task_id}")
|
|
||||||
|
|
||||||
if not task_data:
|
if not task_data:
|
||||||
raise HTTPException(status_code=404, detail="Task not found")
|
raise HTTPException(status_code=404, detail="Task not found")
|
||||||
|
|
||||||
@@ -896,59 +623,30 @@ async def get_task_status(task_id: str):
|
|||||||
"status": task_data.get("status", TaskStatus.PENDING)
|
"status": task_data.get("status", TaskStatus.PENDING)
|
||||||
}
|
}
|
||||||
|
|
||||||
# If task completed, add results
|
|
||||||
if task_data.get("status") == TaskStatus.COMPLETED:
|
if task_data.get("status") == TaskStatus.COMPLETED:
|
||||||
response.update({
|
response.update({
|
||||||
"answer": task_data.get("answer"),
|
"answer": task_data.get("answer"),
|
||||||
"paper_title": task_data.get("paper_title")
|
"paper_title": task_data.get("paper_title")
|
||||||
})
|
})
|
||||||
# If task failed, add error information
|
|
||||||
elif task_data.get("status") == TaskStatus.FAILED:
|
elif task_data.get("status") == TaskStatus.FAILED:
|
||||||
response.update({
|
response.update({
|
||||||
"error": task_data.get("error")
|
"error": task_data.get("error")
|
||||||
})
|
})
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error getting task status: {str(e)}")
|
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
finally:
|
|
||||||
try:
|
|
||||||
await redis.aclose()
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error closing Redis connection: {e}")
|
|
||||||
|
|
||||||
# 添加新的路由,通过哈希值获取报告
|
|
||||||
@app.get("/paper/check/{file_hash}")
|
@app.get("/paper/check/{file_hash}")
|
||||||
async def check_report(
|
async def check_report(file_hash: str):
|
||||||
file_hash: str
|
"""检查文献报告状态"""
|
||||||
):
|
|
||||||
"""直接通过文件哈希值从 Redis db190 读取已保存的文献报告"""
|
|
||||||
redis = await get_redis()
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 选择 db190
|
report = await db_ops.get_paper_report(file_hash)
|
||||||
await redis.select(190)
|
if not report:
|
||||||
report_key = f"paper_report:{file_hash}"
|
|
||||||
|
|
||||||
# 获取已保存的报告
|
|
||||||
existing_report = await redis.get(report_key)
|
|
||||||
|
|
||||||
if not existing_report:
|
|
||||||
return {"status": "not_found"}
|
return {"status": "not_found"}
|
||||||
|
return report
|
||||||
# 返回报告
|
|
||||||
return json.loads(existing_report)
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error getting paper report by hash: {str(e)}")
|
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
finally:
|
|
||||||
try:
|
|
||||||
await redis.aclose()
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error closing Redis connection: {e}")
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import uvicorn
|
import uvicorn
|
||||||
|
|||||||
Reference in New Issue
Block a user