update
This commit is contained in:
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,41 @@
|
||||
from openai import OpenAI
|
||||
import aiohttp
|
||||
import json
|
||||
|
||||
# DeepSeek API Configuration
|
||||
DEEPSEEK_API_KEY = "sk-3027fb3c810b4e17985fa397d41250b9"
|
||||
DEEPSEEK_BASE_URL = "https://api.deepseek.com/v1"
|
||||
|
||||
client = OpenAI(
|
||||
base_url=DEEPSEEK_BASE_URL,
|
||||
api_key=DEEPSEEK_API_KEY
|
||||
)
|
||||
|
||||
async def get_aiohttp_session():
|
||||
"""获取异步HTTP会话"""
|
||||
return aiohttp.ClientSession(
|
||||
base_url=f"{DEEPSEEK_BASE_URL}/",
|
||||
headers={"Authorization": f"Bearer {DEEPSEEK_API_KEY}"}
|
||||
)
|
||||
|
||||
async def call_deepseek(messages: list) -> dict:
|
||||
"""异步调用DeepSeek API"""
|
||||
async with await get_aiohttp_session() as session:
|
||||
async with session.post("/chat/completions", json={
|
||||
"model": "deepseek-chat",
|
||||
"messages": messages,
|
||||
"response_format": {"type": "json_object"}
|
||||
}) as response:
|
||||
if response.status == 200:
|
||||
data = await response.json()
|
||||
return json.loads(data["choices"][0]["message"]["content"])
|
||||
else:
|
||||
raise Exception(f"API调用失败: {await response.text()}")
|
||||
|
||||
async def call_deepseek_sync(messages: list) -> str:
|
||||
"""同步调用DeepSeek API"""
|
||||
response = client.chat.completions.create(
|
||||
model="deepseek-chat",
|
||||
messages=messages
|
||||
)
|
||||
return response.choices[0].message.content
|
||||
+230
@@ -0,0 +1,230 @@
|
||||
from motor.motor_asyncio import AsyncIOMotorClient
|
||||
from redis import asyncio as aioredis
|
||||
from datetime import datetime, timezone
|
||||
from bson import ObjectId
|
||||
import json
|
||||
from typing import List, Dict, Any, Optional
|
||||
|
||||
# 数据库配置
|
||||
MONGODB_URL = "mongodb://paper:SYX7cdJNMRbiytra@222.186.10.253:27017/paper"
|
||||
REDIS_URL = "redis://:Obscura@2024@222.186.10.253:6379"
|
||||
|
||||
# Redis 数据库索引
|
||||
REDIS_REPORT_DB = 190 # 文献分析报告
|
||||
REDIS_CHAT_DB = 191 # 聊天历史
|
||||
REDIS_TASK_DB = 192 # 任务状态
|
||||
|
||||
class Database:
|
||||
"""数据库连接管理类"""
|
||||
client: AsyncIOMotorClient = None
|
||||
|
||||
db = Database()
|
||||
|
||||
async def get_database() -> AsyncIOMotorClient:
|
||||
return db.client["paper"]
|
||||
|
||||
async def connect_to_mongo():
|
||||
"""连接到MongoDB"""
|
||||
try:
|
||||
db.client = AsyncIOMotorClient(MONGODB_URL)
|
||||
await db.client.admin.command('ping')
|
||||
print("Successfully connected to MongoDB")
|
||||
except Exception as e:
|
||||
print(f"Could not connect to MongoDB: {e}")
|
||||
raise
|
||||
|
||||
async def close_mongo_connection():
|
||||
"""关闭MongoDB连接"""
|
||||
db.client.close()
|
||||
|
||||
async def get_redis():
|
||||
"""获取Redis连接"""
|
||||
redis = aioredis.from_url(
|
||||
REDIS_URL,
|
||||
encoding="utf-8",
|
||||
decode_responses=True,
|
||||
)
|
||||
return redis
|
||||
|
||||
class DatabaseOperations:
|
||||
"""数据库操作类"""
|
||||
|
||||
@staticmethod
|
||||
async def insert_paper(paper: dict):
|
||||
"""插入新文献"""
|
||||
db = await get_database()
|
||||
result = await db.papers.insert_one(paper)
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
async def set_paper_processing_status(file_hash: str):
|
||||
"""设置文献处理状态"""
|
||||
redis = await get_redis()
|
||||
try:
|
||||
await redis.select(REDIS_REPORT_DB)
|
||||
report_key = f"paper_report:{file_hash}"
|
||||
initial_status = {
|
||||
"status": "processing",
|
||||
"message": "Analysis in progress"
|
||||
}
|
||||
await redis.set(report_key, json.dumps(initial_status))
|
||||
finally:
|
||||
await redis.aclose()
|
||||
|
||||
@staticmethod
|
||||
async def save_paper_report(file_hash: str, report_data: dict):
|
||||
"""保存文献分析报告"""
|
||||
redis = await get_redis()
|
||||
try:
|
||||
await redis.select(REDIS_REPORT_DB)
|
||||
report_key = f"paper_report:{file_hash}"
|
||||
await redis.set(report_key, json.dumps(report_data))
|
||||
finally:
|
||||
await redis.aclose()
|
||||
|
||||
@staticmethod
|
||||
async def get_paper_report(file_hash: str):
|
||||
"""获取文献分析报告"""
|
||||
redis = await get_redis()
|
||||
try:
|
||||
await redis.select(REDIS_REPORT_DB)
|
||||
report_key = f"paper_report:{file_hash}"
|
||||
report_data = await redis.get(report_key)
|
||||
return json.loads(report_data) if report_data else None
|
||||
finally:
|
||||
await redis.aclose()
|
||||
|
||||
@staticmethod
|
||||
async def save_qa_history(file_hash: str, question: str, answer: str):
|
||||
"""保存问答历史"""
|
||||
redis = await get_redis()
|
||||
try:
|
||||
await redis.select(REDIS_CHAT_DB)
|
||||
chat_history_key = f"chat_history:{file_hash}"
|
||||
|
||||
existing_history = await redis.get(chat_history_key)
|
||||
history = json.loads(existing_history) if existing_history else []
|
||||
|
||||
history.append({
|
||||
"question": question,
|
||||
"answer": answer,
|
||||
"timestamp": datetime.now(timezone.utc).isoformat()
|
||||
})
|
||||
|
||||
await redis.set(chat_history_key, json.dumps(history))
|
||||
finally:
|
||||
await redis.aclose()
|
||||
|
||||
@staticmethod
|
||||
async def get_qa_history(file_hash: str):
|
||||
"""获取问答历史"""
|
||||
redis = await get_redis()
|
||||
try:
|
||||
await redis.select(REDIS_CHAT_DB)
|
||||
chat_history_key = f"chat_history:{file_hash}"
|
||||
history_data = await redis.get(chat_history_key)
|
||||
return json.loads(history_data) if history_data else []
|
||||
finally:
|
||||
await redis.aclose()
|
||||
|
||||
@staticmethod
|
||||
async def save_task_status(task_id: str, status_data: dict):
|
||||
"""保存任务状态"""
|
||||
redis = await get_redis()
|
||||
try:
|
||||
await redis.select(REDIS_TASK_DB)
|
||||
await redis.hset(f"task:{task_id}", mapping=status_data)
|
||||
finally:
|
||||
await redis.aclose()
|
||||
|
||||
@staticmethod
|
||||
async def get_task_status(task_id: str):
|
||||
"""获取任务状态"""
|
||||
redis = await get_redis()
|
||||
try:
|
||||
await redis.select(REDIS_TASK_DB)
|
||||
task_data = await redis.hgetall(f"task:{task_id}")
|
||||
return task_data if task_data else None
|
||||
finally:
|
||||
await redis.aclose()
|
||||
|
||||
@staticmethod
|
||||
async def delete_paper(paper_id: str):
|
||||
"""删除文献"""
|
||||
db = await get_database()
|
||||
redis = await get_redis()
|
||||
|
||||
try:
|
||||
paper = await db.papers.find_one({"file_hash": paper_id})
|
||||
if not paper:
|
||||
return None
|
||||
|
||||
file_hash = paper.get("file_hash")
|
||||
await db.papers.delete_one({"file_hash": paper_id})
|
||||
|
||||
# 删除Redis中的数据
|
||||
try:
|
||||
await redis.select(REDIS_REPORT_DB)
|
||||
report_key = f"paper_report:{file_hash}"
|
||||
await redis.delete(report_key)
|
||||
|
||||
await redis.select(REDIS_CHAT_DB)
|
||||
chat_history_key = f"chat_history:{paper_id}"
|
||||
await redis.delete(chat_history_key)
|
||||
finally:
|
||||
await redis.aclose()
|
||||
|
||||
return {"message": "paper successfully deleted"}
|
||||
except Exception as e:
|
||||
raise Exception(f"Failed to delete paper: {str(e)}")
|
||||
|
||||
@staticmethod
|
||||
async def get_papers():
|
||||
"""获取所有文献列表"""
|
||||
db = await get_database()
|
||||
papers = []
|
||||
async for ref in db.papers.find():
|
||||
papers.append({
|
||||
"_id": str(ref["_id"]),
|
||||
"paper_title": ref["paper_title"],
|
||||
"upload_time": ref["upload_time"],
|
||||
"file_hash": ref.get("file_hash")
|
||||
})
|
||||
return papers
|
||||
|
||||
@staticmethod
|
||||
async def get_paper(file_hash: str):
|
||||
"""获取单个文献信息"""
|
||||
db = await get_database()
|
||||
return await db.papers.find_one({"file_hash": file_hash})
|
||||
|
||||
@staticmethod
|
||||
async def check_existing_report(file_hash: str):
|
||||
"""检查是否已存在分析报告"""
|
||||
redis = await get_redis()
|
||||
try:
|
||||
await redis.select(REDIS_REPORT_DB)
|
||||
report_key = f"paper_report:{file_hash}"
|
||||
existing_report = await redis.get(report_key)
|
||||
if existing_report:
|
||||
report_data = json.loads(existing_report)
|
||||
if report_data.get("status") == "completed":
|
||||
return True
|
||||
return False
|
||||
finally:
|
||||
await redis.aclose()
|
||||
|
||||
@staticmethod
|
||||
async def save_analysis_error(file_hash: str, error_message: str):
|
||||
"""保存分析错误状态"""
|
||||
redis = await get_redis()
|
||||
try:
|
||||
await redis.select(REDIS_REPORT_DB)
|
||||
report_key = f"paper_report:{file_hash}"
|
||||
error_status = {
|
||||
"status": "failed",
|
||||
"message": str(error_message)
|
||||
}
|
||||
await redis.set(report_key, json.dumps(error_status))
|
||||
finally:
|
||||
await redis.aclose()
|
||||
@@ -1,59 +1,23 @@
|
||||
from fastapi import FastAPI, HTTPException, UploadFile, File, Form, Query
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse
|
||||
from typing import Dict, List, Optional
|
||||
from typing import List, Optional
|
||||
from datetime import datetime, timezone
|
||||
import asyncio
|
||||
from motor.motor_asyncio import AsyncIOMotorClient
|
||||
from bson import ObjectId
|
||||
from pydantic import BaseModel, Field
|
||||
from redis import asyncio as aioredis
|
||||
from typing import Any
|
||||
from contextlib import asynccontextmanager
|
||||
import json
|
||||
import os
|
||||
import PyPDF2
|
||||
import aiohttp
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from openai import OpenAI
|
||||
import hashlib
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
# Database Configuration
|
||||
MONGODB_URL = "mongodb://paper:SYX7cdJNMRbiytra@222.186.10.253:27017/paper"
|
||||
REDIS_URL = "redis://:Obscura@2024@222.186.10.253:6379"
|
||||
from database import (
|
||||
connect_to_mongo, close_mongo_connection,
|
||||
DatabaseOperations as db_ops
|
||||
)
|
||||
from api import call_deepseek, call_deepseek_sync
|
||||
|
||||
|
||||
# Database connection
|
||||
class Database:
|
||||
"""数据库连接管理类"""
|
||||
client: AsyncIOMotorClient = None
|
||||
|
||||
db = Database()
|
||||
|
||||
async def get_database() -> AsyncIOMotorClient:
|
||||
return db.client["paper"]
|
||||
|
||||
async def connect_to_mongo():
|
||||
try:
|
||||
db.client = AsyncIOMotorClient(MONGODB_URL)
|
||||
# 验证连接
|
||||
await db.client.admin.command('ping')
|
||||
print("Successfully connected to MongoDB")
|
||||
except Exception as e:
|
||||
print(f"Could not connect to MongoDB: {e}")
|
||||
raise
|
||||
|
||||
async def close_mongo_connection():
|
||||
db.client.close()
|
||||
|
||||
# Redis setup
|
||||
async def get_redis():
|
||||
redis = aioredis.from_url(
|
||||
REDIS_URL,
|
||||
encoding="utf-8",
|
||||
decode_responses=True,
|
||||
)
|
||||
return redis
|
||||
# 在全局范围创建线程池
|
||||
pdf_thread_pool = ThreadPoolExecutor(max_workers=3) # 限制并发PDF处理数量
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
@@ -76,7 +40,6 @@ app.add_middleware(
|
||||
expose_headers=["*"]
|
||||
)
|
||||
|
||||
|
||||
# 在其他 Pydantic 模型后添加
|
||||
class paperModel(BaseModel):
|
||||
"""文献引用模型"""
|
||||
@@ -89,101 +52,36 @@ class paperModel(BaseModel):
|
||||
arbitrary_types_allowed = True
|
||||
json_encoders = {ObjectId: str}
|
||||
|
||||
|
||||
# 删除文献
|
||||
@app.delete("/paper/delete/{paper_id}")
|
||||
async def delete_paper(
|
||||
paper_id: str
|
||||
):
|
||||
async def delete_paper(paper_id: str):
|
||||
"""删除文献"""
|
||||
db = await get_database()
|
||||
|
||||
try:
|
||||
# 获取文献信息
|
||||
paper = await db.papers.find_one({"file_hash": paper_id})
|
||||
if not paper:
|
||||
result = await db_ops.delete_paper(paper_id)
|
||||
if not result:
|
||||
raise HTTPException(status_code=404, detail="paper not found")
|
||||
|
||||
# 从文件存储服务删除文件
|
||||
file_hash = paper.get("file_hash")
|
||||
# 删除数据库记录
|
||||
await db.papers.delete_one({"file_hash": paper_id})
|
||||
|
||||
# 删除Redis中的分析报告(如果存在)
|
||||
try:
|
||||
redis = await get_redis()
|
||||
await redis.select(190)
|
||||
report_key = f"paper_report:{file_hash}"
|
||||
await redis.delete(report_key)
|
||||
|
||||
# 删除聊天历史(如果存在)
|
||||
await redis.select(191)
|
||||
chat_history_key = f"chat_history:{paper_id}"
|
||||
await redis.delete(chat_history_key)
|
||||
except Exception as redis_error:
|
||||
print(f"Warning: Failed to delete Redis keys: {str(redis_error)}")
|
||||
finally:
|
||||
try:
|
||||
await redis.aclose()
|
||||
except Exception as e:
|
||||
print(f"Error closing Redis connection: {e}")
|
||||
|
||||
return {"message": "paper successfully deleted"}
|
||||
|
||||
return result
|
||||
except Exception as e:
|
||||
print(f"Error deleting paper: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"Failed to delete paper: {str(e)}")
|
||||
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.get("/paper/papers")
|
||||
async def get_papers():
|
||||
"""获取文献列表"""
|
||||
db = await get_database()
|
||||
|
||||
try:
|
||||
papers = []
|
||||
async for ref in db.papers.find():
|
||||
papers.append({
|
||||
"_id": str(ref["_id"]),
|
||||
"paper_link": ref["paper_link"],
|
||||
"paper_title": ref["paper_title"],
|
||||
"upload_time": ref["upload_time"],
|
||||
"file_hash": ref.get("file_hash")
|
||||
})
|
||||
return papers
|
||||
|
||||
return await db_ops.get_papers()
|
||||
except Exception as e:
|
||||
print(f"Error getting papers: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@app.get("/paper/report/{file_hash}")
|
||||
async def get_report(
|
||||
file_hash: str
|
||||
):
|
||||
"""从 Redis db190 直接通过文件哈希值读取已保存的文献报告"""
|
||||
redis = await get_redis()
|
||||
|
||||
async def get_report(file_hash: str):
|
||||
"""获取文献报告"""
|
||||
try:
|
||||
# 选择 db190
|
||||
await redis.select(190)
|
||||
report_key = f"paper_report:{file_hash}"
|
||||
|
||||
# 获取已保存的报告
|
||||
existing_report = await redis.get(report_key)
|
||||
|
||||
if not existing_report:
|
||||
report = await db_ops.get_paper_report(file_hash)
|
||||
if not report:
|
||||
raise HTTPException(status_code=404, detail="No saved paper report found")
|
||||
|
||||
# 返回报告
|
||||
return json.loads(existing_report)
|
||||
|
||||
return report
|
||||
except Exception as e:
|
||||
print(f"Error getting paper report: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
finally:
|
||||
try:
|
||||
await redis.aclose()
|
||||
except Exception as e:
|
||||
print(f"Error closing Redis connection: {e}")
|
||||
|
||||
class FileUpload(BaseModel):
|
||||
filename: str
|
||||
@@ -195,9 +93,6 @@ class BatchUploadRequest(BaseModel):
|
||||
@app.post("/paper/upload")
|
||||
async def batch_upload(request: BatchUploadRequest):
|
||||
"""批量上传项目相关文献"""
|
||||
db = await get_database()
|
||||
redis = await get_redis()
|
||||
|
||||
try:
|
||||
uploaded_papers = []
|
||||
papers_to_analyze = []
|
||||
@@ -206,13 +101,12 @@ async def batch_upload(request: BatchUploadRequest):
|
||||
try:
|
||||
# 1. 创建新记录
|
||||
paper = {
|
||||
"paper_link": f"https://files.aiot.ml/pdf/{file.hash}",
|
||||
"paper_title": file.filename,
|
||||
"upload_time": datetime.now(timezone.utc),
|
||||
"file_hash": file.hash
|
||||
}
|
||||
|
||||
result = await db.papers.insert_one(paper)
|
||||
result = await db_ops.insert_paper(paper)
|
||||
paper_info = {
|
||||
"paper_id": str(result.inserted_id),
|
||||
"file_hash": file.hash,
|
||||
@@ -221,13 +115,7 @@ async def batch_upload(request: BatchUploadRequest):
|
||||
uploaded_papers.append(paper_info)
|
||||
|
||||
# 2. 设置分析状态
|
||||
await redis.select(190)
|
||||
report_key = f"paper_report:{file.hash}"
|
||||
initial_status = {
|
||||
"status": "processing",
|
||||
"message": "Analysis in progress"
|
||||
}
|
||||
await redis.set(report_key, json.dumps(initial_status))
|
||||
await db_ops.set_paper_processing_status(file.hash)
|
||||
papers_to_analyze.append(paper_info)
|
||||
|
||||
except Exception as file_error:
|
||||
@@ -254,56 +142,49 @@ async def batch_upload(request: BatchUploadRequest):
|
||||
except Exception as e:
|
||||
print(f"批量上传错误: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
finally:
|
||||
try:
|
||||
await redis.aclose()
|
||||
except Exception as e:
|
||||
print(f"Error closing Redis connection: {e}")
|
||||
|
||||
async def get_pdf_content(file_hash: str) -> Optional[str]:
|
||||
"""从文件存储服务获取PDF内容"""
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(f'https://files.aiot.ml/pdf/{file_hash}') as response:
|
||||
if response.status != 200:
|
||||
print(f"获取PDF内容失败: {response.status}")
|
||||
return None
|
||||
|
||||
pdf_content = await response.json()
|
||||
if not pdf_content.get('content'):
|
||||
return None
|
||||
|
||||
content = pdf_content['content']
|
||||
if isinstance(content, list):
|
||||
content = '\n'.join(content)
|
||||
|
||||
return content
|
||||
|
||||
async def batch_analysis(papers: List[dict]):
|
||||
"""批量处理文献分析的后台任务"""
|
||||
redis = await get_redis()
|
||||
|
||||
# 限制并发数量
|
||||
semaphore = asyncio.Semaphore(3)
|
||||
|
||||
async def single_paper(ref: dict):
|
||||
async with semaphore:
|
||||
try:
|
||||
paper_id = ref["paper_id"]
|
||||
file_hash = ref["file_hash"]
|
||||
|
||||
# 再次检查报告是否存在(以防在开始分析前已经被其他进程分析)
|
||||
await redis.select(190)
|
||||
report_key = f"paper_report:{file_hash}"
|
||||
existing_report = await redis.get(report_key)
|
||||
if existing_report:
|
||||
report_data = json.loads(existing_report)
|
||||
if report_data.get("status") == "completed":
|
||||
print(f"Report already exists for file hash: {file_hash}, skipping analysis")
|
||||
return
|
||||
# 从文件存储服务获取PDF内容
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(f'https://files.aiot.ml/pdf/{file_hash}') as response:
|
||||
if response.status != 200:
|
||||
print(f"[Redis] 获取PDF内容失败: {response.status}")
|
||||
raise Exception(f"Failed to get PDF content for file hash: {file_hash}")
|
||||
|
||||
pdf_content = await response.json()
|
||||
if not pdf_content.get('content'):
|
||||
raise Exception("No PDF content returned")
|
||||
|
||||
# 检查报告是否存在
|
||||
if await db_ops.check_existing_report(file_hash):
|
||||
print(f"Report already exists for file hash: {file_hash}, skipping analysis")
|
||||
return
|
||||
|
||||
# 获取PDF内容
|
||||
content = await get_pdf_content(file_hash)
|
||||
if not content:
|
||||
await db_ops.save_analysis_error(file_hash, "Failed to get PDF content")
|
||||
return
|
||||
|
||||
print(f"\n开始处理文献 {ref.get('paper_title', '未知标题')}")
|
||||
print(f"文献ID: {paper_id}")
|
||||
print(f"文件哈希: {file_hash}")
|
||||
|
||||
# 使用获取到的PDF内容继续处理
|
||||
content = pdf_content['content']
|
||||
# 如果content是列表,将其合并为单个字符串
|
||||
if isinstance(content, list):
|
||||
content = '\n'.join(content)
|
||||
|
||||
# 打印字符数
|
||||
content_length = len(content)
|
||||
print(f"\n=== 步骤2: 内容长度检查 ===")
|
||||
@@ -315,19 +196,22 @@ async def batch_analysis(papers: List[dict]):
|
||||
print(f"文档长度在处理范围内 ({content_length} <= 200000)")
|
||||
document_analysis = await analyze_paper(content[:180000])
|
||||
if not document_analysis:
|
||||
raise Exception("Failed to analyze document")
|
||||
await db_ops.save_analysis_error(file_hash, "Failed to analyze document")
|
||||
return
|
||||
print("文档分析完成")
|
||||
else:
|
||||
print(f"\n=== 步骤3B: 使用分段分析方式 ===")
|
||||
print(f"文档超过200000字符 ({content_length} > 200000)")
|
||||
analysis_results = await analyze_long_file(content)
|
||||
if not analysis_results:
|
||||
raise Exception("Failed to analyze document in segments")
|
||||
await db_ops.save_analysis_error(file_hash, "Failed to analyze document in segments")
|
||||
return
|
||||
print(f"分段分析完成,共分析了 {len(analysis_results)} 个段落")
|
||||
|
||||
document_analysis = await merge_results(analysis_results)
|
||||
if not document_analysis:
|
||||
raise Exception("Failed to merge analysis results")
|
||||
await db_ops.save_analysis_error(file_hash, "Failed to merge analysis results")
|
||||
return
|
||||
|
||||
# 等待一小段时间避免API限制
|
||||
await asyncio.sleep(1)
|
||||
@@ -335,7 +219,8 @@ async def batch_analysis(papers: List[dict]):
|
||||
# 异步分析文献价值
|
||||
value_evaluation = await paper_value(document_analysis)
|
||||
if not value_evaluation:
|
||||
raise Exception("Failed to evaluate value")
|
||||
await db_ops.save_analysis_error(file_hash, "Failed to evaluate value")
|
||||
return
|
||||
print("文献价值分析完成")
|
||||
|
||||
# 合并结果
|
||||
@@ -347,98 +232,26 @@ async def batch_analysis(papers: List[dict]):
|
||||
}
|
||||
|
||||
# 保存结果
|
||||
await redis.select(190)
|
||||
await redis.set(report_key, json.dumps(analysis_result))
|
||||
await db_ops.save_paper_report(file_hash, analysis_result)
|
||||
|
||||
except Exception as e:
|
||||
try:
|
||||
await redis.select(190)
|
||||
report_key = f"paper_report:{file_hash}"
|
||||
error_status = {
|
||||
"status": "failed",
|
||||
"message": str(e)
|
||||
}
|
||||
await redis.set(report_key, json.dumps(error_status))
|
||||
except Exception as redis_error:
|
||||
print(f"Error updating Redis status: {redis_error}")
|
||||
print(f"分析文献时出错: {str(e)}")
|
||||
await db_ops.save_analysis_error(file_hash, str(e))
|
||||
|
||||
try:
|
||||
# 并发处理所有引用
|
||||
await asyncio.gather(
|
||||
*(single_paper(ref) for ref in papers)
|
||||
)
|
||||
finally:
|
||||
try:
|
||||
await redis.aclose()
|
||||
except Exception as e:
|
||||
print(f"Error closing Redis connection: {e}")
|
||||
|
||||
# 在全局范围创建线程池
|
||||
pdf_thread_pool = ThreadPoolExecutor(max_workers=3) # 限制并发PDF处理数量
|
||||
# DeepSeek API Configuration
|
||||
client = OpenAI(
|
||||
base_url="https://api.deepseek.com/v1",
|
||||
api_key="sk-3027fb3c810b4e17985fa397d41250b9"
|
||||
)
|
||||
# 创建异步HTTP客户端会话
|
||||
async def get_aiohttp_session():
|
||||
return aiohttp.ClientSession(
|
||||
base_url="https://api.deepseek.com/v1/", # 添加了末尾的斜杠
|
||||
headers={"Authorization": f"Bearer sk-3027fb3c810b4e17985fa397d41250b9"}
|
||||
)
|
||||
|
||||
async def read_pdf(file_path: str) -> str:
|
||||
"""在线程池中异步读取PDF"""
|
||||
def read_pdf():
|
||||
try:
|
||||
pdf_content = ""
|
||||
with open(file_path, 'rb') as file:
|
||||
pdf_reader = PyPDF2.PdfReader(file)
|
||||
for page in pdf_reader.pages:
|
||||
content = page.extract_text()
|
||||
pdf_content += content
|
||||
|
||||
# 根据内容长度返回不同的结果
|
||||
content_length = len(pdf_content)
|
||||
print(f"\nPDF原始内容字符数: {content_length}")
|
||||
|
||||
if content_length <= 180000:
|
||||
print("文档长度 ≤ 180000,返回完整内容")
|
||||
return pdf_content
|
||||
elif content_length <= 200000:
|
||||
print("文档长度在 180000-200000 之间,截取前 180000 个字符")
|
||||
return pdf_content[:180000]
|
||||
else:
|
||||
print("文档长度 > 200000,返回完整内容供分段处理")
|
||||
return pdf_content # 返回完整内容,由调用者处理分段
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error reading PDF: {e}")
|
||||
return ""
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(pdf_thread_pool, read_pdf)
|
||||
|
||||
async def call_deepseek(messages: list) -> dict:
|
||||
"""异步调用DeepSeek API"""
|
||||
async with await get_aiohttp_session() as session:
|
||||
async with session.post("/chat/completions", json={
|
||||
"model": "deepseek-chat",
|
||||
"messages": messages,
|
||||
"response_format": {"type": "json_object"}
|
||||
}) as response:
|
||||
if response.status == 200:
|
||||
data = await response.json()
|
||||
return json.loads(data["choices"][0]["message"]["content"])
|
||||
else:
|
||||
raise Exception(f"API调用失败: {await response.text()}")
|
||||
except Exception as e:
|
||||
print(f"批量分析任务出错: {str(e)}")
|
||||
|
||||
async def analyze_long_file(content: str) -> List[dict]:
|
||||
"""分段分析长文档"""
|
||||
# 将内容分成多个段落,每段约60000个字符(估算后约50000 tokens)
|
||||
segments = []
|
||||
content_length = len(content)
|
||||
segment_size = 60000 # 减小段落大小
|
||||
segment_size = 180000 # 减小段落大小
|
||||
|
||||
print(f"\n开始分段处理,总字符数: {content_length}")
|
||||
print(f"每段大小: {segment_size} 字符")
|
||||
@@ -598,6 +411,7 @@ async def merge_results(results: List[dict]) -> dict:
|
||||
print("分析结果合并失败")
|
||||
|
||||
return result
|
||||
|
||||
async def analyze_paper(content: str):
|
||||
"""分析文献的基本信息和内容"""
|
||||
system_prompt = """
|
||||
@@ -698,20 +512,9 @@ class TaskStatus:
|
||||
FAILED = "failed"
|
||||
|
||||
@app.post("/paper/{file_hash}/qa")
|
||||
async def ask_reference_question(
|
||||
file_hash: str,
|
||||
question: QuestionModel
|
||||
):
|
||||
async def ask_reference_question(file_hash: str, question: QuestionModel):
|
||||
"""Ask questions about the paper (async)"""
|
||||
db = await get_database()
|
||||
redis = await get_redis()
|
||||
|
||||
try:
|
||||
# Verify paper exists
|
||||
paper = await db.papers.find_one({"file_hash": file_hash})
|
||||
if not paper:
|
||||
raise HTTPException(status_code=404, detail="Paper not found")
|
||||
|
||||
# Generate task ID
|
||||
task_id = str(ObjectId())
|
||||
|
||||
@@ -719,8 +522,7 @@ async def ask_reference_question(
|
||||
asyncio.create_task(paper_question(
|
||||
task_id=task_id,
|
||||
file_hash=file_hash,
|
||||
question=question.question,
|
||||
paper=paper
|
||||
question=question.question
|
||||
))
|
||||
|
||||
return {
|
||||
@@ -730,56 +532,36 @@ async def ask_reference_question(
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error in ask_reference_question: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
finally:
|
||||
try:
|
||||
await redis.aclose()
|
||||
except Exception as e:
|
||||
print(f"Error closing Redis connection: {e}")
|
||||
|
||||
async def paper_question(task_id: str, file_hash: str, question: str, paper: dict):
|
||||
async def paper_question(task_id: str, file_hash: str, question: str):
|
||||
"""Process paper Q&A background task"""
|
||||
redis = await get_redis()
|
||||
|
||||
try:
|
||||
# Get paper info
|
||||
paper = await db_ops.get_paper(file_hash)
|
||||
if not paper:
|
||||
raise Exception("Paper not found")
|
||||
|
||||
# Update task status to processing
|
||||
await redis.select(192) # Use db192 to store task status
|
||||
await redis.hset(
|
||||
f"task:{task_id}",
|
||||
mapping={
|
||||
"status": TaskStatus.PROCESSING,
|
||||
"file_hash": file_hash,
|
||||
"question": question
|
||||
}
|
||||
)
|
||||
|
||||
# Get analysis report from Redis
|
||||
await redis.select(190)
|
||||
report_key = f"paper_report:{file_hash}"
|
||||
report_data = await redis.get(report_key)
|
||||
await db_ops.save_task_status(task_id, {
|
||||
"status": TaskStatus.PROCESSING,
|
||||
"file_hash": file_hash,
|
||||
"question": question
|
||||
})
|
||||
|
||||
# Get analysis report
|
||||
report_data = await db_ops.get_paper_report(file_hash)
|
||||
if not report_data:
|
||||
raise Exception("Paper analysis report does not exist, please analyze first")
|
||||
|
||||
# Get PDF content from file storage service
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(f'https://files.aiot.ml/pdf/{file_hash}') as response:
|
||||
if response.status != 200:
|
||||
raise Exception(f"Failed to get PDF content: HTTP {response.status}")
|
||||
|
||||
pdf_content = await response.json()
|
||||
if not pdf_content.get('content'):
|
||||
raise Exception("No PDF content returned")
|
||||
|
||||
content = pdf_content['content']
|
||||
# If content is a list, join it into a single string
|
||||
if isinstance(content, list):
|
||||
content = '\n'.join(content)
|
||||
|
||||
# Limit content length
|
||||
MAX_CHARS = 180000
|
||||
content = content[:MAX_CHARS]
|
||||
# Get PDF content
|
||||
content = await get_pdf_content(file_hash)
|
||||
if not content:
|
||||
raise Exception("Failed to get PDF content")
|
||||
|
||||
# Limit content length
|
||||
MAX_CHARS = 180000
|
||||
content = content[:MAX_CHARS]
|
||||
|
||||
# Build system prompt and user prompt
|
||||
system_prompt = """
|
||||
@@ -791,7 +573,7 @@ async def paper_question(task_id: str, file_hash: str, question: str, paper: dic
|
||||
# Build context
|
||||
context = {
|
||||
"Paper Content": content,
|
||||
"Analysis Report": json.loads(report_data)
|
||||
"Analysis Report": report_data
|
||||
}
|
||||
|
||||
messages = [
|
||||
@@ -800,94 +582,39 @@ async def paper_question(task_id: str, file_hash: str, question: str, paper: dic
|
||||
]
|
||||
|
||||
# Call DeepSeek API for answer
|
||||
response = client.chat.completions.create(
|
||||
model="deepseek-chat",
|
||||
messages=messages
|
||||
)
|
||||
answer = response.choices[0].message.content
|
||||
answer = await call_deepseek_sync(messages)
|
||||
|
||||
# Save conversation history to Redis db191
|
||||
await redis.select(191)
|
||||
chat_history_key = f"chat_history:{file_hash}"
|
||||
|
||||
# Get existing history
|
||||
existing_history = await redis.get(chat_history_key)
|
||||
history = json.loads(existing_history) if existing_history else []
|
||||
|
||||
# Add new conversation
|
||||
history.append({
|
||||
"question": question,
|
||||
"answer": answer,
|
||||
"timestamp": datetime.now(timezone.utc).isoformat()
|
||||
})
|
||||
|
||||
# Save updated history
|
||||
await redis.set(chat_history_key, json.dumps(history))
|
||||
# Save conversation history
|
||||
await db_ops.save_qa_history(file_hash, question, answer)
|
||||
|
||||
# Update task status to completed
|
||||
await redis.select(192)
|
||||
await redis.hset(
|
||||
f"task:{task_id}",
|
||||
mapping={
|
||||
"status": TaskStatus.COMPLETED,
|
||||
"answer": answer,
|
||||
"paper_title": paper.get("paper_title")
|
||||
}
|
||||
)
|
||||
await db_ops.save_task_status(task_id, {
|
||||
"status": TaskStatus.COMPLETED,
|
||||
"answer": answer,
|
||||
"paper_title": paper.get("paper_title")
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error processing question: {e}")
|
||||
# Update task status to failed
|
||||
await redis.select(192)
|
||||
await redis.hset(
|
||||
f"task:{task_id}",
|
||||
mapping={
|
||||
"status": TaskStatus.FAILED,
|
||||
"error": str(e)
|
||||
}
|
||||
)
|
||||
finally:
|
||||
try:
|
||||
await redis.aclose()
|
||||
except Exception as e:
|
||||
print(f"Error closing Redis connection: {e}")
|
||||
await db_ops.save_task_status(task_id, {
|
||||
"status": TaskStatus.FAILED,
|
||||
"error": str(e)
|
||||
})
|
||||
|
||||
@app.get("/paper/{file_hash}/qa/history")
|
||||
async def get_paper_qa_history(
|
||||
file_hash: str
|
||||
):
|
||||
async def get_paper_qa_history(file_hash: str):
|
||||
"""Get paper Q&A history"""
|
||||
redis = await get_redis()
|
||||
|
||||
try:
|
||||
# Get conversation history from Redis db191
|
||||
await redis.select(191)
|
||||
chat_history_key = f"chat_history:{file_hash}"
|
||||
|
||||
history_data = await redis.get(chat_history_key)
|
||||
if not history_data:
|
||||
return []
|
||||
|
||||
return json.loads(history_data)
|
||||
|
||||
return await db_ops.get_qa_history(file_hash)
|
||||
except Exception as e:
|
||||
print(f"Error getting QA history: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
finally:
|
||||
try:
|
||||
await redis.aclose()
|
||||
except Exception as e:
|
||||
print(f"Error closing Redis connection: {e}")
|
||||
|
||||
@app.get("/paper/task/{task_id}")
|
||||
async def get_task_status(task_id: str):
|
||||
"""Get task status and result"""
|
||||
redis = await get_redis()
|
||||
|
||||
try:
|
||||
await redis.select(192)
|
||||
task_data = await redis.hgetall(f"task:{task_id}")
|
||||
|
||||
task_data = await db_ops.get_task_status(task_id)
|
||||
if not task_data:
|
||||
raise HTTPException(status_code=404, detail="Task not found")
|
||||
|
||||
@@ -896,59 +623,30 @@ async def get_task_status(task_id: str):
|
||||
"status": task_data.get("status", TaskStatus.PENDING)
|
||||
}
|
||||
|
||||
# If task completed, add results
|
||||
if task_data.get("status") == TaskStatus.COMPLETED:
|
||||
response.update({
|
||||
"answer": task_data.get("answer"),
|
||||
"paper_title": task_data.get("paper_title")
|
||||
})
|
||||
# If task failed, add error information
|
||||
elif task_data.get("status") == TaskStatus.FAILED:
|
||||
response.update({
|
||||
"error": task_data.get("error")
|
||||
})
|
||||
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error getting task status: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
finally:
|
||||
try:
|
||||
await redis.aclose()
|
||||
except Exception as e:
|
||||
print(f"Error closing Redis connection: {e}")
|
||||
|
||||
# 添加新的路由,通过哈希值获取报告
|
||||
@app.get("/paper/check/{file_hash}")
|
||||
async def check_report(
|
||||
file_hash: str
|
||||
):
|
||||
"""直接通过文件哈希值从 Redis db190 读取已保存的文献报告"""
|
||||
redis = await get_redis()
|
||||
|
||||
async def check_report(file_hash: str):
|
||||
"""检查文献报告状态"""
|
||||
try:
|
||||
# 选择 db190
|
||||
await redis.select(190)
|
||||
report_key = f"paper_report:{file_hash}"
|
||||
|
||||
# 获取已保存的报告
|
||||
existing_report = await redis.get(report_key)
|
||||
|
||||
if not existing_report:
|
||||
report = await db_ops.get_paper_report(file_hash)
|
||||
if not report:
|
||||
return {"status": "not_found"}
|
||||
|
||||
# 返回报告
|
||||
return json.loads(existing_report)
|
||||
|
||||
return report
|
||||
except Exception as e:
|
||||
print(f"Error getting paper report by hash: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
finally:
|
||||
try:
|
||||
await redis.aclose()
|
||||
except Exception as e:
|
||||
print(f"Error closing Redis connection: {e}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
Reference in New Issue
Block a user