Files
2025-01-19 07:52:50 +00:00

654 lines
26 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from typing import List, Optional
from datetime import datetime, timezone
import asyncio
from bson import ObjectId
from pydantic import BaseModel, Field
import json
import aiohttp
from concurrent.futures import ThreadPoolExecutor
from contextlib import asynccontextmanager
from database import (
connect_to_mongo, close_mongo_connection,
DatabaseOperations as db_ops
)
from api import call_deepseek, call_deepseek_sync
# 在全局范围创建线程池
pdf_thread_pool = ThreadPoolExecutor(max_workers=3) # 限制并发PDF处理数量
@asynccontextmanager
async def lifespan(app: FastAPI):
# Startup
await connect_to_mongo()
yield
# Shutdown
await close_mongo_connection()
# 更新 FastAPI 实例化
app = FastAPI(lifespan=lifespan)
# CORS configuration
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
expose_headers=["*"]
)
# 在其他 Pydantic 模型后添加
class paperModel(BaseModel):
"""文献引用模型"""
paper_link: str
paper_title: str
upload_time: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
class Config:
populate_by_name = True
arbitrary_types_allowed = True
json_encoders = {ObjectId: str}
# 删除文献
@app.delete("/paper/delete/{paper_id}")
async def delete_paper(paper_id: str):
"""删除文献"""
try:
result = await db_ops.delete_paper(paper_id)
if not result:
raise HTTPException(status_code=404, detail="paper not found")
return result
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/paper/papers")
async def get_papers():
"""获取文献列表"""
try:
return await db_ops.get_papers()
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/paper/report/{file_hash}")
async def get_report(file_hash: str):
"""获取文献报告"""
try:
report = await db_ops.get_paper_report(file_hash)
if not report:
raise HTTPException(status_code=404, detail="No saved paper report found")
return report
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
class FileUpload(BaseModel):
filename: str
hash: str
class BatchUploadRequest(BaseModel):
files: List[FileUpload]
@app.post("/paper/upload")
async def batch_upload(request: BatchUploadRequest):
"""批量上传项目相关文献"""
try:
uploaded_papers = []
papers_to_analyze = []
for file in request.files:
try:
# 1. 创建新记录
paper = {
"paper_title": file.filename,
"upload_time": datetime.now(timezone.utc),
"file_hash": file.hash
}
result = await db_ops.insert_paper(paper)
paper_info = {
"paper_id": str(result.inserted_id),
"file_hash": file.hash,
"paper_title": paper["paper_title"]
}
uploaded_papers.append(paper_info)
# 2. 设置分析状态
await db_ops.set_paper_processing_status(file.hash)
papers_to_analyze.append(paper_info)
except Exception as file_error:
print(f"处理文件 {file.filename} 时出错: {str(file_error)}")
uploaded_papers.append({
"file_hash": file.hash,
"paper_title": file.filename,
"status": "error",
"error_message": str(file_error)
})
continue
# 启动分析任务
if papers_to_analyze:
asyncio.create_task(batch_analysis(papers_to_analyze))
print(f"[Redis] 开始分析新文件: {papers_to_analyze}")
return {
"message": f"Successfully processed {len(uploaded_papers)} files",
"uploaded_files": uploaded_papers,
"files_to_analyze": len(papers_to_analyze)
}
except Exception as e:
print(f"批量上传错误: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
async def get_pdf_content(file_hash: str) -> Optional[str]:
"""从文件存储服务获取PDF内容"""
async with aiohttp.ClientSession() as session:
async with session.get(f'https://files.aiot.ml/pdf/{file_hash}') as response:
if response.status != 200:
print(f"获取PDF内容失败: {response.status}")
return None
pdf_content = await response.json()
if not pdf_content.get('content'):
return None
content = pdf_content['content']
if isinstance(content, list):
content = '\n'.join(content)
return content
async def batch_analysis(papers: List[dict]):
"""批量处理文献分析的后台任务"""
# 限制并发数量
semaphore = asyncio.Semaphore(3)
async def single_paper(ref: dict):
async with semaphore:
try:
file_hash = ref["file_hash"]
# 检查报告是否存在
if await db_ops.check_existing_report(file_hash):
print(f"Report already exists for file hash: {file_hash}, skipping analysis")
return
# 获取PDF内容
content = await get_pdf_content(file_hash)
if not content:
await db_ops.save_analysis_error(file_hash, "Failed to get PDF content")
return
print(f"\n开始处理文献 {ref.get('paper_title', '未知标题')}")
print(f"文件哈希: {file_hash}")
# 打印字符数
content_length = len(content)
print(f"\n=== 步骤2: 内容长度检查 ===")
print(f"PDF内容总字符数: {content_length}")
# 根据内容长度选择不同的处理方式
if content_length <= 200000:
print(f"\n=== 步骤3A: 使用直接分析方式 ===")
print(f"文档长度在处理范围内 ({content_length} <= 200000)")
document_analysis = await analyze_paper(content[:180000])
if not document_analysis:
await db_ops.save_analysis_error(file_hash, "Failed to analyze document")
return
print("文档分析完成")
else:
print(f"\n=== 步骤3B: 使用分段分析方式 ===")
print(f"文档超过200000字符 ({content_length} > 200000)")
analysis_results = await analyze_long_file(content)
if not analysis_results:
await db_ops.save_analysis_error(file_hash, "Failed to analyze document in segments")
return
print(f"分段分析完成,共分析了 {len(analysis_results)} 个段落")
document_analysis = await merge_results(analysis_results)
if not document_analysis:
await db_ops.save_analysis_error(file_hash, "Failed to merge analysis results")
return
# 等待一小段时间避免API限制
await asyncio.sleep(1)
# 异步分析文献价值
value_evaluation = await paper_value(document_analysis)
if not value_evaluation:
await db_ops.save_analysis_error(file_hash, "Failed to evaluate value")
return
print("文献价值分析完成")
# 合并结果
print("\n=== 步骤5: 保存最终结果 ===")
analysis_result = {
**document_analysis,
**value_evaluation,
"status": "completed"
}
# 保存结果
await db_ops.save_paper_report(file_hash, analysis_result)
except Exception as e:
print(f"分析文献时出错: {str(e)}")
await db_ops.save_analysis_error(file_hash, str(e))
try:
# 并发处理所有引用
await asyncio.gather(
*(single_paper(ref) for ref in papers)
)
except Exception as e:
print(f"批量分析任务出错: {str(e)}")
async def analyze_long_file(content: str) -> List[dict]:
"""分段分析长文档"""
# 将内容分成多个段落,每段约60000个字符(估算后约50000 tokens
segments = []
content_length = len(content)
segment_size = 180000 # 减小段落大小
print(f"\n开始分段处理,总字符数: {content_length}")
print(f"每段大小: {segment_size} 字符")
for i in range(0, content_length, segment_size):
segment = content[i:i + segment_size]
segments.append(segment)
print(f"文档已分段,共 {len(segments)} 个段落")
# 对每个段落进行分析
analysis_results = []
for i, segment in enumerate(segments):
print(f"\n开始分析第 {i+1}/{len(segments)} 个段落...")
system_prompt = f"""
You are an AI assistant tasked with analyzing part {i+1} of {len(segments)} of an academic paper.
Generate a comprehensive analysis in JSON format covering both basic information and content analysis.
Note that this is part {i+1} of a longer document, so focus on the content provided.
The JSON structure must strictly follow the provided template.
Also generate a Mermaid flowchart code to visualize the research methodology if this segment contains methodology information.
Create a detailed and comprehensive flowchart that accurately represents the paper's research methodology.
"""
user_prompt = f"""Analyze the following paper segment and extract all relevant information:
Content: {segment}
Generate a JSON response with the following structure:
{{
"1. Basic Information": {{
"author": "[Author name(s) and affiliations]",
"publication_date": "[Publication date in YYYY-MM format]",
"title": "[Full title of the document]",
"journal_publisher": "[Journal name or publisher details]",
"document_type": "[Type: journal article/book/conference paper etc.]"
}},
"2. Content Analysis": {{
"abstract": "[Paper abstract]",
"research_purpose": "[Main objectives and research questions]",
"methodology": "[Research methods, data collection and analysis approaches]",
"main_arguments": "[Key theoretical frameworks and arguments]",
"conclusions": "[Complete findings and conclusions]",
"innovations": "[Novel contributions and original aspects]"
}},
"flowchart": "[If this segment contains methodology information, generate a Mermaid flowchart code that visualizes the research methodology. Follow these rules:
1. Use 'graph TD' for top-down flow
2. Each node should be in format: id[text] where:
- id is a unique identifier (A, B1, B2, etc.)
- text should be simple and clear, using only letters, numbers, and spaces
- DO NOT use any special characters including parentheses, colons, commas
- abbreviations should be written without parentheses, e.g., 'DNN' not '(DNN)'
- use space instead of special characters, e.g., 'Deep Learning Model' not 'Deep-Learning/Model'
3. Connections use '-->' between nodes
4. Ensure each line ends with a proper node paper
Create a detailed flowchart that shows:
- Research objectives and questions
- All major research methods used
- Data collection and analysis processes
- Key experimental or analytical steps
- Result synthesis and conclusion formation
Make the flowchart as detailed as possible while maintaining clarity.
If this segment does not contain methodology information, set this field to null.]"
}}
"""
try:
result = await call_deepseek([
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt}
])
if result:
print(f"{i+1} 个段落分析完成")
analysis_results.append(result)
# 等待一小段时间避免API限制
await asyncio.sleep(1)
else:
print(f"{i+1} 个段落分析失败")
except Exception as e:
print(f"分析段落时出错: {str(e)}")
continue
return analysis_results
async def merge_results(results: List[dict]) -> dict:
"""合并多个分析结果"""
if not results:
return {}
print(f"开始合并 {len(results)} 个分析结果...")
system_prompt = """
You are an AI assistant tasked with merging multiple analysis results of different parts of the same academic paper.
Generate a comprehensive merged analysis in JSON format.
The JSON structure must strictly follow the provided template.
Ensure the merged result is coherent and eliminates redundancy.
For the flowchart, select the most comprehensive one from the input results, or combine multiple flowcharts if they contain complementary information.
When combining flowcharts, ensure the result follows Mermaid syntax rules and avoids special characters.
Create a detailed and comprehensive flowchart that accurately represents the paper's complete research methodology.
"""
user_prompt = f"""Merge the following analysis results into a single coherent analysis:
Analysis Results: {json.dumps(results, ensure_ascii=False)}
Generate a JSON response with the following structure:
{{
"1. Basic Information": {{
"author": "[Merged author information]",
"publication_date": "[Publication date]",
"title": "[Complete title]",
"journal_publisher": "[Journal/publisher information]",
"document_type": "[Document type]"
}},
"2. Content Analysis": {{
"abstract": "[Complete abstract]",
"research_purpose": "[Comprehensive research objectives]",
"methodology": "[Complete methodology description]",
"main_arguments": "[Comprehensive theoretical frameworks and arguments]",
"conclusions": "[Complete findings and conclusions]",
"innovations": "[Complete list of innovations]"
}},
"flowchart": "[Select or combine the flowcharts following these rules:
1. Use 'graph TD' for top-down flow
2. Each node should be in format: id[text] where:
- id is a unique identifier (A, B1, B2, etc.)
- text should be simple and clear, using only letters, numbers, and spaces
- DO NOT use any special characters including parentheses, colons, commas
- abbreviations should be written without parentheses, e.g., 'DNN' not '(DNN)'
- use space instead of special characters, e.g., 'Deep Learning Model' not 'Deep-Learning/Model'
3. Connections use '-->' between nodes
4. Ensure each line ends with a proper node paper
Create a detailed flowchart that shows:
- Research objectives and questions
- All major research methods used
- Data collection and analysis processes
- Key experimental or analytical steps
- Result synthesis and conclusion formation
Make the flowchart as detailed as possible while maintaining clarity.
If no valid flowchart is found in any segment, set this field to null.]"
}}
"""
result = await call_deepseek([
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt}
])
if result:
print("分析结果合并完成")
else:
print("分析结果合并失败")
return result
async def analyze_paper(content: str):
"""分析文献的基本信息和内容"""
system_prompt = """
You are an AI assistant tasked with analyzing academic paper.
Generate a comprehensive analysis in JSON format covering both basic information and content analysis.
The JSON structure must strictly follow the provided template.
Also generate a Mermaid flowchart code to visualize the research methodology.
Ensure the flowchart follows strict formatting rules to avoid parsing errors.
Create a detailed and comprehensive flowchart that accurately represents the paper's research methodology.
"""
user_prompt = f"""Analyze the following paper and extract all relevant information:
Content: {content}
Generate a JSON response with the following structure:
{{
"1. Basic Information": {{
"author": "[Author name(s) and affiliations]",
"publication_date": "[Publication date in YYYY-MM format]",
"title": "[Full title of the document]",
"journal_publisher": "[Journal name or publisher details]",
"document_type": "[Type: journal article/book/conference paper etc.]"
}},
"2. Content Analysis": {{
"abstract": "[Paper abstract]",
"research_purpose": "[Main objectives and research questions]",
"methodology": "[Research methods, data collection and analysis approaches]",
"main_arguments": "[Key theoretical frameworks and arguments]",
"conclusions": "[Primary findings and conclusions]",
"innovations": "[Novel contributions and original aspects]"
}},
"flowchart": "[Generate a Mermaid flowchart code that visualizes the research methodology. Follow these rules:
1. Use 'graph TD' for top-down flow
2. Each node should be in format: id[text] where:
- id is a unique identifier (A, B1, B2, etc.)
- text should be simple and clear, using only letters, numbers, and spaces
- DO NOT use any special characters including parentheses, colons, commas
- abbreviations should be written without parentheses, e.g., 'DNN' not '(DNN)'
- use space instead of special characters, e.g., 'Deep Learning Model' not 'Deep-Learning/Model'
3. Connections use '-->' between nodes
4. Ensure each line ends with a proper node paper
Create a detailed flowchart that shows:
- Research objectives and questions
- All major research methods used
- Data collection and analysis processes
- Key experimental or analytical steps
- Result synthesis and conclusion formation
Make the flowchart as detailed as possible while maintaining clarity.
If the paper does not contain clear methodology information, set this field to null.]"
}}
"""
return await call_deepseek([
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt}
])
async def paper_value(content_analysis: dict):
"""基于内容分析结果评估文献的价值"""
system_prompt = """
You are an AI assistant tasked with evaluating the value of academic paper based on its content analysis.
Generate a comprehensive value evaluation in JSON format.
The JSON structure must strictly follow the provided template.
"""
user_prompt = f"""Based on the following content analysis, evaluate the paper's value:
Content Analysis: {json.dumps(content_analysis, ensure_ascii=False)}
Generate a JSON response with the following structure:
{{
"3. Value Evaluation": {{
"academic_contribution": "[Significance to the field of study]",
"practical_significance": "[Real-world applications and implications]",
"limitations": "[Research constraints and weaknesses]",
"implications": "[Suggestions for future research and practice]"
}}
}}
"""
return await call_deepseek([
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt}
])
# 添加问答相关的模型
class QuestionModel(BaseModel):
"""问题模型"""
question: str
class TaskStatus:
PENDING = "pending"
PROCESSING = "processing"
COMPLETED = "completed"
FAILED = "failed"
@app.post("/paper/{file_hash}/qa")
async def ask_reference_question(file_hash: str, question: QuestionModel):
"""Ask questions about the paper (async)"""
try:
# Generate task ID
task_id = str(ObjectId())
# Create background task
asyncio.create_task(paper_question(
task_id=task_id,
file_hash=file_hash,
question=question.question
))
return {
"task_id": task_id,
"status": TaskStatus.PENDING,
"message": "Question submitted, processing in progress"
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
async def paper_question(task_id: str, file_hash: str, question: str):
"""Process paper Q&A background task"""
try:
# Get paper info
paper = await db_ops.get_paper(file_hash)
if not paper:
raise Exception("Paper not found")
# Update task status to processing
await db_ops.save_task_status(task_id, {
"status": TaskStatus.PROCESSING,
"file_hash": file_hash,
"question": question
})
# Get analysis report
report_data = await db_ops.get_paper_report(file_hash)
if not report_data:
raise Exception("Paper analysis report does not exist, please analyze first")
# Get PDF content
content = await get_pdf_content(file_hash)
if not content:
raise Exception("Failed to get PDF content")
# Limit content length
MAX_CHARS = 180000
content = content[:MAX_CHARS]
# Build system prompt and user prompt
system_prompt = """
You are a professional academic assistant responsible for answering questions about academic papers.
You should provide accurate and professional answers based on the paper content and analysis report.
Answers should be concise and clear, citing specific content from the paper whenever possible.
"""
# Build context
context = {
"Paper Content": content,
"Analysis Report": report_data
}
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": f"Answer the question based on the following paper content and analysis report:\n\nPaper Information:{json.dumps(context, ensure_ascii=False)}\n\nQuestion:{question}"}
]
# Call DeepSeek API for answer
answer = await call_deepseek_sync(messages)
# Save conversation history
await db_ops.save_qa_history(file_hash, question, answer)
# Update task status to completed
await db_ops.save_task_status(task_id, {
"status": TaskStatus.COMPLETED,
"answer": answer,
"paper_title": paper.get("paper_title")
})
except Exception as e:
print(f"Error processing question: {e}")
# Update task status to failed
await db_ops.save_task_status(task_id, {
"status": TaskStatus.FAILED,
"error": str(e)
})
@app.get("/paper/{file_hash}/qa/history")
async def get_paper_qa_history(file_hash: str):
"""Get paper Q&A history"""
try:
return await db_ops.get_qa_history(file_hash)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/paper/task/{task_id}")
async def get_task_status(task_id: str):
"""Get task status and result"""
try:
task_data = await db_ops.get_task_status(task_id)
if not task_data:
raise HTTPException(status_code=404, detail="Task not found")
response = {
"task_id": task_id,
"status": task_data.get("status", TaskStatus.PENDING)
}
if task_data.get("status") == TaskStatus.COMPLETED:
response.update({
"answer": task_data.get("answer"),
"paper_title": task_data.get("paper_title")
})
elif task_data.get("status") == TaskStatus.FAILED:
response.update({
"error": task_data.get("error")
})
return response
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/paper/check/{file_hash}")
async def check_report(file_hash: str):
"""检查文献报告状态"""
try:
report = await db_ops.get_paper_report(file_hash)
if not report:
return {"status": "not_found"}
return report
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=9005)