update
This commit is contained in:
@@ -0,0 +1,74 @@
|
||||
from typing import Dict, Any
|
||||
import os
|
||||
from datetime import datetime, timezone
|
||||
|
||||
# MongoDB 配置
|
||||
MONGODB_URL = "mongodb://localhost:27017"
|
||||
DATABASE_NAME = "lab"
|
||||
|
||||
# Redis 配置
|
||||
REDIS_URL = "redis://localhost:6379"
|
||||
|
||||
# DeepSeek API 配置
|
||||
DEEPSEEK_API_CONFIG = {
|
||||
"base_url": "https://api.deepseek.com/v1",
|
||||
"api_key": "sk-3027fb3c810b4e17985fa397d41250b9"
|
||||
}
|
||||
|
||||
# 文件上传配置
|
||||
UPLOAD_PATH = "/obscura/task/references"
|
||||
|
||||
# Redis 数据库映射
|
||||
REDIS_DB_MAP = {
|
||||
"experiment_analysis": 201, # 实验分析报告
|
||||
"project_analysis": 202, # 项目分析报告
|
||||
"reference_analysis": 203, # 文献分析报告
|
||||
"reference_summary": 204, # 文献汇总分析
|
||||
"project_memo": 205, # 项目备忘录
|
||||
"experiment_memo": 206, # 实验备忘录
|
||||
"chat_history": 207, # 聊天历史记录
|
||||
"task_status": 208 # 任务状态
|
||||
}
|
||||
|
||||
# 任务状态枚举
|
||||
class TaskStatus:
|
||||
PENDING = "pending"
|
||||
PROCESSING = "processing"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
|
||||
# 线程池配置
|
||||
THREAD_POOL_CONFIG = {
|
||||
"pdf_workers": 3,
|
||||
"analysis_workers": 3,
|
||||
"pro_analysis_workers": 3
|
||||
}
|
||||
|
||||
# API 响应格式
|
||||
def format_response(data: Any, message: str = "Success") -> Dict[str, Any]:
|
||||
"""格式化 API 响应"""
|
||||
return {
|
||||
"status": "success",
|
||||
"message": message,
|
||||
"data": data,
|
||||
"timestamp": datetime.now(timezone.utc).isoformat()
|
||||
}
|
||||
|
||||
def format_error(message: str, status_code: int = 500) -> Dict[str, Any]:
|
||||
"""格式化错误响应"""
|
||||
return {
|
||||
"status": "error",
|
||||
"message": message,
|
||||
"error_code": status_code,
|
||||
"timestamp": datetime.now(timezone.utc).isoformat()
|
||||
}
|
||||
|
||||
# 文件类型配置
|
||||
ALLOWED_FILE_TYPES = {
|
||||
"application/pdf",
|
||||
"application/msword",
|
||||
"application/vnd.openxmlformats-officedocument.wordprocessingml.document"
|
||||
}
|
||||
|
||||
# 确保上传目录存在
|
||||
os.makedirs(UPLOAD_PATH, exist_ok=True)
|
||||
@@ -0,0 +1,53 @@
|
||||
from bson import ObjectId
|
||||
from pydantic import BaseModel, Field
|
||||
from motor.motor_asyncio import AsyncIOMotorClient
|
||||
from contextlib import asynccontextmanager
|
||||
from redis import asyncio as aioredis
|
||||
|
||||
# Database Configuration
|
||||
MONGODB_URL = "mongodb://lab:y6aHwySAhzrbibLD@222.186.10.253:27017/lab"
|
||||
REDIS_URL = "redis://:Obscura@2024@222.186.10.253:6379"
|
||||
|
||||
class PyObjectId(ObjectId):
|
||||
"""
|
||||
自定义ObjectId类,用于在Pydantic模型中处理MongoDB的ObjectId
|
||||
"""
|
||||
@classmethod
|
||||
def __get_validators__(cls):
|
||||
yield cls.validate
|
||||
|
||||
@classmethod
|
||||
def validate(cls, v, handler):
|
||||
if not ObjectId.is_valid(v):
|
||||
raise ValueError("Invalid ObjectId")
|
||||
return ObjectId(v)
|
||||
|
||||
@classmethod
|
||||
def __get_pydantic_json_schema__(cls, _schema_cache, **_kwargs):
|
||||
return {
|
||||
'type': 'string',
|
||||
'description': 'ObjectId',
|
||||
'pattern': r'^[0-9a-fA-F]{24}$'
|
||||
}
|
||||
|
||||
class Database:
|
||||
"""数据库连接管理类"""
|
||||
client: AsyncIOMotorClient = None
|
||||
|
||||
async def get_database() -> AsyncIOMotorClient:
|
||||
return Database.client.lab
|
||||
|
||||
async def connect_to_mongo():
|
||||
Database.client = AsyncIOMotorClient(MONGODB_URL)
|
||||
|
||||
async def close_mongo_connection():
|
||||
if Database.client:
|
||||
Database.client.close()
|
||||
|
||||
async def get_redis():
|
||||
redis = aioredis.from_url(
|
||||
REDIS_URL,
|
||||
encoding="utf-8",
|
||||
decode_responses=True
|
||||
)
|
||||
return redis
|
||||
@@ -0,0 +1,72 @@
|
||||
# Standard library imports
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
# Third-party imports
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from openai import OpenAI
|
||||
|
||||
# Local application imports
|
||||
from .cores.config import DEEPSEEK_API_CONFIG
|
||||
from .cores.db import (
|
||||
connect_to_mongo,
|
||||
close_mongo_connection,
|
||||
)
|
||||
from .routers.login import router as login_router
|
||||
from .routers.project import router as project_router
|
||||
from .routers.device import router as device_router
|
||||
from .routers.experiment_device import router as experiment_device_router
|
||||
from .routers.experiment import router as experiment_router
|
||||
from .routers.experiment_report import router as experiment_report_router
|
||||
from .routers.websocket import router as websocket_router
|
||||
from .routers.project_report import router as project_report_router
|
||||
from .routers.memo import router as memo_router
|
||||
from .routers.paper import router as paper_router
|
||||
|
||||
# DeepSeek API Configuration
|
||||
client = OpenAI(
|
||||
base_url=DEEPSEEK_API_CONFIG["base_url"],
|
||||
api_key=DEEPSEEK_API_CONFIG["api_key"]
|
||||
)
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
# Startup
|
||||
await connect_to_mongo()
|
||||
yield
|
||||
# Shutdown
|
||||
await close_mongo_connection()
|
||||
|
||||
# 更新 FastAPI 实例化
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
|
||||
# 添加路由
|
||||
app.include_router(login_router)
|
||||
app.include_router(project_router)
|
||||
app.include_router(device_router)
|
||||
app.include_router(experiment_device_router)
|
||||
app.include_router(experiment_router)
|
||||
app.include_router(experiment_report_router)
|
||||
app.include_router(websocket_router)
|
||||
app.include_router(project_report_router)
|
||||
app.include_router(memo_router)
|
||||
app.include_router(paper_router)
|
||||
|
||||
# CORS configuration
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
expose_headers=["*"]
|
||||
)
|
||||
|
||||
# Health check endpoint
|
||||
@app.get("/health")
|
||||
async def health_check():
|
||||
return {"status": "healthy"}
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
uvicorn.run(app, host="0.0.0.0", port=6000)
|
||||
@@ -0,0 +1,614 @@
|
||||
import json
|
||||
import os
|
||||
import asyncio
|
||||
from datetime import datetime, timezone
|
||||
from typing import List, Dict
|
||||
import PyPDF2
|
||||
import aiohttp
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from openai import OpenAI
|
||||
from ..cores.config import DEEPSEEK_API_CONFIG
|
||||
from ..cores.db import get_redis
|
||||
from enum import Enum
|
||||
|
||||
# 创建OpenAI客户端
|
||||
client = OpenAI(
|
||||
base_url=DEEPSEEK_API_CONFIG["base_url"],
|
||||
api_key=DEEPSEEK_API_CONFIG["api_key"]
|
||||
)
|
||||
|
||||
# 修改函数定义为异步函数
|
||||
async def analyze_reference_summary(reference_data):
|
||||
system_prompt = """
|
||||
You are an AI assistant responsible for analyzing paper reports.
|
||||
You will summarize and analyze all paper reports and generate a comprehensive analysis report in JSON format.
|
||||
The JSON structure must strictly follow the provided template.
|
||||
"""
|
||||
|
||||
user_prompt = f"""Analyze the following paper reports:
|
||||
Paper reports: {json.dumps(reference_data, ensure_ascii=False)}
|
||||
|
||||
Generate a JSON response with the following structure:
|
||||
|
||||
{{
|
||||
"Paper Summary Report": {{
|
||||
"overview": {{
|
||||
"total_papers": "[Number of papers]",
|
||||
"time_range": {{"start_year": "[Start year]", "end_year": "[End year]"}},
|
||||
"main_research_areas": "[Main research areas of the papers]"
|
||||
}},
|
||||
"research_trends": {{
|
||||
"major_themes": "[Major themes areas of the papers]",
|
||||
"common_methodologies": "[Common methodologies used in the papers]",
|
||||
"emerging_topics": "[Emerging topics in the papers]"
|
||||
}},
|
||||
"key_findings": {{
|
||||
"theoretical_advances": "[Theoretical advances in the papers]",
|
||||
"experimental_results": "[Experimental results in the papers]",
|
||||
"common_conclusions": "[Common conclusions in the papers]"
|
||||
}},
|
||||
"research_gaps": {{
|
||||
"current_limitations": "[Current limitations]",
|
||||
"unexplored_areas": "[Unexplored areas]",
|
||||
"technical_challenges": "[Technical challenges"
|
||||
}},
|
||||
"future_directions": {{
|
||||
"potential_applications": "[Potential future applications]",
|
||||
"methodological_suggestions": "[Methodological suggestions based on paper summary]"
|
||||
}},
|
||||
"impact_assessment": {{
|
||||
"academic_influence": "[Summarize the academic influence of the paper]",
|
||||
"practical_value": "[Summarize the practical value of the paper]"
|
||||
}}
|
||||
}}
|
||||
}}
|
||||
"""
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_prompt}
|
||||
]
|
||||
|
||||
try:
|
||||
response = client.chat.completions.create(
|
||||
model="deepseek-chat",
|
||||
messages=messages,
|
||||
response_format={'type': 'json_object'}
|
||||
)
|
||||
return json.loads(response.choices[0].message.content)
|
||||
except Exception as e:
|
||||
print(f"Error calling DeepSeek API: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def run_analysis_in_thread(project_id: str, reference_ids: List[str]):
|
||||
"""在独立线程中运行分析任务"""
|
||||
# 创建新的事件循环
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
# 在新的事件循环中运行异步任务
|
||||
loop.run_until_complete(process_reference_analysis(project_id, reference_ids))
|
||||
except Exception as e:
|
||||
print(f"Error in analysis thread: {str(e)}")
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
async def process_reference_analysis(project_id: str, reference_ids: List[str]):
|
||||
"""后台处理文献分析任务"""
|
||||
redis = await get_redis()
|
||||
|
||||
try:
|
||||
# 更新状态键
|
||||
await redis.select(204)
|
||||
status_key = f"reference_analysis_status:{project_id}"
|
||||
|
||||
# 从Redis db203获取所有实验报告
|
||||
await redis.select(203)
|
||||
all_reference_reports = {}
|
||||
completed_count = 0
|
||||
|
||||
for ref_id in reference_ids:
|
||||
report_key = f"reference_report:{ref_id}"
|
||||
report_data = await redis.get(report_key)
|
||||
|
||||
if report_data:
|
||||
try:
|
||||
parsed_report = json.loads(report_data)
|
||||
all_reference_reports[ref_id] = parsed_report
|
||||
completed_count += 1
|
||||
|
||||
# 更新进度
|
||||
await redis.select(204)
|
||||
current_status = await redis.get(status_key)
|
||||
if current_status:
|
||||
status_data = json.loads(current_status)
|
||||
status_data["completed_references"] = completed_count
|
||||
await redis.set(status_key, json.dumps(status_data))
|
||||
await redis.select(203)
|
||||
|
||||
except json.JSONDecodeError:
|
||||
print(f"Error parsing reference {ref_id} report")
|
||||
continue
|
||||
|
||||
if all_reference_reports:
|
||||
analysis_result = await analyze_reference_summary(all_reference_reports)
|
||||
|
||||
if analysis_result:
|
||||
await redis.select(204)
|
||||
report_key = f"reference_summary_report:{project_id}"
|
||||
await redis.set(report_key, json.dumps(analysis_result))
|
||||
|
||||
# 更新最终状态
|
||||
status_data = {
|
||||
"status": "completed",
|
||||
"completion_time": datetime.now(timezone.utc).isoformat(),
|
||||
"total_references": len(reference_ids),
|
||||
"completed_references": completed_count
|
||||
}
|
||||
await redis.set(status_key, json.dumps(status_data))
|
||||
|
||||
print(f"Reference analysis completed for project {project_id}")
|
||||
else:
|
||||
# 更新失败状态
|
||||
status_data = {
|
||||
"status": "failed",
|
||||
"error": "Failed to generate analysis result",
|
||||
"completion_time": datetime.now(timezone.utc).isoformat()
|
||||
}
|
||||
await redis.set(status_key, json.dumps(status_data))
|
||||
print(f"Failed to generate analysis for project {project_id}")
|
||||
else:
|
||||
# 更新失败状态
|
||||
status_data = {
|
||||
"status": "failed",
|
||||
"error": "No valid reference reports found",
|
||||
"completion_time": datetime.now(timezone.utc).isoformat()
|
||||
}
|
||||
await redis.set(status_key, json.dumps(status_data))
|
||||
print(f"No valid reference reports found for project {project_id}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error in background reference analysis: {str(e)}")
|
||||
try:
|
||||
await redis.select(204)
|
||||
status_data = {
|
||||
"status": "failed",
|
||||
"error": str(e),
|
||||
"completion_time": datetime.now(timezone.utc).isoformat()
|
||||
}
|
||||
await redis.set(status_key, json.dumps(status_data))
|
||||
except Exception as redis_error:
|
||||
print(f"Error updating Redis status: {redis_error}")
|
||||
finally:
|
||||
try:
|
||||
await redis.aclose()
|
||||
except Exception as e:
|
||||
print(f"Error closing Redis connection: {e}")
|
||||
|
||||
|
||||
|
||||
|
||||
async def process_reference_question(task_id: str, reference_id: str, question: str, reference: dict):
|
||||
"""处理文献问答的后台任务"""
|
||||
redis = await get_redis()
|
||||
|
||||
try:
|
||||
# 更新任务状态为处理中
|
||||
await redis.select(208) # 使用db208存储任务状态
|
||||
await redis.hset(
|
||||
f"task:{task_id}",
|
||||
mapping={
|
||||
"status": TaskStatus.PROCESSING,
|
||||
"reference_id": reference_id,
|
||||
"question": question
|
||||
}
|
||||
)
|
||||
|
||||
# 从Redis获取分析报告
|
||||
await redis.select(203)
|
||||
report_key = f"reference_report:{reference_id}"
|
||||
report_data = await redis.get(report_key)
|
||||
|
||||
if not report_data:
|
||||
raise Exception("文献分析报告不存在,请先进行分析")
|
||||
|
||||
# 读取PDF文件内容
|
||||
file_path = reference.get("reference_link")
|
||||
if not file_path or not os.path.exists(file_path):
|
||||
raise Exception("文献文件不存在")
|
||||
|
||||
# 读取PDF内容
|
||||
pdf_content = ""
|
||||
with open(file_path, 'rb') as file:
|
||||
pdf_reader = PyPDF2.PdfReader(file)
|
||||
for page in pdf_reader.pages:
|
||||
content = page.extract_text()
|
||||
pdf_content += content
|
||||
|
||||
# 限制内容长度
|
||||
MAX_CHARS = 180000
|
||||
pdf_content = pdf_content[:MAX_CHARS]
|
||||
|
||||
# 构建系统提示和用户提示
|
||||
system_prompt = """
|
||||
你是一个专业的学术助手,负责回答关于学术文献的问题。
|
||||
你应该基于文献内容和分析报告提供准确、专业的回答。
|
||||
回答应当简洁明了,并尽可能引用文献中的具体内容。
|
||||
"""
|
||||
|
||||
# 构建上下文
|
||||
context = {
|
||||
"文献内容": pdf_content,
|
||||
"分析报告": json.loads(report_data)
|
||||
}
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": f"基于以下文献内容和分析报告回答问题:\n\n文献信息:{json.dumps(context, ensure_ascii=False)}\n\n问题:{question}"}
|
||||
]
|
||||
|
||||
# 调用DeepSeek API获取回答
|
||||
response = client.chat.completions.create(
|
||||
model="deepseek-chat",
|
||||
messages=messages
|
||||
)
|
||||
answer = response.choices[0].message.content
|
||||
|
||||
# 保存对话历史到Redis db207
|
||||
await redis.select(207)
|
||||
chat_history_key = f"chat_history:{reference_id}"
|
||||
|
||||
# 获取现有历史记录
|
||||
existing_history = await redis.get(chat_history_key)
|
||||
history = json.loads(existing_history) if existing_history else []
|
||||
|
||||
# 添加新的对话
|
||||
history.append({
|
||||
"question": question,
|
||||
"answer": answer,
|
||||
"timestamp": datetime.now(timezone.utc).isoformat()
|
||||
})
|
||||
|
||||
# 保存更新后的历史记录
|
||||
await redis.set(chat_history_key, json.dumps(history))
|
||||
|
||||
# 更新任务状态为完成
|
||||
await redis.select(208)
|
||||
await redis.hset(
|
||||
f"task:{task_id}",
|
||||
mapping={
|
||||
"status": TaskStatus.COMPLETED,
|
||||
"answer": answer,
|
||||
"reference_title": reference.get("reference_title")
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error processing question: {e}")
|
||||
# 更新任务状态为失败
|
||||
await redis.select(208)
|
||||
await redis.hset(
|
||||
f"task:{task_id}",
|
||||
mapping={
|
||||
"status": TaskStatus.FAILED,
|
||||
"error": str(e)
|
||||
}
|
||||
)
|
||||
finally:
|
||||
try:
|
||||
await redis.aclose()
|
||||
except Exception as e:
|
||||
print(f"Error closing Redis connection: {e}")
|
||||
|
||||
|
||||
# 在全局范围创建线程池
|
||||
pdf_thread_pool = ThreadPoolExecutor(max_workers=3) # 限制并发PDF处理数量
|
||||
|
||||
# 创建异步HTTP客户端会话
|
||||
async def get_aiohttp_session():
|
||||
return aiohttp.ClientSession(
|
||||
base_url="https://api.deepseek.com/v1/", # 添加了末尾的斜杠
|
||||
headers={"Authorization": f"Bearer sk-3027fb3c810b4e17985fa397d41250b9"}
|
||||
)
|
||||
|
||||
async def read_pdf_async(file_path: str) -> str:
|
||||
"""在线程池中异步读取PDF"""
|
||||
def read_pdf():
|
||||
try:
|
||||
pdf_content = ""
|
||||
with open(file_path, 'rb') as file:
|
||||
pdf_reader = PyPDF2.PdfReader(file)
|
||||
for page in pdf_reader.pages:
|
||||
content = page.extract_text()
|
||||
pdf_content += content
|
||||
|
||||
# 根据内容长度返回不同的结果
|
||||
content_length = len(pdf_content)
|
||||
print(f"\nPDF原始内容字符数: {content_length}")
|
||||
|
||||
if content_length <= 180000:
|
||||
print("文档长度 ≤ 180000,返回完整内容")
|
||||
return pdf_content
|
||||
elif content_length <= 200000:
|
||||
print("文档长度在 180000-200000 之间,截取前 180000 个字符")
|
||||
return pdf_content[:180000]
|
||||
else:
|
||||
print("文档长度 > 200000,返回完整内容供分段处理")
|
||||
return pdf_content # 返回完整内容,由调用者处理分段
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error reading PDF: {e}")
|
||||
return ""
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(pdf_thread_pool, read_pdf)
|
||||
|
||||
async def call_deepseek_api_async(messages: list) -> dict:
|
||||
"""异步调用DeepSeek API"""
|
||||
async with await get_aiohttp_session() as session:
|
||||
async with session.post("/chat/completions", json={
|
||||
"model": "deepseek-chat",
|
||||
"messages": messages,
|
||||
"response_format": {"type": "json_object"}
|
||||
}) as response:
|
||||
if response.status == 200:
|
||||
data = await response.json()
|
||||
return json.loads(data["choices"][0]["message"]["content"])
|
||||
else:
|
||||
raise Exception(f"API调用失败: {await response.text()}")
|
||||
|
||||
async def analyze_long_document_async(content: str) -> List[dict]:
|
||||
"""分段分析长文档"""
|
||||
# 将内容分成多个段落,每段约60000个字符(估算后约50000 tokens)
|
||||
segments = []
|
||||
content_length = len(content)
|
||||
segment_size = 60000 # 减小段落大小
|
||||
|
||||
print(f"\n开始分段处理,总字符数: {content_length}")
|
||||
print(f"每段大小: {segment_size} 字符")
|
||||
|
||||
for i in range(0, content_length, segment_size):
|
||||
segment = content[i:i + segment_size]
|
||||
segments.append(segment)
|
||||
|
||||
print(f"文档已分段,共 {len(segments)} 个段落")
|
||||
|
||||
# 对每个段落进行分析
|
||||
analysis_results = []
|
||||
for i, segment in enumerate(segments):
|
||||
print(f"\n开始分析第 {i+1}/{len(segments)} 个段落...")
|
||||
|
||||
system_prompt = f"""
|
||||
You are an AI assistant tasked with analyzing part {i+1} of {len(segments)} of an academic paper.
|
||||
Generate a comprehensive analysis in JSON format covering both basic information and content analysis.
|
||||
Note that this is part {i+1} of a longer document, so focus on the content provided.
|
||||
The JSON structure must strictly follow the provided template.
|
||||
Also generate a Mermaid flowchart code to visualize the research methodology if this segment contains methodology information.
|
||||
Create a detailed and comprehensive flowchart that accurately represents the paper's research methodology.
|
||||
"""
|
||||
|
||||
user_prompt = f"""Analyze the following paper segment and extract all relevant information:
|
||||
Content: {segment}
|
||||
|
||||
Generate a JSON response with the following structure:
|
||||
|
||||
{{
|
||||
"1. Basic Information": {{
|
||||
"author": "[Author name(s) and affiliations]",
|
||||
"publication_date": "[Publication date in YYYY-MM format]",
|
||||
"title": "[Full title of the document]",
|
||||
"journal_publisher": "[Journal name or publisher details]",
|
||||
"document_type": "[Type: journal article/book/conference paper etc.]"
|
||||
}},
|
||||
"2. Content Analysis": {{
|
||||
"abstract": "[Paper abstract]",
|
||||
"research_purpose": "[Main objectives and research questions]",
|
||||
"methodology": "[Research methods, data collection and analysis approaches]",
|
||||
"main_arguments": "[Key theoretical frameworks and arguments]",
|
||||
"conclusions": "[Complete findings and conclusions]",
|
||||
"innovations": "[Novel contributions and original aspects]"
|
||||
}},
|
||||
"flowchart": "[If this segment contains methodology information, generate a Mermaid flowchart code that visualizes the research methodology. Follow these rules:
|
||||
1. Use 'graph TD' for top-down flow
|
||||
2. Each node should be in format: id[text] where:
|
||||
- id is a unique identifier (A, B1, B2, etc.)
|
||||
- text should be simple and clear, using only letters, numbers, and spaces
|
||||
- DO NOT use any special characters including parentheses, colons, commas
|
||||
- abbreviations should be written without parentheses, e.g., 'DNN' not '(DNN)'
|
||||
- use space instead of special characters, e.g., 'Deep Learning Model' not 'Deep-Learning/Model'
|
||||
3. Connections use '-->' between nodes
|
||||
4. Ensure each line ends with a proper node reference
|
||||
|
||||
Create a detailed flowchart that shows:
|
||||
- Research objectives and questions
|
||||
- All major research methods used
|
||||
- Data collection and analysis processes
|
||||
- Key experimental or analytical steps
|
||||
- Result synthesis and conclusion formation
|
||||
|
||||
Make the flowchart as detailed as possible while maintaining clarity.
|
||||
If this segment does not contain methodology information, set this field to null.]"
|
||||
}}
|
||||
"""
|
||||
|
||||
try:
|
||||
result = await call_deepseek_api_async([
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_prompt}
|
||||
])
|
||||
|
||||
if result:
|
||||
print(f"第 {i+1} 个段落分析完成")
|
||||
analysis_results.append(result)
|
||||
# 等待一小段时间避免API限制
|
||||
await asyncio.sleep(1)
|
||||
else:
|
||||
print(f"第 {i+1} 个段落分析失败")
|
||||
except Exception as e:
|
||||
print(f"分析段落时出错: {str(e)}")
|
||||
continue
|
||||
|
||||
return analysis_results
|
||||
|
||||
async def merge_analysis_results(results: List[dict]) -> dict:
|
||||
"""合并多个分析结果"""
|
||||
if not results:
|
||||
return {}
|
||||
|
||||
print(f"开始合并 {len(results)} 个分析结果...")
|
||||
|
||||
system_prompt = """
|
||||
You are an AI assistant tasked with merging multiple analysis results of different parts of the same academic paper.
|
||||
Generate a comprehensive merged analysis in JSON format.
|
||||
The JSON structure must strictly follow the provided template.
|
||||
Ensure the merged result is coherent and eliminates redundancy.
|
||||
For the flowchart, select the most comprehensive one from the input results, or combine multiple flowcharts if they contain complementary information.
|
||||
When combining flowcharts, ensure the result follows Mermaid syntax rules and avoids special characters.
|
||||
Create a detailed and comprehensive flowchart that accurately represents the paper's complete research methodology.
|
||||
"""
|
||||
|
||||
user_prompt = f"""Merge the following analysis results into a single coherent analysis:
|
||||
Analysis Results: {json.dumps(results, ensure_ascii=False)}
|
||||
|
||||
Generate a JSON response with the following structure:
|
||||
|
||||
{{
|
||||
"1. Basic Information": {{
|
||||
"author": "[Merged author information]",
|
||||
"publication_date": "[Publication date]",
|
||||
"title": "[Complete title]",
|
||||
"journal_publisher": "[Journal/publisher information]",
|
||||
"document_type": "[Document type]"
|
||||
}},
|
||||
"2. Content Analysis": {{
|
||||
"abstract": "[Complete abstract]",
|
||||
"research_purpose": "[Comprehensive research objectives]",
|
||||
"methodology": "[Complete methodology description]",
|
||||
"main_arguments": "[Comprehensive theoretical frameworks and arguments]",
|
||||
"conclusions": "[Complete findings and conclusions]",
|
||||
"innovations": "[Complete list of innovations]"
|
||||
}},
|
||||
"flowchart": "[Select or combine the flowcharts following these rules:
|
||||
1. Use 'graph TD' for top-down flow
|
||||
2. Each node should be in format: id[text] where:
|
||||
- id is a unique identifier (A, B1, B2, etc.)
|
||||
- text should be simple and clear, using only letters, numbers, and spaces
|
||||
- DO NOT use any special characters including parentheses, colons, commas
|
||||
- abbreviations should be written without parentheses, e.g., 'DNN' not '(DNN)'
|
||||
- use space instead of special characters, e.g., 'Deep Learning Model' not 'Deep-Learning/Model'
|
||||
3. Connections use '-->' between nodes
|
||||
4. Ensure each line ends with a proper node reference
|
||||
|
||||
Create a detailed flowchart that shows:
|
||||
- Research objectives and questions
|
||||
- All major research methods used
|
||||
- Data collection and analysis processes
|
||||
- Key experimental or analytical steps
|
||||
- Result synthesis and conclusion formation
|
||||
|
||||
Make the flowchart as detailed as possible while maintaining clarity.
|
||||
If no valid flowchart is found in any segment, set this field to null.]"
|
||||
}}
|
||||
"""
|
||||
|
||||
result = await call_deepseek_api_async([
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_prompt}
|
||||
])
|
||||
|
||||
if result:
|
||||
print("分析结果合并完成")
|
||||
else:
|
||||
print("分析结果合并失败")
|
||||
|
||||
return result
|
||||
async def analyze_reference_document_async(content: str):
|
||||
"""分析文献的基本信息和内容"""
|
||||
system_prompt = """
|
||||
You are an AI assistant tasked with analyzing academic paper.
|
||||
Generate a comprehensive analysis in JSON format covering both basic information and content analysis.
|
||||
The JSON structure must strictly follow the provided template.
|
||||
Also generate a Mermaid flowchart code to visualize the research methodology.
|
||||
Ensure the flowchart follows strict formatting rules to avoid parsing errors.
|
||||
Create a detailed and comprehensive flowchart that accurately represents the paper's research methodology.
|
||||
"""
|
||||
|
||||
user_prompt = f"""Analyze the following paper and extract all relevant information:
|
||||
Content: {content}
|
||||
|
||||
Generate a JSON response with the following structure:
|
||||
|
||||
{{
|
||||
"1. Basic Information": {{
|
||||
"author": "[Author name(s) and affiliations]",
|
||||
"publication_date": "[Publication date in YYYY-MM format]",
|
||||
"title": "[Full title of the document]",
|
||||
"journal_publisher": "[Journal name or publisher details]",
|
||||
"document_type": "[Type: journal article/book/conference paper etc.]"
|
||||
}},
|
||||
"2. Content Analysis": {{
|
||||
"abstract": "[Paper abstract]",
|
||||
"research_purpose": "[Main objectives and research questions]",
|
||||
"methodology": "[Research methods, data collection and analysis approaches]",
|
||||
"main_arguments": "[Key theoretical frameworks and arguments]",
|
||||
"conclusions": "[Primary findings and conclusions]",
|
||||
"innovations": "[Novel contributions and original aspects]"
|
||||
}},
|
||||
"flowchart": "[Generate a Mermaid flowchart code that visualizes the research methodology. Follow these rules:
|
||||
1. Use 'graph TD' for top-down flow
|
||||
2. Each node should be in format: id[text] where:
|
||||
- id is a unique identifier (A, B1, B2, etc.)
|
||||
- text should be simple and clear, using only letters, numbers, and spaces
|
||||
- DO NOT use any special characters including parentheses, colons, commas
|
||||
- abbreviations should be written without parentheses, e.g., 'DNN' not '(DNN)'
|
||||
- use space instead of special characters, e.g., 'Deep Learning Model' not 'Deep-Learning/Model'
|
||||
3. Connections use '-->' between nodes
|
||||
4. Ensure each line ends with a proper node reference
|
||||
|
||||
Create a detailed flowchart that shows:
|
||||
- Research objectives and questions
|
||||
- All major research methods used
|
||||
- Data collection and analysis processes
|
||||
- Key experimental or analytical steps
|
||||
- Result synthesis and conclusion formation
|
||||
|
||||
Make the flowchart as detailed as possible while maintaining clarity.
|
||||
If the paper does not contain clear methodology information, set this field to null.]"
|
||||
}}
|
||||
"""
|
||||
|
||||
return await call_deepseek_api_async([
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_prompt}
|
||||
])
|
||||
|
||||
async def analyze_reference_value_async(content_analysis: dict):
|
||||
"""基于内容分析结果评估文献的价值"""
|
||||
system_prompt = """
|
||||
You are an AI assistant tasked with evaluating the value of academic paper based on its content analysis.
|
||||
Generate a comprehensive value evaluation in JSON format.
|
||||
The JSON structure must strictly follow the provided template.
|
||||
"""
|
||||
|
||||
user_prompt = f"""Based on the following content analysis, evaluate the paper's value:
|
||||
Content Analysis: {json.dumps(content_analysis, ensure_ascii=False)}
|
||||
|
||||
Generate a JSON response with the following structure:
|
||||
|
||||
{{
|
||||
"3. Value Evaluation": {{
|
||||
"academic_contribution": "[Significance to the field of study]",
|
||||
"practical_significance": "[Real-world applications and implications]",
|
||||
"limitations": "[Research constraints and weaknesses]",
|
||||
"implications": "[Suggestions for future research and practice]"
|
||||
}}
|
||||
}}
|
||||
"""
|
||||
|
||||
return await call_deepseek_api_async([
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_prompt}
|
||||
])
|
||||
|
||||
@@ -0,0 +1,205 @@
|
||||
from fastapi import APIRouter, HTTPException, Depends
|
||||
from typing import List
|
||||
from datetime import datetime, timezone
|
||||
from bson import ObjectId
|
||||
from ..cores.db import PyObjectId, get_database
|
||||
from .login import get_current_user, UserModel
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@router.post("/lab/user/devices/{serial_number}")
|
||||
async def add_device_to_user(
|
||||
serial_number: str,
|
||||
current_user: UserModel = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
通过序列号添加设备到用户设备列表
|
||||
|
||||
参数:
|
||||
serial_number: 设备序列号
|
||||
current_user: 当前登录用户
|
||||
返回:
|
||||
设备信息和添加状态
|
||||
"""
|
||||
db = await get_database()
|
||||
|
||||
try:
|
||||
# 查找具有该序列号的设备
|
||||
device = await db.devices.find_one({
|
||||
"serial_numbers": {
|
||||
"$elemMatch": {
|
||||
"serial": serial_number
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
if not device:
|
||||
raise HTTPException(status_code=404, detail="Device not found or invalid serial number")
|
||||
# 检查该序列号是否已被其他用户激活
|
||||
active_device = await db.user_devices.find_one({
|
||||
"serial_number": serial_number,
|
||||
"status": "active"
|
||||
})
|
||||
if active_device:
|
||||
raise HTTPException(status_code=400, detail="Device is already in use by another user")
|
||||
|
||||
# Check if this serial number has already been added
|
||||
existing = await db.user_devices.find_one({
|
||||
"user_id": current_user.id,
|
||||
"serial_number": serial_number
|
||||
})
|
||||
if existing:
|
||||
if existing["status"] == "active":
|
||||
raise HTTPException(status_code=400, detail="This device is already active")
|
||||
else:
|
||||
# 如果设备存在但状态是inactive,则重新激活
|
||||
await db.user_devices.update_one(
|
||||
{"_id": existing["_id"]},
|
||||
{"$set": {"status": "active"}}
|
||||
)
|
||||
return {
|
||||
"message": "Device reactivated successfully",
|
||||
"device": {
|
||||
"user_device_id": str(existing["_id"]),
|
||||
"device_id": str(existing["device_id"]),
|
||||
"serial_number": serial_number,
|
||||
"device_name": existing["device_name"],
|
||||
"device_type": existing["device_type"],
|
||||
"device_number": existing["device_number"],
|
||||
"sensors": existing.get("sensors", []),
|
||||
"status": "active"
|
||||
}
|
||||
}
|
||||
|
||||
# 创建新的关联记录,包含完整的设备信息、序列号和传感器信息
|
||||
user_device_dict = {
|
||||
"user_id": current_user.id,
|
||||
"device_id": device["_id"],
|
||||
"serial_number": serial_number,
|
||||
"device_name": device["device_name"],
|
||||
"device_type": device["device_type"],
|
||||
"device_number": device["device_number"],
|
||||
"sensors": device.get("sensors", []), # 添加传感器信息
|
||||
"status": "active", # 添加状态字段
|
||||
"add_time": datetime.now(timezone.utc) # 添加时间记录
|
||||
}
|
||||
# 只有当序列号状态为available时才更新为in_use
|
||||
await db.devices.update_one(
|
||||
{
|
||||
"_id": device["_id"],
|
||||
"serial_numbers": {
|
||||
"$elemMatch": {
|
||||
"serial": serial_number,
|
||||
"status": "available" # 只匹配状态为available的序列号
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"$set": {
|
||||
"serial_numbers.$.status": "in_use"
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
result = await db.user_devices.insert_one(user_device_dict)
|
||||
|
||||
# 返回设备信息
|
||||
return {
|
||||
"message": "Device added successfully",
|
||||
"device": {
|
||||
"user_device_id": str(result.inserted_id),
|
||||
"device_id": str(device["_id"]),
|
||||
"serial_number": serial_number,
|
||||
"device_name": device["device_name"],
|
||||
"device_type": device["device_type"],
|
||||
"device_number": device["device_number"],
|
||||
"sensors": device.get("sensors", []), # 在响应中包含传感器信息
|
||||
"status": "active" # 添加状态字段"status": "active" # 添加状态字段
|
||||
}
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error adding device to user: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"Failed to add device: {str(e)}")
|
||||
|
||||
@router.get("/lab/userdevices")
|
||||
async def get_user_devices(current_user: UserModel = Depends(get_current_user)):
|
||||
"""
|
||||
获取当前用户的所有传感器列表
|
||||
|
||||
返回:
|
||||
List[Dict]: 传感器列表,包含当前用户关联的所有传感器信息
|
||||
"""
|
||||
db = await get_database()
|
||||
|
||||
try:
|
||||
devices = []
|
||||
async for device in db.user_devices.find({"user_id": current_user.id}):
|
||||
devices.append({
|
||||
"user_device_id": str(device["_id"]),
|
||||
"device_id": str(device["device_id"]),
|
||||
"serial_number": device["serial_number"], # 添加序列号
|
||||
"device_name": device["device_name"],
|
||||
"device_type": device["device_type"],
|
||||
"device_number": device["device_number"],
|
||||
"sensors": device.get("sensors", []), # 在响应中包含传感器信息
|
||||
"status": device.get("status", "active")
|
||||
})
|
||||
|
||||
return devices
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error getting user devices: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"Failed to get user devices: {str(e)}")
|
||||
|
||||
@router.delete("/lab/user/devices/{device_id}")
|
||||
async def remove_user_device(
|
||||
device_id: str,
|
||||
current_user: UserModel = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
将设备标记为非活动状态
|
||||
|
||||
参数:
|
||||
device_id: 用户设备的ID
|
||||
current_user: 当前登录用户
|
||||
"""
|
||||
db = await get_database()
|
||||
|
||||
try:
|
||||
# 验证设备是否存在且属于当前用户,且状态为active
|
||||
device = await db.user_devices.find_one({
|
||||
"_id": ObjectId(device_id),
|
||||
"user_id": current_user.id,
|
||||
"status": "active" # 只能停用活动状态的设备
|
||||
})
|
||||
if not device:
|
||||
raise HTTPException(status_code=404, detail="Active device not found or unauthorized access")
|
||||
|
||||
# 更新设备状态为非活动
|
||||
result = await db.user_devices.update_one(
|
||||
{
|
||||
"_id": ObjectId(device_id),
|
||||
"user_id": current_user.id
|
||||
},
|
||||
{
|
||||
"$set": {
|
||||
"status": "inactive",
|
||||
"deactivate_time": datetime.now(timezone.utc) # 记录停用时间
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
if result.modified_count == 0:
|
||||
raise HTTPException(status_code=500, detail="Failed to deactivate device")
|
||||
|
||||
return {
|
||||
"message": "Device deactivated successfully",
|
||||
"device_id": device_id,
|
||||
"status": "inactive", # 添加状态字段
|
||||
"deactivate_time": datetime.now(timezone.utc).isoformat()
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error deactivating device: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"Failed to deactivate device: {str(e)}")
|
||||
@@ -0,0 +1,632 @@
|
||||
from fastapi import APIRouter, HTTPException, Depends, Query
|
||||
from typing import List, Optional
|
||||
from datetime import datetime, timezone
|
||||
from bson import ObjectId
|
||||
from pydantic import BaseModel, Field
|
||||
import json
|
||||
from ..cores.db import PyObjectId, get_database, get_redis
|
||||
from .login import get_current_user, UserModel
|
||||
from io import BytesIO, StringIO
|
||||
import zipfile
|
||||
import csv
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
class ExperimentStatus:
|
||||
ACTIVE = "active"
|
||||
COMPLETED = "completed"
|
||||
|
||||
class ExperimentModel(BaseModel):
|
||||
id: Optional[PyObjectId] = Field(alias="_id", default=None)
|
||||
project_id: PyObjectId
|
||||
experiment_name: str
|
||||
create_time: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
description: str
|
||||
status: str = Field(default=ExperimentStatus.ACTIVE) # 添加状态字段
|
||||
|
||||
class Config:
|
||||
populate_by_name = True
|
||||
arbitrary_types_allowed = True
|
||||
json_encoders = {ObjectId: str}
|
||||
|
||||
class ExperimentCreate(BaseModel):
|
||||
"""实验创建请求模型"""
|
||||
project_id: str
|
||||
experiment_name: str
|
||||
description: str | None = None # 修改为可选字段,默认为None
|
||||
|
||||
# 创建实验
|
||||
@router.post("/lab/experiments")
|
||||
async def create_experiment(
|
||||
experiment: ExperimentCreate,
|
||||
current_user: UserModel = Depends(get_current_user)
|
||||
):
|
||||
"""创建新实验"""
|
||||
db = await get_database()
|
||||
|
||||
try:
|
||||
# 验证项目是否存在且于当前用户
|
||||
project = await db.projects.find_one({
|
||||
"_id": ObjectId(experiment.project_id),
|
||||
"user_id": current_user.id
|
||||
})
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Project not found or unauthorized access")
|
||||
|
||||
# 构造实验文档
|
||||
experiment_dict = {
|
||||
"project_id": ObjectId(experiment.project_id),
|
||||
"experiment_name": experiment.experiment_name,
|
||||
"description": experiment.description or "", # 如果None则使用空字符串
|
||||
"create_time": datetime.now(timezone.utc),
|
||||
"status": ExperimentStatus.ACTIVE
|
||||
}
|
||||
|
||||
result = await db.experiments.insert_one(experiment_dict)
|
||||
|
||||
return {
|
||||
"message": "Experiment created successfully",
|
||||
"id": str(result.inserted_id)
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error creating experiment: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"Failed to create experiment: {str(e)}")
|
||||
|
||||
# 获取实验列表
|
||||
@router.get("/lab/experiments")
|
||||
async def get_experiments(
|
||||
project_id: str,
|
||||
current_user: UserModel = Depends(get_current_user)
|
||||
):
|
||||
"""获取项目下的所有实验"""
|
||||
db = await get_database()
|
||||
|
||||
# 验证项目是否存在且属于当前用户
|
||||
project = await db.projects.find_one({
|
||||
"_id": ObjectId(project_id),
|
||||
"user_id": current_user.id
|
||||
})
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Project not found or unauthorized access")
|
||||
|
||||
# 查询实验列表
|
||||
experiments = []
|
||||
async for experiment in db.experiments.find({"project_id": ObjectId(project_id)}):
|
||||
experiments.append(ExperimentModel(**experiment))
|
||||
|
||||
return experiments
|
||||
# 数据模型
|
||||
class ExperimentSession(BaseModel):
|
||||
"""实会话模型,记录每次实验的开始和结束时间"""
|
||||
id: Optional[PyObjectId] = Field(alias="_id", default=None)
|
||||
experiment_id: PyObjectId
|
||||
user_id: PyObjectId
|
||||
start_time: datetime
|
||||
end_time: Optional[datetime] = None
|
||||
duration: Optional[float] = None # 持续时间(秒)
|
||||
|
||||
class Config:
|
||||
populate_by_name = True
|
||||
arbitrary_types_allowed = True
|
||||
json_encoders = {ObjectId: str}
|
||||
|
||||
class ExperimentData(BaseModel):
|
||||
"""实验数据模型"""
|
||||
id: Optional[PyObjectId] = Field(alias="_id", default=None)
|
||||
experiment_id: PyObjectId
|
||||
session_ids: List[PyObjectId] # 修改为session_ids数组
|
||||
user_id: PyObjectId
|
||||
device_id: str
|
||||
sensor_name: str
|
||||
last_update: datetime # 添加最后更新时间
|
||||
|
||||
class Config:
|
||||
populate_by_name = True
|
||||
arbitrary_types_allowed = True
|
||||
json_encoders = {ObjectId: str}
|
||||
|
||||
# 开始实验
|
||||
@router.post("/lab/experiments/{experiment_id}/start")
|
||||
async def start_experiment_session(
|
||||
experiment_id: str,
|
||||
current_user: UserModel = Depends(get_current_user)
|
||||
):
|
||||
"""开始新的实验会话"""
|
||||
db = await get_database()
|
||||
redis = await get_redis()
|
||||
|
||||
try:
|
||||
# 首先获取该实验关联的所有设备
|
||||
experiment_devices = []
|
||||
async for device in db.experiment_devices.find({"experiment_id": ObjectId(experiment_id)}):
|
||||
experiment_devices.append({
|
||||
"user_device_id": str(device["user_device_id"]),
|
||||
"device_name": device["device_name"],
|
||||
"serial_number": device["serial_number"],
|
||||
"sensors": device.get("sensors", [])
|
||||
})
|
||||
|
||||
if not experiment_devices:
|
||||
raise HTTPException(status_code=400, detail="No devices added to the experiment")
|
||||
|
||||
# 更新实验状态到Redis (新增)
|
||||
await redis.select(199)
|
||||
for device in experiment_devices:
|
||||
status_key = f"experiment_status:{experiment_id}:{device['serial_number']}"
|
||||
await redis.set(status_key, "active")
|
||||
|
||||
# 创建实验会话记录
|
||||
session = {
|
||||
"experiment_id": ObjectId(experiment_id),
|
||||
"user_id": current_user.id,
|
||||
"start_time": datetime.now(timezone.utc),
|
||||
"devices": experiment_devices
|
||||
}
|
||||
|
||||
result = await db.experiment_sessions.insert_one(session)
|
||||
session_id = str(result.inserted_id)
|
||||
|
||||
return {
|
||||
"message": "Experiment session started",
|
||||
"session_id": session_id,
|
||||
"devices": experiment_devices
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error starting experiment session: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"Failed to start experiment session: {str(e)}")
|
||||
finally:
|
||||
if redis:
|
||||
await redis.aclose()
|
||||
|
||||
# 停止实验
|
||||
@router.post("/lab/experiments/{experiment_id}/stop")
|
||||
async def stop_experiment_session(
|
||||
experiment_id: str,
|
||||
session_data: dict,
|
||||
current_user: UserModel = Depends(get_current_user)
|
||||
):
|
||||
"""停止实验会话"""
|
||||
db = await get_database()
|
||||
redis = await get_redis()
|
||||
|
||||
try:
|
||||
session_id = session_data.get("session_id")
|
||||
if not session_id:
|
||||
raise HTTPException(status_code=422, detail="Missing session_id parameter")
|
||||
|
||||
# 验证session是否存在且未结束
|
||||
session = await db.experiment_sessions.find_one({
|
||||
"_id": ObjectId(session_id),
|
||||
"experiment_id": ObjectId(experiment_id),
|
||||
"end_time": None
|
||||
})
|
||||
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail="Active experiment session not found")
|
||||
|
||||
# 更新实验状态到Redis (新增)
|
||||
await redis.select(199) # 使用同一个db
|
||||
for device in session.get("devices", []):
|
||||
status_key = f"experiment_status:{experiment_id}:{device['serial_number']}"
|
||||
await redis.set(status_key, "inactive")
|
||||
|
||||
# 确保start_time是带时区的
|
||||
start_time = session["start_time"]
|
||||
if start_time.tzinfo is None:
|
||||
start_time = start_time.replace(tzinfo=timezone.utc)
|
||||
|
||||
end_time = datetime.now(timezone.utc) # 确保end_time带有UTC时区
|
||||
duration = (end_time - start_time).total_seconds()
|
||||
|
||||
# 更新会话状态
|
||||
result = await db.experiment_sessions.update_one(
|
||||
{"_id": ObjectId(session_id)},
|
||||
{
|
||||
"$set": {
|
||||
"end_time": end_time,
|
||||
"duration": duration
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
if result.modified_count == 0:
|
||||
raise HTTPException(status_code=400, detail="Failed to update session status")
|
||||
|
||||
return {
|
||||
"message": "Experiment session stopped",
|
||||
"session_id": session_id,
|
||||
"duration": duration
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error stopping experiment session: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"Failed to stop experiment session: {str(e)}")
|
||||
finally:
|
||||
if redis:
|
||||
await redis.aclose()
|
||||
|
||||
# 获取实验会话历史的路由
|
||||
@router.get("/lab/experiments/{experiment_id}/sessions")
|
||||
async def get_experiment_sessions(
|
||||
experiment_id: str,
|
||||
current_user: UserModel = Depends(get_current_user)
|
||||
):
|
||||
"""获取实验的所有会话记录"""
|
||||
db = await get_database()
|
||||
|
||||
try:
|
||||
sessions = []
|
||||
async for session in db.experiment_sessions.find({
|
||||
"experiment_id": ObjectId(experiment_id)
|
||||
}).sort("start_time", -1):
|
||||
sessions.append({
|
||||
"session_id": str(session["_id"]),
|
||||
"start_time": session["start_time"],
|
||||
"end_time": session.get("end_time"),
|
||||
"duration": session.get("duration")
|
||||
})
|
||||
|
||||
return sessions
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to get experiment session records: {str(e)}")
|
||||
|
||||
# 导出实验数据
|
||||
@router.get("/lab/experiments/{experiment_id}/export")
|
||||
async def export_experiment_data(
|
||||
experiment_id: str,
|
||||
current_user: UserModel = Depends(get_current_user)
|
||||
):
|
||||
"""导出实验的所有数据为ZIP文件(包含多个CSV)"""
|
||||
db = await get_database()
|
||||
redis = await get_redis()
|
||||
|
||||
try:
|
||||
zip_buffer = BytesIO()
|
||||
|
||||
with zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED) as zip_file:
|
||||
# 获取实验的所有会话
|
||||
sessions = []
|
||||
async for session in db.experiment_sessions.find({
|
||||
"experiment_id": ObjectId(experiment_id)
|
||||
}):
|
||||
end_time = session.get("end_time") or datetime.now(timezone.utc)
|
||||
duration = (end_time - session["start_time"]).total_seconds()
|
||||
sessions.append({
|
||||
"id": str(session["_id"]),
|
||||
"start_time": session["start_time"],
|
||||
"end_time": end_time,
|
||||
"devices": session.get("devices", []),
|
||||
"duration": duration
|
||||
})
|
||||
|
||||
if not sessions:
|
||||
raise HTTPException(status_code=404, detail="No experiment session data found")
|
||||
|
||||
# 写入会话信息到sessions.csv
|
||||
sessions_output = StringIO()
|
||||
sessions_writer = csv.writer(sessions_output)
|
||||
sessions_writer.writerow(['Session ID', 'Start Time', 'End Time', 'Duration (seconds)'])
|
||||
|
||||
for session in sessions:
|
||||
sessions_writer.writerow([
|
||||
session["id"],
|
||||
session["start_time"].isoformat(),
|
||||
session["end_time"].isoformat() if isinstance(session["end_time"], datetime) else "Active",
|
||||
session["duration"]
|
||||
])
|
||||
|
||||
zip_file.writestr('sessions.csv', sessions_output.getvalue())
|
||||
|
||||
# 创建README.txt
|
||||
readme_content = """Experiment Data Export Instructions:
|
||||
1. sessions.csv: Contains basic information for all experiment sessions
|
||||
2. Data for each device is stored in folders named by device serial number
|
||||
3. Each device folder contains multiple sensor_data_XXX.csv files, each file containing sensor data for a specific time period
|
||||
4. All timestamps are in UTC timezone"""
|
||||
|
||||
zip_file.writestr('README.txt', readme_content)
|
||||
|
||||
# 选择db200获取原始数据
|
||||
await redis.select(200)
|
||||
|
||||
# 按设备处理数据
|
||||
for session in sessions:
|
||||
for device in session["devices"]:
|
||||
serial_number = device["serial_number"]
|
||||
stream_key = f"experiment:{experiment_id}:{serial_number}"
|
||||
|
||||
try:
|
||||
start_ms = int(session["start_time"].timestamp() * 1000)
|
||||
end_ms = int(session["end_time"].timestamp() * 1000)
|
||||
|
||||
stream_data = await redis.xrange(
|
||||
stream_key,
|
||||
min=str(start_ms),
|
||||
max=str(end_ms)
|
||||
)
|
||||
|
||||
# 将数据分批处理,每个文件最多包含100000行数据
|
||||
MAX_ROWS_PER_FILE = 100000
|
||||
current_file_rows = 0
|
||||
file_counter = 1
|
||||
current_output = StringIO()
|
||||
current_writer = csv.writer(current_output)
|
||||
current_writer.writerow(['Session ID', 'Device Time', 'Sensor', 'Value'])
|
||||
|
||||
for entry_id, data in stream_data:
|
||||
stream_timestamp = int(entry_id.split('-')[0])
|
||||
device_time = datetime.fromtimestamp(
|
||||
stream_timestamp / 1000,
|
||||
tz=timezone.utc
|
||||
)
|
||||
|
||||
data_str = data[b'data'] if isinstance(data.get('data'), bytes) else data['data']
|
||||
if isinstance(data_str, bytes):
|
||||
data_str = data_str.decode('utf-8')
|
||||
|
||||
point_data = json.loads(data_str)
|
||||
channel_data = point_data["channel_data"]
|
||||
|
||||
for sensor_name, values in channel_data.items():
|
||||
if isinstance(values, list):
|
||||
for value in values:
|
||||
current_writer.writerow([
|
||||
session["id"],
|
||||
device_time.isoformat(),
|
||||
sensor_name,
|
||||
value
|
||||
])
|
||||
current_file_rows += 1
|
||||
else:
|
||||
current_writer.writerow([
|
||||
session["id"],
|
||||
device_time.isoformat(),
|
||||
sensor_name,
|
||||
values
|
||||
])
|
||||
current_file_rows += 1
|
||||
|
||||
# 如果当前文件达到行数限制,保存并创建新文件
|
||||
if current_file_rows >= MAX_ROWS_PER_FILE:
|
||||
file_name = f"{serial_number}/session_{session['id']}/sensor_data_{file_counter:03d}.csv"
|
||||
zip_file.writestr(file_name, current_output.getvalue())
|
||||
|
||||
current_output = StringIO()
|
||||
current_writer = csv.writer(current_output)
|
||||
current_writer.writerow(['Session ID', 'Device Time', 'Sensor', 'Value'])
|
||||
current_file_rows = 0
|
||||
file_counter += 1
|
||||
|
||||
# 保存最后一个文件
|
||||
if current_file_rows > 0:
|
||||
file_name = f"{serial_number}/session_{session['id']}/sensor_data_{file_counter:03d}.csv"
|
||||
zip_file.writestr(file_name, current_output.getvalue())
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error processing stream data: {str(e)}")
|
||||
continue
|
||||
|
||||
# 获取实验信息用于文件命名
|
||||
experiment = await db.experiments.find_one({"_id": ObjectId(experiment_id)})
|
||||
zip_filename = f"experiment_{experiment['experiment_name']}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.zip"
|
||||
|
||||
# 准备ZIP文<50><E69687>下载
|
||||
zip_buffer.seek(0)
|
||||
|
||||
return StreamingResponse(
|
||||
iter([zip_buffer.getvalue()]),
|
||||
media_type="application/zip",
|
||||
headers={
|
||||
'Content-Disposition': f'attachment; filename="{zip_filename}"'
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error exporting data: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"Failed to export data: {str(e)}")
|
||||
finally:
|
||||
try:
|
||||
await redis.aclose()
|
||||
except Exception as e:
|
||||
print(f"Error closing Redis connection: {e}")
|
||||
|
||||
# 完成实验的路由
|
||||
@router.post("/lab/experiments/{experiment_id}/complete")
|
||||
async def complete_experiment(
|
||||
experiment_id: str,
|
||||
current_user: UserModel = Depends(get_current_user)
|
||||
):
|
||||
"""完成实验,将实验状态标记为已完成"""
|
||||
db = await get_database()
|
||||
|
||||
try:
|
||||
# 验证实验是否存在
|
||||
experiment = await db.experiments.find_one({"_id": ObjectId(experiment_id)})
|
||||
if not experiment:
|
||||
raise HTTPException(status_code=404, detail="Experiment not found")
|
||||
|
||||
# 检查是否有正在进行的会话
|
||||
active_session = await db.experiment_sessions.find_one({
|
||||
"experiment_id": ObjectId(experiment_id),
|
||||
"end_time": None
|
||||
})
|
||||
|
||||
if active_session:
|
||||
raise HTTPException(status_code=400, detail="Please stop all ongoing experiment sessions first")
|
||||
|
||||
# 更新实验状态为已完成
|
||||
result = await db.experiments.update_one(
|
||||
{"_id": ObjectId(experiment_id)},
|
||||
{"$set": {"status": ExperimentStatus.COMPLETED}}
|
||||
)
|
||||
|
||||
if result.modified_count == 0:
|
||||
raise HTTPException(status_code=400, detail="Failed to update experiment status")
|
||||
|
||||
return {"message": "Experiment completed"}
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error completing experiment: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"Failed to complete experiment: {str(e)}")
|
||||
|
||||
|
||||
# 进入实验的路由,添加状态检查
|
||||
@router.get("/lab/experiments/{experiment_id}")
|
||||
async def get_experiment(
|
||||
experiment_id: str,
|
||||
current_user: UserModel = Depends(get_current_user)
|
||||
):
|
||||
"""获取实验详情,包括状态"""
|
||||
db = await get_database()
|
||||
|
||||
try:
|
||||
experiment = await db.experiments.find_one({"_id": ObjectId(experiment_id)})
|
||||
if not experiment:
|
||||
raise HTTPException(status_code=404, detail="Experiment not found")
|
||||
|
||||
# 创建一个新的字典来存储处理后的数据
|
||||
experiment_dict = {
|
||||
"_id": str(experiment["_id"]),
|
||||
"project_id": str(experiment["project_id"]),
|
||||
"experiment_name": experiment["experiment_name"],
|
||||
"create_time": experiment["create_time"],
|
||||
"description": experiment["description"],
|
||||
"status": experiment.get("status", ExperimentStatus.ACTIVE)
|
||||
}
|
||||
|
||||
# 获取实验会话历史
|
||||
sessions = []
|
||||
async for session in db.experiment_sessions.find({"experiment_id": ObjectId(experiment_id)}):
|
||||
session_dict = {
|
||||
"id": str(session["_id"]),
|
||||
"experiment_id": str(session["experiment_id"]),
|
||||
"start_time": session["start_time"],
|
||||
"end_time": session.get("end_time"),
|
||||
"duration": session.get("duration")
|
||||
}
|
||||
sessions.append(session_dict)
|
||||
|
||||
experiment_dict["sessions"] = sessions
|
||||
|
||||
return experiment_dict
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error getting experiment: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"Failed to get experiment: {str(e)}")
|
||||
|
||||
# 删除实验的路由
|
||||
@router.delete("/lab/experiments/{experiment_id}")
|
||||
async def delete_experiment(
|
||||
experiment_id: str,
|
||||
current_user: UserModel = Depends(get_current_user)
|
||||
):
|
||||
"""删除实验及其相关数据"""
|
||||
db = await get_database()
|
||||
redis = await get_redis()
|
||||
|
||||
try:
|
||||
# 验证实验是否存在且属于当前用户的项目
|
||||
experiment = await db.experiments.find_one({"_id": ObjectId(experiment_id)})
|
||||
if not experiment:
|
||||
raise HTTPException(status_code=404, detail="Experiment not found")
|
||||
|
||||
project = await db.projects.find_one({
|
||||
"_id": experiment["project_id"],
|
||||
"user_id": current_user.id
|
||||
})
|
||||
if not project:
|
||||
raise HTTPException(status_code=403, detail="Unauthorized to delete this experiment")
|
||||
|
||||
# 删除实验相关数据
|
||||
await db.experiment_sessions.delete_many({"experiment_id": ObjectId(experiment_id)})
|
||||
await db.experiment_devices.delete_many({"experiment_id": ObjectId(experiment_id)})
|
||||
await db.experiments.delete_one({"_id": ObjectId(experiment_id)})
|
||||
|
||||
# 删除Redis中的实验报告
|
||||
await redis.select(201)
|
||||
await redis.delete(f"experiment_report:{experiment_id}")
|
||||
|
||||
return {"message": "Experiment successfully deleted"}
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error deleting experiment: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"Failed to delete experiment: {str(e)}")
|
||||
finally:
|
||||
try:
|
||||
await redis.aclose()
|
||||
except Exception as e:
|
||||
print(f"Error closing Redis connection: {e}")
|
||||
|
||||
# 添加公式模型
|
||||
class FormulaModel(BaseModel):
|
||||
data_name: str
|
||||
data_unit: str
|
||||
formula: str
|
||||
|
||||
# 添加公式请求模型
|
||||
class FormulaCreate(BaseModel):
|
||||
sensor_name: str
|
||||
data_name: str
|
||||
data_unit: str
|
||||
formula: str
|
||||
|
||||
@router.post("/lab/experiments/{experiment_id}/devices/{device_id}/formulas")
|
||||
async def add_sensor_formula(
|
||||
experiment_id: str,
|
||||
device_id: str,
|
||||
formula: FormulaCreate,
|
||||
current_user: UserModel = Depends(get_current_user)
|
||||
):
|
||||
"""为实验设备的传感器添加计算公式"""
|
||||
db = await get_database()
|
||||
|
||||
try:
|
||||
# 验证实验设备是否存在,使用 user_device_id 查询
|
||||
device = await db.experiment_devices.find_one({
|
||||
"user_device_id": ObjectId(device_id),
|
||||
"experiment_id": ObjectId(experiment_id)
|
||||
})
|
||||
|
||||
if not device:
|
||||
raise HTTPException(status_code=404, detail="Experiment device not found")
|
||||
|
||||
# 查找对应的传感器
|
||||
sensor_found = False
|
||||
sensors = device.get("sensors", [])
|
||||
for sensor in sensors:
|
||||
if sensor["sensor_name"] == formula.sensor_name:
|
||||
# 初始化或获取现有公式列表
|
||||
if "formulas" not in sensor:
|
||||
sensor["formulas"] = []
|
||||
# 添加新公式
|
||||
sensor["formulas"].append({
|
||||
"data_name": formula.data_name,
|
||||
"data_unit": formula.data_unit,
|
||||
"formula": formula.formula
|
||||
})
|
||||
sensor_found = True
|
||||
break
|
||||
|
||||
if not sensor_found:
|
||||
raise HTTPException(status_code=404, detail="Sensor not found")
|
||||
|
||||
# 更新设备文档,使用 user_device_id 更新
|
||||
result = await db.experiment_devices.update_one(
|
||||
{
|
||||
"user_device_id": ObjectId(device_id),
|
||||
"experiment_id": ObjectId(experiment_id)
|
||||
},
|
||||
{"$set": {"sensors": sensors}}
|
||||
)
|
||||
|
||||
if result.modified_count == 0:
|
||||
raise HTTPException(status_code=400, detail="Failed to add formula")
|
||||
|
||||
return {"message": "Formula added successfully"}
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error adding formula: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"Failed to add formula: {str(e)}")
|
||||
@@ -0,0 +1,238 @@
|
||||
from fastapi import APIRouter, HTTPException, Depends
|
||||
from typing import List, Optional
|
||||
from datetime import datetime, timezone
|
||||
from bson import ObjectId
|
||||
from pydantic import BaseModel, Field
|
||||
from ..cores.db import PyObjectId, get_database
|
||||
from .login import get_current_user, UserModel
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
class ExperimentDeviceModel(BaseModel):
|
||||
"""实验-设备关联模型"""
|
||||
id: Optional[PyObjectId] = Field(alias="_id", default=None)
|
||||
experiment_id: PyObjectId
|
||||
user_device_id: PyObjectId
|
||||
|
||||
class Config:
|
||||
populate_by_name = True
|
||||
arbitrary_types_allowed = True
|
||||
json_encoders = {ObjectId: str}
|
||||
|
||||
@router.post("/lab/experiments/{experiment_id}/devices/{user_device_id}")
|
||||
async def add_device_to_experiment(
|
||||
experiment_id: str,
|
||||
user_device_id: str,
|
||||
current_user: UserModel = Depends(get_current_user)
|
||||
):
|
||||
"""将设备添加到特定实验中"""
|
||||
db = await get_database()
|
||||
|
||||
try:
|
||||
# 验证实验是否存在且属于当前用户
|
||||
experiment = await db.experiments.find_one({"_id": ObjectId(experiment_id)})
|
||||
if not experiment:
|
||||
raise HTTPException(status_code=404, detail="Experiment not found")
|
||||
|
||||
# 验证项目所有权
|
||||
project = await db.projects.find_one({
|
||||
"_id": experiment["project_id"],
|
||||
"user_id": current_user.id
|
||||
})
|
||||
if not project:
|
||||
raise HTTPException(status_code=403, detail="Unauthorized access to this experiment")
|
||||
|
||||
# 验证用户设备是否存在且状态为active
|
||||
user_device = await db.user_devices.find_one({
|
||||
"_id": ObjectId(user_device_id),
|
||||
"user_id": current_user.id,
|
||||
"status": "active" # 添加状态检查
|
||||
})
|
||||
if not user_device:
|
||||
raise HTTPException(status_code=404, detail="Active device not found or unauthorized access")
|
||||
|
||||
# 检查是否已经添加过这个设备
|
||||
existing = await db.experiment_devices.find_one({
|
||||
"experiment_id": ObjectId(experiment_id),
|
||||
"user_device_id": ObjectId(user_device_id)
|
||||
})
|
||||
if existing:
|
||||
raise HTTPException(status_code=400, detail="This device has already been added to the experiment")
|
||||
|
||||
# 创建关联记录,直接复制用户设备的所有信息
|
||||
experiment_device_dict = {
|
||||
"experiment_id": ObjectId(experiment_id),
|
||||
"user_device_id": ObjectId(user_device_id),
|
||||
**{k:v for k,v in user_device.items() if k not in ['_id', 'user_id']}
|
||||
}
|
||||
|
||||
result = await db.experiment_devices.insert_one(experiment_device_dict)
|
||||
|
||||
return {
|
||||
"message": "Device added successfully",
|
||||
"experiment_device_id": str(result.inserted_id)
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error adding device to experiment: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"Failed to add device: {str(e)}")
|
||||
|
||||
@router.get("/lab/experiments/{experiment_id}/devices")
|
||||
async def get_experiment_devices(
|
||||
experiment_id: str,
|
||||
current_user: UserModel = Depends(get_current_user)
|
||||
):
|
||||
"""获取实验中的所有设备信息"""
|
||||
db = await get_database()
|
||||
|
||||
try:
|
||||
# 验证实验是否存在且属于当前用户
|
||||
experiment = await db.experiments.find_one({"_id": ObjectId(experiment_id)})
|
||||
if not experiment:
|
||||
raise HTTPException(status_code=404, detail="Experiment not found")
|
||||
|
||||
# 验证项目所有权
|
||||
project = await db.projects.find_one({
|
||||
"_id": experiment["project_id"],
|
||||
"user_id": current_user.id
|
||||
})
|
||||
if not project:
|
||||
raise HTTPException(status_code=403, detail="Unauthorized access to this experiment")
|
||||
|
||||
# 获取实验关联的所有设备
|
||||
devices = []
|
||||
async for device in db.experiment_devices.find({"experiment_id": ObjectId(experiment_id)}):
|
||||
# 确保返回传感器的公式信息
|
||||
device_info = {
|
||||
"_id": str(device["_id"]),
|
||||
"experiment_id": str(device["experiment_id"]),
|
||||
"user_device_id": str(device["user_device_id"]),
|
||||
"device_id": str(device["device_id"]),
|
||||
"serial_number": device["serial_number"],
|
||||
"device_name": device["device_name"],
|
||||
"device_type": device["device_type"],
|
||||
"device_number": device["device_number"],
|
||||
"sensors": []
|
||||
}
|
||||
|
||||
# 保留完整的传感器信息,包括 index
|
||||
for sensor in device.get("sensors", []):
|
||||
sensor_info = {
|
||||
"index": sensor["index"], # 保留 index
|
||||
"sensor_name": sensor["sensor_name"],
|
||||
"sensor_type": sensor["sensor_type"],
|
||||
"unit": sensor["unit"]
|
||||
}
|
||||
# 如果有公式信息,也添加进去
|
||||
if "formulas" in sensor:
|
||||
sensor_info["formulas"] = sensor["formulas"]
|
||||
|
||||
device_info["sensors"].append(sensor_info)
|
||||
|
||||
devices.append(device_info)
|
||||
|
||||
return devices
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error getting experiment devices: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"Failed to get experiment devices: {str(e)}")
|
||||
|
||||
@router.get("/lab/experiments/{experiment_id}/devices/public")
|
||||
async def get_experiment_devices_public(
|
||||
experiment_id: str
|
||||
):
|
||||
"""获取实验中的所有设备信息(公开接口)"""
|
||||
db = await get_database()
|
||||
|
||||
try:
|
||||
# 验证实验是否存在
|
||||
experiment = await db.experiments.find_one({"_id": ObjectId(experiment_id)})
|
||||
if not experiment:
|
||||
raise HTTPException(status_code=404, detail="Experiment not found")
|
||||
|
||||
# 获取实验关联的所有设备
|
||||
devices = []
|
||||
async for device in db.experiment_devices.find({"experiment_id": ObjectId(experiment_id)}):
|
||||
# 确保返回传感器的公式信息
|
||||
device_info = {
|
||||
"_id": str(device["_id"]),
|
||||
"experiment_id": str(device["experiment_id"]),
|
||||
"user_device_id": str(device["user_device_id"]),
|
||||
"device_id": str(device["device_id"]),
|
||||
"serial_number": device["serial_number"],
|
||||
"device_name": device["device_name"],
|
||||
"device_type": device["device_type"],
|
||||
"device_number": device["device_number"],
|
||||
"sensors": []
|
||||
}
|
||||
|
||||
# 保留完整的传感器信息,包括 index
|
||||
for sensor in device.get("sensors", []):
|
||||
sensor_info = {
|
||||
"index": sensor["index"], # 保留 index
|
||||
"sensor_name": sensor["sensor_name"],
|
||||
"sensor_type": sensor["sensor_type"],
|
||||
"unit": sensor["unit"]
|
||||
}
|
||||
# 如果有公式信息,也添加进去
|
||||
if "formulas" in sensor:
|
||||
sensor_info["formulas"] = sensor["formulas"]
|
||||
|
||||
device_info["sensors"].append(sensor_info)
|
||||
|
||||
devices.append(device_info)
|
||||
|
||||
return devices
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error getting experiment devices: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"Failed to get experiment devices: {str(e)}")
|
||||
|
||||
@router.delete("/lab/experiments/{experiment_id}/devices/{user_device_id}")
|
||||
async def remove_device_from_experiment(
|
||||
experiment_id: str,
|
||||
user_device_id: str,
|
||||
current_user: UserModel = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
从实验中移除特定设备
|
||||
|
||||
参数:
|
||||
experiment_id: 实验ID
|
||||
user_device_id: 用户设备ID
|
||||
current_user: 当前登录用户
|
||||
"""
|
||||
db = await get_database()
|
||||
|
||||
try:
|
||||
# 验证实验是否存在且属于当前用户
|
||||
experiment = await db.experiments.find_one({"_id": ObjectId(experiment_id)})
|
||||
if not experiment:
|
||||
raise HTTPException(status_code=404, detail="Experiment not found")
|
||||
|
||||
# 验证项目所有权
|
||||
project = await db.projects.find_one({
|
||||
"_id": experiment["project_id"],
|
||||
"user_id": current_user.id
|
||||
})
|
||||
if not project:
|
||||
raise HTTPException(status_code=403, detail="Unauthorized to delete this experiment")
|
||||
|
||||
# 删除实验-设备关联
|
||||
result = await db.experiment_devices.delete_one({
|
||||
"experiment_id": ObjectId(experiment_id),
|
||||
"user_device_id": ObjectId(user_device_id)
|
||||
})
|
||||
|
||||
if result.deleted_count == 0:
|
||||
raise HTTPException(status_code=404, detail="Device association not found")
|
||||
|
||||
return {
|
||||
"message": "Device removed from experiment",
|
||||
"experiment_id": experiment_id,
|
||||
"user_device_id": user_device_id
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error removing device from experiment: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"Failed to remove device: {str(e)}")
|
||||
@@ -0,0 +1,515 @@
|
||||
from fastapi import APIRouter, HTTPException, Depends
|
||||
from typing import Dict, Any
|
||||
import json
|
||||
from datetime import datetime, timezone
|
||||
import asyncio
|
||||
from motor.motor_asyncio import AsyncIOMotorClient
|
||||
from bson import ObjectId
|
||||
from redis import asyncio as aioredis
|
||||
from openai import OpenAI
|
||||
from pydantic import BaseModel
|
||||
from ..cores.config import (
|
||||
MONGODB_URL, REDIS_URL, DEEPSEEK_API_CONFIG,
|
||||
REDIS_DB_MAP, TaskStatus, THREAD_POOL_CONFIG, format_response,
|
||||
exp_analysis_thread_pool
|
||||
)
|
||||
|
||||
from ..cores.db import get_database, get_redis
|
||||
from .login import get_current_user, UserModel
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/lab",
|
||||
tags=["experiment_report"]
|
||||
)
|
||||
|
||||
# DeepSeek API Configuration
|
||||
client = OpenAI(
|
||||
base_url=DEEPSEEK_API_CONFIG["base_url"],
|
||||
api_key=DEEPSEEK_API_CONFIG["api_key"]
|
||||
)
|
||||
|
||||
async def analyze_experiment_data(experiment_info):
|
||||
system_prompt = """
|
||||
You are an AI assistant tasked with analyzing experimental data.
|
||||
Generate a comprehensive experiment analysis report in JSON format.
|
||||
The JSON structure must strictly follow the provided template.
|
||||
"""
|
||||
|
||||
user_prompt = f"""Analyze the experiment data based on the following information:
|
||||
Experiment data: {json.dumps(experiment_info['sessions'], ensure_ascii=False)}
|
||||
|
||||
Generate a JSON response with the following structure:
|
||||
|
||||
{{
|
||||
"Experiment Analysis Report": {{
|
||||
"1. Basic Information": {{
|
||||
"Total Sessions": "{experiment_info['total_sessions']}",
|
||||
"Total Duration": "{experiment_info['total_duration']} seconds",
|
||||
"Data Points": "{experiment_info['total_points']}",
|
||||
"Devices number": "{experiment_info['device_stats']['total_devices']}",
|
||||
"Sensors number": "{experiment_info['device_stats']['total_sensors']}"
|
||||
}},
|
||||
"2. Session Data Analysis": {{
|
||||
"[Session ID]": {{
|
||||
"Duration": "[Duration] seconds",
|
||||
"Data Points": "[data_points]"
|
||||
}},
|
||||
// Repeat for each session dynamically
|
||||
}},
|
||||
"3. Key Findings": [
|
||||
"[Finding 1]",
|
||||
"[Finding 2]",
|
||||
"[Finding 3]"
|
||||
],
|
||||
"4. Recommendations": [
|
||||
"[Recommendation 1]",
|
||||
"[Recommendation 2]",
|
||||
"[Recommendation 3]"
|
||||
]
|
||||
}}
|
||||
}}
|
||||
"""
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_prompt}
|
||||
]
|
||||
|
||||
try:
|
||||
response = client.chat.completions.create(
|
||||
model="deepseek-chat",
|
||||
messages=messages,
|
||||
response_format={'type': 'json_object'}
|
||||
)
|
||||
return json.loads(response.choices[0].message.content)
|
||||
except Exception as e:
|
||||
print(f"Error calling DeepSeek API: {e}")
|
||||
return None
|
||||
|
||||
def run_experiment_in_thread(experiment_id: str):
|
||||
"""在独立线程中运行分析任务"""
|
||||
# 创建新的事件循环
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
# 在新的事件循环中运行异步任务
|
||||
loop.run_until_complete(process_experiment_analysis(experiment_id))
|
||||
except Exception as e:
|
||||
print(f"Error in analysis thread: {str(e)}")
|
||||
finally:
|
||||
try:
|
||||
# 清理所有待处理的任务
|
||||
pending = asyncio.all_tasks(loop)
|
||||
for task in pending:
|
||||
task.cancel()
|
||||
# 运行直到所有任务完成
|
||||
if pending:
|
||||
loop.run_until_complete(asyncio.gather(*pending, return_exceptions=True))
|
||||
except Exception as e:
|
||||
print(f"Error cleaning up tasks: {str(e)}")
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
async def process_experiment_analysis(experiment_id: str):
|
||||
"""后台处理实验分析任务"""
|
||||
print(f"\n=== 开始实验分析 ===")
|
||||
print(f"实验ID: {experiment_id}")
|
||||
|
||||
# 创建新的数据库和Redis连接
|
||||
mongo_client = AsyncIOMotorClient(MONGODB_URL)
|
||||
db = mongo_client["lab"]
|
||||
redis = await aioredis.from_url(REDIS_URL, encoding="utf-8", decode_responses=True)
|
||||
|
||||
try:
|
||||
await redis.select(REDIS_DB_MAP["experiment_analysis"])
|
||||
status_key = f"experiment_analysis_status:{experiment_id}"
|
||||
|
||||
# 更新状态为进行中
|
||||
status_data = {
|
||||
"status": "processing",
|
||||
"start_time": datetime.now(timezone.utc).isoformat()
|
||||
}
|
||||
await redis.set(status_key, json.dumps(status_data))
|
||||
|
||||
# 获取实验基本信息
|
||||
experiment = await db.experiments.find_one({"_id": ObjectId(experiment_id)})
|
||||
if not experiment:
|
||||
raise ValueError("实验不存在")
|
||||
|
||||
# 收集会话数据
|
||||
sessions = []
|
||||
total_points = 0
|
||||
total_duration = 0
|
||||
devices_set = set()
|
||||
sensors_count = 0
|
||||
|
||||
async for session in db.experiment_sessions.find({
|
||||
"experiment_id": ObjectId(experiment_id)
|
||||
}):
|
||||
# 计算会话持续时间
|
||||
end_time = session.get("end_time") or datetime.now(timezone.utc)
|
||||
duration = (end_time - session["start_time"]).total_seconds()
|
||||
total_duration += duration
|
||||
|
||||
# 转换时间戳为毫秒
|
||||
start_ms = int(session["start_time"].timestamp() * 1000)
|
||||
end_ms = int(end_time.timestamp() * 1000)
|
||||
|
||||
# 统计设备和传感器数量
|
||||
session_devices = session.get("devices", [])
|
||||
for device in session_devices:
|
||||
devices_set.add(device["serial_number"])
|
||||
sensors_count += len(device.get("sensors", []))
|
||||
|
||||
# 获取会话数据点数
|
||||
session_points = 0
|
||||
for device in session_devices:
|
||||
stream_key = f"experiment:{experiment_id}:{device['serial_number']}"
|
||||
await redis.select(200)
|
||||
data_points = await redis.xrange(
|
||||
stream_key,
|
||||
min=str(start_ms),
|
||||
max=str(end_ms)
|
||||
)
|
||||
session_points += len(data_points)
|
||||
|
||||
total_points += session_points
|
||||
sessions.append({
|
||||
"session_id": str(session["_id"]),
|
||||
"duration": duration,
|
||||
"data_points": session_points
|
||||
})
|
||||
|
||||
if not sessions:
|
||||
raise ValueError("没有找到实验会话数据")
|
||||
|
||||
experiment_info = {
|
||||
"experiment_name": experiment["experiment_name"],
|
||||
"total_sessions": len(sessions),
|
||||
"total_duration": total_duration,
|
||||
"total_points": total_points,
|
||||
"device_stats": {
|
||||
"total_devices": len(devices_set),
|
||||
"total_sensors": sensors_count
|
||||
},
|
||||
"sessions": sessions
|
||||
}
|
||||
|
||||
# 执行分析
|
||||
analysis_result = await analyze_experiment_data(experiment_info)
|
||||
|
||||
if analysis_result:
|
||||
# 保存分析结果
|
||||
await redis.select(REDIS_DB_MAP["experiment_analysis"])
|
||||
report_key = f"experiment_report:{experiment_id}"
|
||||
await redis.set(report_key, json.dumps(analysis_result))
|
||||
|
||||
# 更新状态为完成
|
||||
status_data = {
|
||||
"status": "completed",
|
||||
"completion_time": datetime.now(timezone.utc).isoformat()
|
||||
}
|
||||
await redis.set(status_key, json.dumps(status_data))
|
||||
print("分析报告已保存")
|
||||
else:
|
||||
print("分析失败")
|
||||
status_data = {
|
||||
"status": "failed",
|
||||
"error": "Failed to generate analysis result",
|
||||
"completion_time": datetime.now(timezone.utc).isoformat()
|
||||
}
|
||||
await redis.set(status_key, json.dumps(status_data))
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error in experiment analysis: {str(e)}")
|
||||
try:
|
||||
await redis.select(REDIS_DB_MAP["experiment_analysis"])
|
||||
status_key = f"experiment_analysis_status:{experiment_id}"
|
||||
status_data = {
|
||||
"status": "failed",
|
||||
"error": str(e),
|
||||
"completion_time": datetime.now(timezone.utc).isoformat()
|
||||
}
|
||||
await redis.set(status_key, json.dumps(status_data))
|
||||
except Exception as redis_error:
|
||||
print(f"Error updating Redis status: {redis_error}")
|
||||
finally:
|
||||
try:
|
||||
await redis.aclose()
|
||||
mongo_client.close()
|
||||
except Exception as e:
|
||||
print(f"Error closing connections: {e}")
|
||||
|
||||
@router.get("/experiments/{experiment_id}/analyze")
|
||||
async def analyze_data(
|
||||
experiment_id: str,
|
||||
current_user: UserModel = Depends(get_current_user)
|
||||
):
|
||||
"""启动实验数据分析"""
|
||||
db = await get_database()
|
||||
redis = await get_redis()
|
||||
|
||||
try:
|
||||
await redis.select(REDIS_DB_MAP["experiment_analysis"])
|
||||
status_key = f"experiment_analysis_status:{experiment_id}"
|
||||
current_status = await redis.get(status_key)
|
||||
|
||||
if current_status:
|
||||
status_data = json.loads(current_status)
|
||||
if status_data.get("status") == "processing":
|
||||
return format_response({
|
||||
"experiment_id": experiment_id,
|
||||
"start_time": status_data.get("start_time")
|
||||
}, "实验分析任务正在进行中")
|
||||
|
||||
# 记录分析开始状态
|
||||
status_data = {
|
||||
"status": "processing",
|
||||
"start_time": datetime.now(timezone.utc).isoformat()
|
||||
}
|
||||
await redis.set(status_key, json.dumps(status_data))
|
||||
|
||||
# 使用线程池提交任务
|
||||
exp_analysis_thread_pool.submit(run_experiment_in_thread, experiment_id)
|
||||
|
||||
return format_response({
|
||||
"experiment_id": experiment_id,
|
||||
"start_time": status_data["start_time"]
|
||||
}, "实验分析任务已启动")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error starting experiment analysis: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
finally:
|
||||
try:
|
||||
await redis.aclose()
|
||||
except Exception as e:
|
||||
print(f"Error closing Redis connection: {e}")
|
||||
|
||||
@router.get("/experiments/{experiment_id}/analysis_status")
|
||||
async def get_experiment_analysis_status(
|
||||
experiment_id: str,
|
||||
current_user: UserModel = Depends(get_current_user)
|
||||
):
|
||||
"""获取实验分析任务的状态"""
|
||||
redis = await get_redis()
|
||||
|
||||
try:
|
||||
await redis.select(REDIS_DB_MAP["experiment_analysis"])
|
||||
status_key = f"experiment_analysis_status:{experiment_id}"
|
||||
status_data = await redis.get(status_key)
|
||||
|
||||
if not status_data:
|
||||
return format_response({
|
||||
"experiment_id": experiment_id,
|
||||
"status": "not_started"
|
||||
})
|
||||
|
||||
return format_response(json.loads(status_data))
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error getting analysis status: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
finally:
|
||||
try:
|
||||
await redis.aclose()
|
||||
except Exception as e:
|
||||
print(f"Error closing Redis connection: {e}")
|
||||
|
||||
@router.get("/experiments/{experiment_id}/report")
|
||||
async def get_saved_report(
|
||||
experiment_id: str,
|
||||
current_user: UserModel = Depends(get_current_user)
|
||||
):
|
||||
"""从 Redis 读取已保存的实验报告"""
|
||||
redis = await get_redis()
|
||||
|
||||
try:
|
||||
await redis.select(REDIS_DB_MAP["experiment_analysis"])
|
||||
report_key = f"experiment_report:{experiment_id}"
|
||||
|
||||
existing_report = await redis.get(report_key)
|
||||
|
||||
if not existing_report:
|
||||
raise HTTPException(status_code=404, detail="No saved experiment report found")
|
||||
|
||||
return format_response(json.loads(existing_report))
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error getting experiment report: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
finally:
|
||||
try:
|
||||
await redis.aclose()
|
||||
except Exception as e:
|
||||
print(f"Error closing Redis connection: {e}")
|
||||
|
||||
class QuestionModel(BaseModel):
|
||||
"""问题模型"""
|
||||
question: str
|
||||
|
||||
@router.post("/experiments/{experiment_id}/qa")
|
||||
async def ask_experiment_question(
|
||||
experiment_id: str,
|
||||
question: QuestionModel,
|
||||
current_user: UserModel = Depends(get_current_user)
|
||||
):
|
||||
"""向实验报告提问(异步)"""
|
||||
db = await get_database()
|
||||
redis = await get_redis()
|
||||
|
||||
try:
|
||||
# 验证实验是否存在
|
||||
experiment = await db.experiments.find_one({"_id": ObjectId(experiment_id)})
|
||||
if not experiment:
|
||||
raise HTTPException(status_code=404, detail="实验不存在")
|
||||
|
||||
# 生成任务ID
|
||||
task_id = str(ObjectId())
|
||||
|
||||
# 创建后台任务
|
||||
asyncio.create_task(process_experiment_question(
|
||||
task_id=task_id,
|
||||
experiment_id=experiment_id,
|
||||
question=question.question,
|
||||
experiment=experiment
|
||||
))
|
||||
|
||||
return {
|
||||
"task_id": task_id,
|
||||
"status": TaskStatus.PENDING,
|
||||
"message": "问题已提交,正在处理中"
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error in ask_experiment_question: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
finally:
|
||||
try:
|
||||
await redis.aclose()
|
||||
except Exception as e:
|
||||
print(f"Error closing Redis connection: {e}")
|
||||
|
||||
async def process_experiment_question(task_id: str, experiment_id: str, question: str, experiment: dict):
|
||||
"""处理实验问答的后台任务"""
|
||||
redis = await get_redis()
|
||||
|
||||
try:
|
||||
# 更新任务状态为处理中
|
||||
await redis.select(208) # 使用db208存储任务状态
|
||||
await redis.hset(
|
||||
f"task:{task_id}",
|
||||
mapping={
|
||||
"status": TaskStatus.PROCESSING,
|
||||
"experiment_id": experiment_id,
|
||||
"question": question
|
||||
}
|
||||
)
|
||||
|
||||
# 从Redis获取分析报告
|
||||
await redis.select(201)
|
||||
report_key = f"experiment_report:{experiment_id}"
|
||||
report_data = await redis.get(report_key)
|
||||
|
||||
if not report_data:
|
||||
raise Exception("实验分析报告不存在,请先进行分析")
|
||||
|
||||
# 构建系统提示和用户提示
|
||||
system_prompt = """
|
||||
你是一个专业的实验助手,负责回答关于实验报告的问题。
|
||||
你应该基于实验报告提供准确、专业的回答。
|
||||
回答应当简洁明了,并尽可能引用报告中的具体内容。
|
||||
"""
|
||||
|
||||
# 构建上下文
|
||||
context = {
|
||||
"实验名称": experiment.get("experiment_name", "未知实验"),
|
||||
"分析报告": json.loads(report_data)
|
||||
}
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": f"基于以下实验报告回答问题:\n\n实验信息:{json.dumps(context, ensure_ascii=False)}\n\n问题:{question}"}
|
||||
]
|
||||
|
||||
# 调用DeepSeek API获取回答
|
||||
response = client.chat.completions.create(
|
||||
model="deepseek-chat",
|
||||
messages=messages
|
||||
)
|
||||
answer = response.choices[0].message.content
|
||||
|
||||
# 保存对话历史到Redis db207
|
||||
await redis.select(207)
|
||||
chat_history_key = f"experiment_chat_history:{experiment_id}"
|
||||
|
||||
# 获取现有历史记录
|
||||
existing_history = await redis.get(chat_history_key)
|
||||
history = json.loads(existing_history) if existing_history else []
|
||||
|
||||
# 添加新的对话
|
||||
history.append({
|
||||
"question": question,
|
||||
"answer": answer,
|
||||
"timestamp": datetime.now(timezone.utc).isoformat()
|
||||
})
|
||||
|
||||
# 保存更新后的历史记录
|
||||
await redis.set(chat_history_key, json.dumps(history))
|
||||
|
||||
# 更新任务状态为完成
|
||||
await redis.select(208)
|
||||
await redis.hset(
|
||||
f"task:{task_id}",
|
||||
mapping={
|
||||
"status": TaskStatus.COMPLETED,
|
||||
"answer": answer,
|
||||
"experiment_name": experiment.get("experiment_name")
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error processing question: {e}")
|
||||
# 更新任务状态为失败
|
||||
await redis.select(208)
|
||||
await redis.hset(
|
||||
f"task:{task_id}",
|
||||
mapping={
|
||||
"status": TaskStatus.FAILED,
|
||||
"error": str(e)
|
||||
}
|
||||
)
|
||||
finally:
|
||||
try:
|
||||
await redis.aclose()
|
||||
except Exception as e:
|
||||
print(f"Error closing Redis connection: {e}")
|
||||
|
||||
@router.get("/experiments/{experiment_id}/qa/history")
|
||||
async def get_experiment_qa_history(
|
||||
experiment_id: str,
|
||||
current_user: UserModel = Depends(get_current_user)
|
||||
):
|
||||
"""获取实验问答历史记录"""
|
||||
redis = await get_redis()
|
||||
|
||||
try:
|
||||
# 从Redis db207获取对话历史
|
||||
await redis.select(207)
|
||||
chat_history_key = f"experiment_chat_history:{experiment_id}"
|
||||
|
||||
history_data = await redis.get(chat_history_key)
|
||||
if not history_data:
|
||||
return []
|
||||
|
||||
return json.loads(history_data)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error getting QA history: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
finally:
|
||||
try:
|
||||
await redis.aclose()
|
||||
except Exception as e:
|
||||
print(f"Error closing Redis connection: {e}")
|
||||
@@ -0,0 +1,112 @@
|
||||
from fastapi import APIRouter, HTTPException, Depends
|
||||
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
|
||||
from datetime import datetime, timedelta, timezone
|
||||
import jwt
|
||||
from jwt.exceptions import ExpiredSignatureError, InvalidSignatureError
|
||||
from pydantic import BaseModel, EmailStr, Field
|
||||
from typing import Optional
|
||||
from bson import ObjectId
|
||||
from ..cores.db import PyObjectId, get_database
|
||||
|
||||
# JWT Configuration
|
||||
SECRET_KEY = "Obscura@2024"
|
||||
ALGORITHM = "HS256"
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES = 60 * 24 * 30
|
||||
|
||||
# OAuth2 scheme
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
|
||||
|
||||
# Router
|
||||
router = APIRouter(
|
||||
prefix="/lab",
|
||||
tags=["authentication"]
|
||||
)
|
||||
|
||||
# Pydantic models
|
||||
class UserModel(BaseModel):
|
||||
"""
|
||||
用户模型
|
||||
包含用户基本信息:用户名、密码、邮箱、姓名、所属机构
|
||||
"""
|
||||
id: Optional[PyObjectId] = Field(alias="_id", default=None)
|
||||
username: str
|
||||
password: str
|
||||
email: EmailStr
|
||||
name: str
|
||||
institution: str
|
||||
|
||||
class Config:
|
||||
populate_by_name = True
|
||||
arbitrary_types_allowed = True
|
||||
json_encoders = {ObjectId: str}
|
||||
|
||||
# Security functions
|
||||
def create_access_token(data: dict):
|
||||
to_encode = data.copy()
|
||||
expire = datetime.now(timezone.utc) + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
|
||||
to_encode.update({"exp": expire})
|
||||
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
|
||||
return encoded_jwt
|
||||
|
||||
async def get_current_user(token: str = Depends(oauth2_scheme)):
|
||||
credentials_exception = HTTPException(
|
||||
status_code=401,
|
||||
detail="Could not validate credentials",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
try:
|
||||
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
||||
username: str = payload.get("sub")
|
||||
if username is None:
|
||||
raise credentials_exception
|
||||
except (ExpiredSignatureError, InvalidSignatureError) as e:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail=str(e),
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
db = await get_database()
|
||||
user = await db.users.find_one({"username": username})
|
||||
if user is None:
|
||||
raise credentials_exception
|
||||
return UserModel(**user)
|
||||
|
||||
# Auth routes
|
||||
@router.post("/register")
|
||||
async def register(user_data: UserModel):
|
||||
"""
|
||||
用户注册接口
|
||||
"""
|
||||
db = await get_database()
|
||||
|
||||
# 检查用户名是否已存在
|
||||
if await db.users.find_one({"username": user_data.username}):
|
||||
raise HTTPException(status_code=400, detail="Username has already been registered")
|
||||
|
||||
# 创建用户文档
|
||||
user_dict = user_data.model_dump(exclude={"id"})
|
||||
|
||||
try:
|
||||
result = await db.users.insert_one(user_dict)
|
||||
return {
|
||||
"message": "Registration successful",
|
||||
"id": str(result.inserted_id)
|
||||
}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Database Error: {str(e)}")
|
||||
|
||||
@router.post("/token")
|
||||
async def login(form_data: OAuth2PasswordRequestForm = Depends()):
|
||||
"""
|
||||
用户登录接口
|
||||
"""
|
||||
db = await get_database()
|
||||
user = await db.users.find_one({"username": form_data.username})
|
||||
|
||||
# 直接比较明文密码
|
||||
if not user or form_data.password != user["password"]:
|
||||
raise HTTPException(status_code=400, detail="Incorrect username or password")
|
||||
|
||||
access_token = create_access_token(data={"sub": user["username"]})
|
||||
return {"access_token": access_token, "token_type": "bearer"}
|
||||
@@ -0,0 +1,141 @@
|
||||
from fastapi import APIRouter, HTTPException, Depends
|
||||
from pydantic import BaseModel, Field
|
||||
from datetime import datetime, timezone
|
||||
import json
|
||||
|
||||
from ..cores.db import get_redis
|
||||
from .login import get_current_user, UserModel
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/lab",
|
||||
tags=["memo"]
|
||||
)
|
||||
|
||||
class MemoModel(BaseModel):
|
||||
"""项目备忘录模型"""
|
||||
content: str
|
||||
create_time: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
|
||||
@router.post("/projects/{project_id}/memo")
|
||||
async def save_project_memo(
|
||||
project_id: str,
|
||||
memo: MemoModel,
|
||||
current_user: UserModel = Depends(get_current_user)
|
||||
):
|
||||
"""保存项目备忘录"""
|
||||
redis = await get_redis()
|
||||
|
||||
try:
|
||||
# 选择 db205
|
||||
await redis.select(205)
|
||||
memo_key = f"memo:{project_id}"
|
||||
|
||||
# 保存备忘录内容和创建时间
|
||||
memo_data = {
|
||||
"content": memo.content,
|
||||
"create_time": memo.create_time.isoformat()
|
||||
}
|
||||
|
||||
await redis.set(memo_key, json.dumps(memo_data))
|
||||
return {"message": "备忘录保存成功"}
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error saving memo: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
finally:
|
||||
try:
|
||||
await redis.aclose()
|
||||
except Exception as e:
|
||||
print(f"Error closing Redis connection: {e}")
|
||||
|
||||
@router.get("/projects/{project_id}/memo")
|
||||
async def get_project_memo(
|
||||
project_id: str,
|
||||
current_user: UserModel = Depends(get_current_user)
|
||||
):
|
||||
"""获取项目备忘录"""
|
||||
redis = await get_redis()
|
||||
|
||||
try:
|
||||
# 选择 db205
|
||||
await redis.select(205)
|
||||
memo_key = f"memo:{project_id}"
|
||||
|
||||
# 获取备忘录
|
||||
memo_data = await redis.get(memo_key)
|
||||
|
||||
if not memo_data:
|
||||
return {"content": "", "create_time": None}
|
||||
|
||||
return json.loads(memo_data)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error getting memo: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
finally:
|
||||
try:
|
||||
await redis.aclose()
|
||||
except Exception as e:
|
||||
print(f"Error closing Redis connection: {e}")
|
||||
|
||||
@router.post("/experiments/{experiment_id}/memo")
|
||||
async def save_experiment_memo(
|
||||
experiment_id: str,
|
||||
memo: MemoModel,
|
||||
current_user: UserModel = Depends(get_current_user)
|
||||
):
|
||||
"""保存实验备忘录"""
|
||||
redis = await get_redis()
|
||||
|
||||
try:
|
||||
# 选择 db206
|
||||
await redis.select(206)
|
||||
memo_key = f"memo:{experiment_id}"
|
||||
|
||||
# 保存备忘录内容和创建时间
|
||||
memo_data = {
|
||||
"content": memo.content,
|
||||
"create_time": memo.create_time.isoformat()
|
||||
}
|
||||
|
||||
await redis.set(memo_key, json.dumps(memo_data))
|
||||
return {"message": "Experiment memo saved successfully"}
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error saving experiment memo: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
finally:
|
||||
try:
|
||||
await redis.aclose()
|
||||
except Exception as e:
|
||||
print(f"Error closing Redis connection: {e}")
|
||||
|
||||
@router.get("/experiments/{experiment_id}/memo")
|
||||
async def get_experiment_memo(
|
||||
experiment_id: str,
|
||||
current_user: UserModel = Depends(get_current_user)
|
||||
):
|
||||
"""获取实验备忘录"""
|
||||
redis = await get_redis()
|
||||
|
||||
try:
|
||||
# 选择 db206
|
||||
await redis.select(206)
|
||||
memo_key = f"memo:{experiment_id}"
|
||||
|
||||
# 获取备忘录
|
||||
memo_data = await redis.get(memo_key)
|
||||
|
||||
if not memo_data:
|
||||
return {"content": "", "create_time": None}
|
||||
|
||||
return json.loads(memo_data)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error getting experiment memo: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
finally:
|
||||
try:
|
||||
await redis.aclose()
|
||||
except Exception as e:
|
||||
print(f"Error closing Redis connection: {e}")
|
||||
@@ -0,0 +1,410 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException, UploadFile, File
|
||||
from typing import List
|
||||
from bson import ObjectId
|
||||
import os
|
||||
import asyncio
|
||||
from datetime import datetime, timezone
|
||||
import json
|
||||
|
||||
from ..cores.db import get_database, get_redis
|
||||
from ..cores.config import UPLOAD_PATH, TaskStatus
|
||||
from ..routers.login import get_current_user, UserModel
|
||||
from ..models.paper import (
|
||||
ReferenceModel, QuestionModel,
|
||||
process_batch_analysis, process_reference_question,
|
||||
run_analysis_in_thread
|
||||
)
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/lab",
|
||||
tags=["paper"]
|
||||
)
|
||||
|
||||
@router.delete("/projects/{project_id}/references/{reference_id}")
|
||||
async def delete_reference(
|
||||
project_id: str,
|
||||
reference_id: str,
|
||||
current_user: UserModel = Depends(get_current_user)
|
||||
):
|
||||
"""删除项目文献"""
|
||||
db = await get_database()
|
||||
|
||||
try:
|
||||
# 验证项目所有权
|
||||
project = await db.projects.find_one({
|
||||
"_id": ObjectId(project_id),
|
||||
"user_id": current_user.id
|
||||
})
|
||||
if not project:
|
||||
raise HTTPException(status_code=403, detail="Unauthorized to access this project")
|
||||
|
||||
# 获取文献信息
|
||||
reference = await db.references.find_one({"_id": ObjectId(reference_id)})
|
||||
if not reference:
|
||||
raise HTTPException(status_code=404, detail="Reference not found")
|
||||
|
||||
# 删除文件
|
||||
if os.path.exists(reference["reference_link"]):
|
||||
os.remove(reference["reference_link"])
|
||||
|
||||
# 删除数据库记录
|
||||
await db.references.delete_one({"_id": ObjectId(reference_id)})
|
||||
|
||||
return {"message": "Reference successfully deleted"}
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error deleting reference: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"Failed to delete reference: {str(e)}")
|
||||
|
||||
@router.get("/projects/{project_id}/references")
|
||||
async def get_project_references(
|
||||
project_id: str,
|
||||
current_user: UserModel = Depends(get_current_user)
|
||||
):
|
||||
"""获取项目的文献列表"""
|
||||
db = await get_database()
|
||||
|
||||
try:
|
||||
# 查找项目下的所有文献
|
||||
references_cursor = db.references.find({
|
||||
"project_id": ObjectId(project_id)
|
||||
})
|
||||
|
||||
references = []
|
||||
async for ref in references_cursor:
|
||||
references.append({
|
||||
"_id": str(ref["_id"]),
|
||||
"project_id": str(ref["project_id"]),
|
||||
"reference_link": ref["reference_link"],
|
||||
"reference_title": ref["reference_title"],
|
||||
"upload_time": ref["upload_time"]
|
||||
})
|
||||
return references
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error getting project references: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@router.get("/references/{reference_id}/report")
|
||||
async def get_reference_report(
|
||||
reference_id: str,
|
||||
current_user: UserModel = Depends(get_current_user)
|
||||
):
|
||||
"""从 Redis db203 读取已保存的文献报告"""
|
||||
redis = await get_redis()
|
||||
|
||||
try:
|
||||
await redis.select(203)
|
||||
report_key = f"reference_report:{reference_id}"
|
||||
|
||||
existing_report = await redis.get(report_key)
|
||||
if not existing_report:
|
||||
raise HTTPException(status_code=404, detail="No saved reference report found")
|
||||
|
||||
return json.loads(existing_report)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error getting reference report: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
finally:
|
||||
try:
|
||||
await redis.aclose()
|
||||
except Exception as e:
|
||||
print(f"Error closing Redis connection: {e}")
|
||||
|
||||
@router.get("/references/{project_id}/analyze_report")
|
||||
async def analyze_reference_summary_report(
|
||||
project_id: str,
|
||||
current_user: UserModel = Depends(get_current_user)
|
||||
):
|
||||
"""分析文献数据"""
|
||||
db = await get_database()
|
||||
redis = await get_redis()
|
||||
|
||||
try:
|
||||
await redis.select(204)
|
||||
status_key = f"reference_analysis_status:{project_id}"
|
||||
current_status = await redis.get(status_key)
|
||||
|
||||
if current_status:
|
||||
status_data = json.loads(current_status)
|
||||
if status_data.get("status") == "processing":
|
||||
return {
|
||||
"message": "文献分析任务正在进行中",
|
||||
"status": "processing",
|
||||
"project_id": project_id,
|
||||
"start_time": status_data.get("start_time")
|
||||
}
|
||||
|
||||
reference_cursor = db.references.find({
|
||||
"project_id": ObjectId(project_id)
|
||||
})
|
||||
|
||||
reference_ids = []
|
||||
async for reference in reference_cursor:
|
||||
reference_ids.append(str(reference["_id"]))
|
||||
|
||||
if not reference_ids:
|
||||
raise HTTPException(status_code=404, detail="No references found for this project")
|
||||
|
||||
status_data = {
|
||||
"status": "processing",
|
||||
"start_time": datetime.now(timezone.utc).isoformat(),
|
||||
"total_references": len(reference_ids),
|
||||
"completed_references": 0
|
||||
}
|
||||
await redis.set(status_key, json.dumps(status_data))
|
||||
|
||||
asyncio.create_task(run_analysis_in_thread(project_id, reference_ids))
|
||||
|
||||
return {
|
||||
"message": "文献分析任务已启动",
|
||||
"status": "processing",
|
||||
"project_id": project_id,
|
||||
"start_time": status_data["start_time"],
|
||||
"total_references": len(reference_ids)
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error starting reference analysis: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
finally:
|
||||
try:
|
||||
await redis.aclose()
|
||||
except Exception as e:
|
||||
print(f"Error closing Redis connection: {e}")
|
||||
|
||||
@router.get("/references/{project_id}/analysis_status")
|
||||
async def get_reference_analysis_status(
|
||||
project_id: str,
|
||||
current_user: UserModel = Depends(get_current_user)
|
||||
):
|
||||
"""获取文献分析任务的状态"""
|
||||
redis = await get_redis()
|
||||
|
||||
try:
|
||||
await redis.select(204)
|
||||
status_key = f"reference_analysis_status:{project_id}"
|
||||
status_data = await redis.get(status_key)
|
||||
|
||||
if not status_data:
|
||||
return {
|
||||
"status": "not_started",
|
||||
"project_id": project_id
|
||||
}
|
||||
|
||||
return json.loads(status_data)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error getting analysis status: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
finally:
|
||||
try:
|
||||
await redis.aclose()
|
||||
except Exception as e:
|
||||
print(f"Error closing Redis connection: {e}")
|
||||
|
||||
@router.get("/references/{project_id}/summary_report")
|
||||
async def get_reference_summary_report(
|
||||
project_id: str,
|
||||
current_user: UserModel = Depends(get_current_user)
|
||||
):
|
||||
"""从 Redis db204 读取已保存的文献报告"""
|
||||
redis = await get_redis()
|
||||
|
||||
try:
|
||||
await redis.select(204)
|
||||
report_key = f"reference_summary_report:{project_id}"
|
||||
|
||||
existing_report = await redis.get(report_key)
|
||||
if not existing_report:
|
||||
raise HTTPException(status_code=404, detail="No saved reference report found")
|
||||
|
||||
return json.loads(existing_report)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error getting reference report: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
finally:
|
||||
try:
|
||||
await redis.aclose()
|
||||
except Exception as e:
|
||||
print(f"Error closing Redis connection: {e}")
|
||||
|
||||
@router.post("/references/{reference_id}/qa")
|
||||
async def ask_reference_question(
|
||||
reference_id: str,
|
||||
question: QuestionModel,
|
||||
current_user: UserModel = Depends(get_current_user)
|
||||
):
|
||||
"""向文献提问(异步)"""
|
||||
db = await get_database()
|
||||
redis = await get_redis()
|
||||
|
||||
try:
|
||||
reference = await db.references.find_one({"_id": ObjectId(reference_id)})
|
||||
if not reference:
|
||||
raise HTTPException(status_code=404, detail="文献不存在")
|
||||
|
||||
task_id = str(ObjectId())
|
||||
|
||||
asyncio.create_task(process_reference_question(
|
||||
task_id=task_id,
|
||||
reference_id=reference_id,
|
||||
question=question.question,
|
||||
reference=reference
|
||||
))
|
||||
|
||||
return {
|
||||
"task_id": task_id,
|
||||
"status": TaskStatus.PENDING,
|
||||
"message": "问题已提交,正在处理中"
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error in ask_reference_question: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
finally:
|
||||
try:
|
||||
await redis.aclose()
|
||||
except Exception as e:
|
||||
print(f"Error closing Redis connection: {e}")
|
||||
|
||||
@router.get("/task/{task_id}")
|
||||
async def get_task_status(task_id: str):
|
||||
"""获取任务状态和结果"""
|
||||
redis = await get_redis()
|
||||
|
||||
try:
|
||||
await redis.select(208)
|
||||
task_data = await redis.hgetall(f"task:{task_id}")
|
||||
|
||||
if not task_data:
|
||||
raise HTTPException(status_code=404, detail="任务不存在")
|
||||
|
||||
response = {
|
||||
"task_id": task_id,
|
||||
"status": task_data.get("status", TaskStatus.PENDING)
|
||||
}
|
||||
|
||||
if task_data.get("status") == TaskStatus.COMPLETED:
|
||||
response.update({
|
||||
"answer": task_data.get("answer"),
|
||||
"reference_title": task_data.get("reference_title")
|
||||
})
|
||||
elif task_data.get("status") == TaskStatus.FAILED:
|
||||
response.update({
|
||||
"error": task_data.get("error")
|
||||
})
|
||||
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error getting task status: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
finally:
|
||||
try:
|
||||
await redis.aclose()
|
||||
except Exception as e:
|
||||
print(f"Error closing Redis connection: {e}")
|
||||
|
||||
@router.get("/references/{reference_id}/qa/history")
|
||||
async def get_reference_qa_history(
|
||||
reference_id: str,
|
||||
current_user: UserModel = Depends(get_current_user)
|
||||
):
|
||||
"""获取文献问答历史记录"""
|
||||
redis = await get_redis()
|
||||
|
||||
try:
|
||||
await redis.select(207)
|
||||
chat_history_key = f"chat_history:{reference_id}"
|
||||
|
||||
history_data = await redis.get(chat_history_key)
|
||||
if not history_data:
|
||||
return []
|
||||
|
||||
return json.loads(history_data)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error getting QA history: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
finally:
|
||||
try:
|
||||
await redis.aclose()
|
||||
except Exception as e:
|
||||
print(f"Error closing Redis connection: {e}")
|
||||
|
||||
@router.post("/projects/{project_id}/references/batch")
|
||||
async def batch_upload_references(
|
||||
project_id: str,
|
||||
files: List[UploadFile] = File(...),
|
||||
current_user: UserModel = Depends(get_current_user)
|
||||
):
|
||||
"""批量上传项目相关文献"""
|
||||
db = await get_database()
|
||||
|
||||
try:
|
||||
project = await db.projects.find_one({
|
||||
"_id": ObjectId(project_id),
|
||||
"user_id": current_user.id
|
||||
})
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Project not found or unauthorized access")
|
||||
|
||||
uploaded_references = []
|
||||
|
||||
for file in files:
|
||||
allowed_types = ["application/pdf", "application/msword",
|
||||
"application/vnd.openxmlformats-officedocument.wordprocessingml.document"]
|
||||
if file.content_type not in allowed_types:
|
||||
continue
|
||||
|
||||
os.makedirs(UPLOAD_PATH, exist_ok=True)
|
||||
|
||||
file_extension = os.path.splitext(file.filename)[1]
|
||||
safe_filename = f"{project_id}_{datetime.now().strftime('%Y%m%d_%H%M%S')}{file_extension}"
|
||||
file_path = os.path.join(UPLOAD_PATH, safe_filename)
|
||||
|
||||
with open(file_path, "wb") as buffer:
|
||||
content = await file.read()
|
||||
buffer.write(content)
|
||||
|
||||
reference = {
|
||||
"project_id": ObjectId(project_id),
|
||||
"reference_link": file_path,
|
||||
"reference_title": file.filename,
|
||||
"upload_time": datetime.now(timezone.utc)
|
||||
}
|
||||
|
||||
result = await db.references.insert_one(reference)
|
||||
reference_info = {
|
||||
"reference_id": str(result.inserted_id),
|
||||
"file_path": file_path,
|
||||
"reference_title": reference["reference_title"]
|
||||
}
|
||||
uploaded_references.append(reference_info)
|
||||
|
||||
redis = await get_redis()
|
||||
try:
|
||||
await redis.select(203)
|
||||
report_key = f"reference_report:{str(result.inserted_id)}"
|
||||
initial_status = {
|
||||
"status": "processing",
|
||||
"message": "Analysis in progress"
|
||||
}
|
||||
await redis.set(report_key, json.dumps(initial_status))
|
||||
finally:
|
||||
await redis.aclose()
|
||||
|
||||
if uploaded_references:
|
||||
asyncio.create_task(process_batch_analysis(uploaded_references))
|
||||
|
||||
return {
|
||||
"message": f"Successfully uploaded {len(uploaded_references)} files",
|
||||
"uploaded_files": uploaded_references
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
print(f"Batch upload error: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
@@ -0,0 +1,133 @@
|
||||
from fastapi import APIRouter, HTTPException, Depends
|
||||
from datetime import datetime, timezone
|
||||
from typing import List, Optional
|
||||
from pydantic import BaseModel, Field
|
||||
from bson import ObjectId
|
||||
|
||||
from ..cores.db import PyObjectId, get_database, get_redis
|
||||
from .login import get_current_user, UserModel
|
||||
|
||||
# Router
|
||||
router = APIRouter(
|
||||
prefix="/lab",
|
||||
tags=["projects"]
|
||||
)
|
||||
|
||||
# Pydantic models
|
||||
class ProjectModel(BaseModel):
|
||||
"""
|
||||
项目模型
|
||||
包含项目信息:用户ID、项目名称、创建时间、描述
|
||||
"""
|
||||
id: Optional[PyObjectId] = Field(alias="_id", default=None)
|
||||
user_id: Optional[PyObjectId] = None
|
||||
project_name: str
|
||||
create_time: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
description: str
|
||||
|
||||
class Config:
|
||||
populate_by_name = True
|
||||
arbitrary_types_allowed = True
|
||||
json_encoders = {ObjectId: str}
|
||||
|
||||
# Project routes
|
||||
@router.post("/projects")
|
||||
async def create_project(
|
||||
project: ProjectModel,
|
||||
current_user: UserModel = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
创建新项目
|
||||
|
||||
参数:
|
||||
project: 项目信息
|
||||
current_user: 当前登录用户
|
||||
"""
|
||||
db = await get_database()
|
||||
|
||||
project_dict = {
|
||||
"user_id": current_user.id,
|
||||
"project_name": project.project_name,
|
||||
"description": project.description,
|
||||
"create_time": datetime.now(timezone.utc)
|
||||
}
|
||||
|
||||
result = await db.projects.insert_one(project_dict)
|
||||
return {
|
||||
"message": "Project created successfully",
|
||||
"id": str(result.inserted_id)
|
||||
}
|
||||
|
||||
@router.get("/projects")
|
||||
async def get_projects(current_user: UserModel = Depends(get_current_user)):
|
||||
"""
|
||||
获取当前用户所有项目
|
||||
|
||||
参数:
|
||||
current_user: 当前登录用户,通过JWT token验证获取
|
||||
|
||||
返回:
|
||||
projects: 项目列表,每个项目包含完整的项目信息
|
||||
"""
|
||||
db = await get_database()
|
||||
|
||||
# 初始化项目列表
|
||||
projects = []
|
||||
# 异步查询数据库,获取该用户的所有项目
|
||||
async for project in db.projects.find({"user_id": current_user.id}):
|
||||
projects.append(ProjectModel(**project))
|
||||
|
||||
return projects
|
||||
# 添加删除项目的路由
|
||||
@router.delete("/lab/projects/{project_id}")
|
||||
async def delete_project(
|
||||
project_id: str,
|
||||
current_user: UserModel = Depends(get_current_user)
|
||||
):
|
||||
"""删除项目及其所有相关数据"""
|
||||
db = await get_database()
|
||||
redis = await get_redis()
|
||||
|
||||
try:
|
||||
# 验证项目是否存在且属于当前用户
|
||||
project = await db.projects.find_one({
|
||||
"_id": ObjectId(project_id),
|
||||
"user_id": current_user.id
|
||||
})
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Project not found or unauthorized access")
|
||||
|
||||
# 获取项目下所有实验的ID
|
||||
experiment_ids = []
|
||||
async for exp in db.experiments.find({"project_id": ObjectId(project_id)}):
|
||||
experiment_ids.append(str(exp["_id"]))
|
||||
|
||||
# 删除所有相关数据
|
||||
for exp_id in experiment_ids:
|
||||
await db.experiment_sessions.delete_many({"experiment_id": ObjectId(exp_id)})
|
||||
await db.experiment_devices.delete_many({"experiment_id": ObjectId(exp_id)})
|
||||
|
||||
# 删除Redis中的实验报告
|
||||
await redis.select(201)
|
||||
await redis.delete(f"experiment_report:{exp_id}")
|
||||
|
||||
# 删除所有实验
|
||||
await db.experiments.delete_many({"project_id": ObjectId(project_id)})
|
||||
|
||||
# 删除项目
|
||||
await db.projects.delete_one({"_id": ObjectId(project_id)})
|
||||
|
||||
# 删除Redis中的项目报告
|
||||
await redis.select(202)
|
||||
await redis.delete(f"project_report:{project_id}")
|
||||
|
||||
return {"message": "Project successfully deleted"}
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error deleting project: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"Failed to delete project: {str(e)}")
|
||||
finally:
|
||||
try:
|
||||
await redis.aclose()
|
||||
except Exception as e:
|
||||
print(f"Error closing Redis connection: {e}")
|
||||
@@ -0,0 +1,558 @@
|
||||
from fastapi import APIRouter, HTTPException, Depends
|
||||
from typing import Dict, List, Optional, Any
|
||||
from datetime import datetime, timezone
|
||||
import json
|
||||
import asyncio
|
||||
from bson import ObjectId
|
||||
from motor.motor_asyncio import AsyncIOMotorClient
|
||||
from openai import OpenAI
|
||||
from redis import asyncio as aioredis
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ..cores.config import MONGODB_URL, REDIS_URL, DEEPSEEK_API_CONFIG, TaskStatus
|
||||
from ..cores.db import get_database, get_redis
|
||||
from .login import get_current_user, UserModel
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/lab",
|
||||
tags=["project_report"]
|
||||
)
|
||||
|
||||
# DeepSeek API Configuration
|
||||
client = OpenAI(
|
||||
base_url=DEEPSEEK_API_CONFIG["base_url"],
|
||||
api_key=DEEPSEEK_API_CONFIG["api_key"]
|
||||
)
|
||||
|
||||
async def analyze_project_data(project_data: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
||||
system_prompt = """
|
||||
You are an AI assistant responsible for analyzing lab reports.
|
||||
You will summarize and analyze all lab reports and generate a comprehensive analysis report in JSON format.
|
||||
The JSON structure must strictly follow the provided template.
|
||||
"""
|
||||
|
||||
user_prompt = f"""The project report is summarized based on the following experimental analysis reports:
|
||||
Experimental analysis reports: {json.dumps(project_data, ensure_ascii=False)}
|
||||
|
||||
Generate a JSON response with the following structure:
|
||||
|
||||
{{
|
||||
"Project Analysis Report": {{
|
||||
"1. Project Overview": {{
|
||||
"Project Name": "{project_data['project_stats']['project_name']}",
|
||||
"Total Experiments": "{project_data['project_stats']['total_experiments']}",
|
||||
"Total Data Points": "{project_data['project_stats']['total_data_points']}",
|
||||
}},
|
||||
"2. Aggregated Statistics": {{
|
||||
"Total Sessions": {project_data['project_stats']['total_sessions']},
|
||||
"Total Duration": {project_data['project_stats']['total_duration']},
|
||||
"Total Data Points": {project_data['project_stats']['total_data_points']},
|
||||
"Average Session Duration": {project_data['project_stats']['avg_session_duration']},
|
||||
"Average Data Points per Session": {project_data['project_stats']['avg_data_points_per_session']}
|
||||
}},
|
||||
"3. Performance Analysis": {{
|
||||
"Best Performing Sessions": [
|
||||
{{
|
||||
"Session ID": "session_id",
|
||||
"experiment_id": "experiment_id",
|
||||
"experiment_name": "experiment_name",
|
||||
"Duration": "duration",
|
||||
"Data Points": "data_points",
|
||||
"Success Factors": "success_factors"
|
||||
}}
|
||||
],
|
||||
"Problematic Sessions": [
|
||||
{{
|
||||
"Session ID": "session_id",
|
||||
"experiment_id": "experiment_id",
|
||||
"experiment_name": "experiment_name",
|
||||
"Issues": "specific_issues",
|
||||
"Possible Causes": "possible_causes"
|
||||
}}
|
||||
]
|
||||
}},
|
||||
"4. Common Findings": {{
|
||||
"Recurring Issues": "common_failure_modes",
|
||||
"Equipment performance": "equipment_performance",
|
||||
"Sensor reliability analysis": "sensor_reliability_analysis"
|
||||
}},
|
||||
"5. Recommendations": {{
|
||||
"Equipment optimization suggestions": "equipment_optimization_suggestions",
|
||||
"Experiment process improvement suggestions": "experiment_process_improvement_suggestions",
|
||||
"Data collection strategy adjustment": "data_collection_strategy_adjustment"
|
||||
}}
|
||||
}}
|
||||
}}
|
||||
"""
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_prompt}
|
||||
]
|
||||
|
||||
try:
|
||||
response = client.chat.completions.create(
|
||||
model="deepseek-chat",
|
||||
messages=messages,
|
||||
response_format={'type': 'json_object'}
|
||||
)
|
||||
return json.loads(response.choices[0].message.content)
|
||||
except Exception as e:
|
||||
print(f"Error calling DeepSeek API: {e}")
|
||||
return None
|
||||
|
||||
@router.get("/projects/{project_id}/analyze")
|
||||
async def analyze_project_data_endpoint(
|
||||
project_id: str,
|
||||
force: bool = False,
|
||||
current_user: UserModel = Depends(get_current_user)
|
||||
):
|
||||
"""启动项目数据分析"""
|
||||
print(f"\n=== analyze_project_data_endpoint 开始 ===")
|
||||
print(f"项目ID: {project_id}")
|
||||
print(f"Force: {force}")
|
||||
|
||||
db = await get_database()
|
||||
redis = await get_redis()
|
||||
|
||||
try:
|
||||
# 检查是否已经有正在进行的分析任务
|
||||
await redis.select(202)
|
||||
status_key = f"project_analysis_status:{project_id}"
|
||||
current_status = await redis.get(status_key)
|
||||
|
||||
if current_status and not force:
|
||||
status_data = json.loads(current_status)
|
||||
if status_data.get("status") == "processing":
|
||||
print("已有分析任务正在进行中")
|
||||
return {
|
||||
"message": "项目分析任务正在进行中",
|
||||
"status": "processing",
|
||||
"project_id": project_id,
|
||||
"start_time": status_data.get("start_time")
|
||||
}
|
||||
|
||||
# 记录分析开始状态
|
||||
status_data = {
|
||||
"status": "processing",
|
||||
"start_time": datetime.now(timezone.utc).isoformat()
|
||||
}
|
||||
await redis.set(status_key, json.dumps(status_data))
|
||||
|
||||
print("启动分析线程...")
|
||||
# 创建并启动新线程
|
||||
import threading
|
||||
analysis_thread = threading.Thread(
|
||||
target=run_project_in_thread,
|
||||
args=(project_id,),
|
||||
daemon=True
|
||||
)
|
||||
analysis_thread.start()
|
||||
print("分析线程已启动")
|
||||
|
||||
return {
|
||||
"message": "项目分析任务已启动",
|
||||
"status": "processing",
|
||||
"project_id": project_id,
|
||||
"start_time": status_data["start_time"]
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error starting project analysis: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
finally:
|
||||
try:
|
||||
await redis.aclose()
|
||||
except Exception as e:
|
||||
print(f"Error closing Redis connection: {e}")
|
||||
|
||||
def run_project_in_thread(project_id: str):
|
||||
"""在独立线程中运行分析任务"""
|
||||
# 创建新的事件循环
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
# 在新的事件循环中运行异步任务
|
||||
loop.run_until_complete(process_project_analysis(project_id))
|
||||
except Exception as e:
|
||||
print(f"Error in analysis thread: {str(e)}")
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
async def process_project_analysis(project_id: str):
|
||||
"""后台处理项目分析任务"""
|
||||
print(f"\n=== 开始项目分析 ===")
|
||||
print(f"项目ID: {project_id}")
|
||||
|
||||
# 创建新的数据库和Redis连接
|
||||
mongo_client = AsyncIOMotorClient(MONGODB_URL)
|
||||
db = mongo_client["lab"]
|
||||
redis = await aioredis.from_url(REDIS_URL, encoding="utf-8", decode_responses=True)
|
||||
|
||||
try:
|
||||
await redis.select(202)
|
||||
status_key = f"project_analysis_status:{project_id}"
|
||||
|
||||
# 更新状态为进行中
|
||||
status_data = {
|
||||
"status": "processing",
|
||||
"start_time": datetime.now(timezone.utc).isoformat()
|
||||
}
|
||||
await redis.set(status_key, json.dumps(status_data))
|
||||
|
||||
# 获取项目名称
|
||||
project = await db.projects.find_one({"_id": ObjectId(project_id)})
|
||||
project_name = project["project_name"] if project else "Unknown Project"
|
||||
|
||||
# 从MongoDB查找项目下的所有实验
|
||||
experiment_cursor = db.experiments.find({
|
||||
"project_id": ObjectId(project_id)
|
||||
})
|
||||
|
||||
experiment_ids = []
|
||||
experiment_names = {} # 存储实验ID和名称的映射
|
||||
async for experiment in experiment_cursor:
|
||||
exp_id = str(experiment["_id"])
|
||||
experiment_ids.append(exp_id)
|
||||
experiment_names[exp_id] = experiment["experiment_name"]
|
||||
|
||||
if not experiment_ids:
|
||||
print("错误: 未找到任何实验")
|
||||
raise Exception("No experiments found in this project")
|
||||
|
||||
# 从Redis db201获取所有实验报告
|
||||
await redis.select(201)
|
||||
all_experiment_reports = {}
|
||||
|
||||
# 统计数据初始化
|
||||
total_sessions = 0
|
||||
total_duration = 0
|
||||
total_data_points = 0
|
||||
total_devices = set()
|
||||
total_sensors = 0
|
||||
|
||||
for exp_id in experiment_ids:
|
||||
report_key = f"experiment_report:{exp_id}"
|
||||
report_data = await redis.get(report_key)
|
||||
|
||||
if report_data:
|
||||
try:
|
||||
parsed_report = json.loads(report_data)
|
||||
all_experiment_reports[exp_id] = parsed_report
|
||||
|
||||
# 从Basic Information中提取数据
|
||||
basic_info = parsed_report["Experiment Analysis Report"]["1. Basic Information"]
|
||||
total_sessions += int(basic_info["Total Sessions"])
|
||||
total_duration += float(basic_info["Total Duration"].split()[0]) # 去掉"seconds"
|
||||
total_data_points += int(basic_info["Data Points"])
|
||||
total_devices.update([str(i) for i in range(int(basic_info["Devices number"]))])
|
||||
total_sensors += int(basic_info["Sensors number"])
|
||||
|
||||
except (json.JSONDecodeError, KeyError) as e:
|
||||
continue
|
||||
|
||||
if not all_experiment_reports:
|
||||
raise Exception("No experiment reports found")
|
||||
|
||||
# 计算平均值
|
||||
avg_session_duration = total_duration / total_sessions if total_sessions > 0 else 0
|
||||
avg_data_points_per_session = total_data_points / total_sessions if total_sessions > 0 else 0
|
||||
|
||||
# 添加项目级统计数据
|
||||
project_stats = {
|
||||
"project_name": project_name,
|
||||
"total_experiments": len(experiment_ids),
|
||||
"total_sessions": total_sessions,
|
||||
"total_duration": total_duration,
|
||||
"total_data_points": total_data_points,
|
||||
"total_devices": len(total_devices),
|
||||
"total_sensors": total_sensors,
|
||||
"avg_session_duration": avg_session_duration,
|
||||
"avg_data_points_per_session": avg_data_points_per_session,
|
||||
"experiment_names": experiment_names
|
||||
}
|
||||
|
||||
# 将统计数据和原始报告一起发送给分析函数
|
||||
project_data = {
|
||||
"project_stats": project_stats,
|
||||
"experiment_reports": all_experiment_reports
|
||||
}
|
||||
|
||||
# 执行分析
|
||||
analysis_result = await analyze_project_data(project_data)
|
||||
|
||||
if analysis_result:
|
||||
# 保存分析结果
|
||||
await redis.select(202)
|
||||
report_key = f"project_report:{project_id}"
|
||||
await redis.set(report_key, json.dumps(analysis_result))
|
||||
|
||||
# 更新状态为完成
|
||||
status_key = f"project_analysis_status:{project_id}"
|
||||
status_data = {
|
||||
"status": "completed",
|
||||
"completion_time": datetime.now(timezone.utc).isoformat()
|
||||
}
|
||||
await redis.set(status_key, json.dumps(status_data))
|
||||
print("分析报告已保存")
|
||||
else:
|
||||
print("\n7. 分析失败")
|
||||
# 更新失败状态
|
||||
status_data = {
|
||||
"status": "failed",
|
||||
"error": "Failed to generate analysis result",
|
||||
"completion_time": datetime.now(timezone.utc).isoformat()
|
||||
}
|
||||
await redis.set(status_key, json.dumps(status_data))
|
||||
|
||||
except Exception as e:
|
||||
try:
|
||||
await redis.select(202)
|
||||
status_key = f"project_analysis_status:{project_id}"
|
||||
status_data = {
|
||||
"status": "failed",
|
||||
"error": str(e),
|
||||
"completion_time": datetime.now(timezone.utc).isoformat()
|
||||
}
|
||||
await redis.set(status_key, json.dumps(status_data))
|
||||
except Exception as redis_error:
|
||||
print(f"Error updating Redis status: {redis_error}")
|
||||
finally:
|
||||
try:
|
||||
await redis.aclose()
|
||||
mongo_client.close()
|
||||
except Exception as e:
|
||||
print(f"Error closing Redis connection: {e}")
|
||||
|
||||
@router.get("/projects/{project_id}/analysis_status")
|
||||
async def get_project_analysis_status(
|
||||
project_id: str,
|
||||
current_user: UserModel = Depends(get_current_user)
|
||||
):
|
||||
"""获取项目分析任务的状态"""
|
||||
redis = await get_redis()
|
||||
|
||||
try:
|
||||
await redis.select(202)
|
||||
status_key = f"project_analysis_status:{project_id}"
|
||||
status_data = await redis.get(status_key)
|
||||
|
||||
if not status_data:
|
||||
return {
|
||||
"status": "not_started",
|
||||
"project_id": project_id
|
||||
}
|
||||
|
||||
return json.loads(status_data)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error getting analysis status: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
finally:
|
||||
try:
|
||||
await redis.aclose()
|
||||
except Exception as e:
|
||||
print(f"Error closing Redis connection: {e}")
|
||||
|
||||
@router.get("/projects/{project_id}/report")
|
||||
async def get_project_saved_report(
|
||||
project_id: str,
|
||||
current_user: UserModel = Depends(get_current_user)
|
||||
):
|
||||
"""从 Redis db202 读取已保存的项目报告"""
|
||||
redis = await get_redis()
|
||||
|
||||
try:
|
||||
# 选择 db202
|
||||
await redis.select(202)
|
||||
report_key = f"project_report:{project_id}"
|
||||
|
||||
# 获取已保存的报告
|
||||
existing_report = await redis.get(report_key)
|
||||
|
||||
if not existing_report:
|
||||
raise HTTPException(status_code=404, detail="No saved project report found")
|
||||
|
||||
# 返回报告
|
||||
return json.loads(existing_report)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error getting project report: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
finally:
|
||||
try:
|
||||
await redis.aclose()
|
||||
except Exception as e:
|
||||
print(f"Error closing Redis connection: {e}")
|
||||
|
||||
class QuestionModel(BaseModel):
|
||||
"""问题模型"""
|
||||
question: str
|
||||
|
||||
# 项目问答
|
||||
|
||||
@router.post("/lab/projects/{project_id}/qa")
|
||||
async def ask_project_question(
|
||||
project_id: str,
|
||||
question: QuestionModel,
|
||||
current_user: UserModel = Depends(get_current_user)
|
||||
):
|
||||
"""向项目报告提问(异步)"""
|
||||
db = await get_database()
|
||||
redis = await get_redis()
|
||||
|
||||
try:
|
||||
# 验证项目是否存在
|
||||
project = await db.projects.find_one({"_id": ObjectId(project_id)})
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="项目不存在")
|
||||
|
||||
# 生成任务ID
|
||||
task_id = str(ObjectId())
|
||||
|
||||
# 创建后台任务
|
||||
asyncio.create_task(process_project_question(
|
||||
task_id=task_id,
|
||||
project_id=project_id,
|
||||
question=question.question,
|
||||
project=project
|
||||
))
|
||||
|
||||
return {
|
||||
"task_id": task_id,
|
||||
"status": TaskStatus.PENDING,
|
||||
"message": "问题已提交,正在处理中"
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error in ask_project_question: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
finally:
|
||||
try:
|
||||
await redis.aclose()
|
||||
except Exception as e:
|
||||
print(f"Error closing Redis connection: {e}")
|
||||
|
||||
async def process_project_question(task_id: str, project_id: str, question: str, project: dict):
|
||||
"""处理项目问答的后台任务"""
|
||||
redis = await get_redis()
|
||||
|
||||
try:
|
||||
# 更新任务状态为处理中
|
||||
await redis.select(208) # 使用db208存储任务状态
|
||||
await redis.hset(
|
||||
f"task:{task_id}",
|
||||
mapping={
|
||||
"status": TaskStatus.PROCESSING,
|
||||
"project_id": project_id,
|
||||
"question": question
|
||||
}
|
||||
)
|
||||
|
||||
# 从Redis获取分析报告
|
||||
await redis.select(202)
|
||||
report_key = f"project_report:{project_id}"
|
||||
report_data = await redis.get(report_key)
|
||||
|
||||
if not report_data:
|
||||
raise Exception("项目分析报告不存在,请先进行分析")
|
||||
|
||||
# 构建系统提示和用户提示
|
||||
system_prompt = """
|
||||
你是一个专业的项目助手,负责回答关于项目报告的问题。
|
||||
你应该基于项目报告提供准确、专业的回答。
|
||||
回答应当简洁明了,并尽可能引用报告中的具体内容。
|
||||
"""
|
||||
|
||||
# 构建上下文
|
||||
context = {
|
||||
"项目名称": project.get("project_name", "未知项目"),
|
||||
"分析报告": json.loads(report_data)
|
||||
}
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": f"基于以下项目报告回答问题:\n\n项目信息:{json.dumps(context, ensure_ascii=False)}\n\n问题:{question}"}
|
||||
]
|
||||
|
||||
# 调用DeepSeek API获取回答
|
||||
response = client.chat.completions.create(
|
||||
model="deepseek-chat",
|
||||
messages=messages
|
||||
)
|
||||
answer = response.choices[0].message.content
|
||||
|
||||
# 保存对话历史到Redis db207
|
||||
await redis.select(207)
|
||||
chat_history_key = f"project_chat_history:{project_id}"
|
||||
|
||||
# 获取现有历史记录
|
||||
existing_history = await redis.get(chat_history_key)
|
||||
history = json.loads(existing_history) if existing_history else []
|
||||
|
||||
# 添加新的对话
|
||||
history.append({
|
||||
"question": question,
|
||||
"answer": answer,
|
||||
"timestamp": datetime.now(timezone.utc).isoformat()
|
||||
})
|
||||
|
||||
# 保存更新后的历史记录
|
||||
await redis.set(chat_history_key, json.dumps(history))
|
||||
|
||||
# 更新任务状态为完成
|
||||
await redis.select(208)
|
||||
await redis.hset(
|
||||
f"task:{task_id}",
|
||||
mapping={
|
||||
"status": TaskStatus.COMPLETED,
|
||||
"answer": answer,
|
||||
"project_name": project.get("project_name")
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error processing question: {e}")
|
||||
# 更新任务状态为失败
|
||||
await redis.select(208)
|
||||
await redis.hset(
|
||||
f"task:{task_id}",
|
||||
mapping={
|
||||
"status": TaskStatus.FAILED,
|
||||
"error": str(e)
|
||||
}
|
||||
)
|
||||
finally:
|
||||
try:
|
||||
await redis.aclose()
|
||||
except Exception as e:
|
||||
print(f"Error closing Redis connection: {e}")
|
||||
|
||||
@router.get("/lab/projects/{project_id}/qa/history")
|
||||
async def get_project_qa_history(
|
||||
project_id: str,
|
||||
current_user: UserModel = Depends(get_current_user)
|
||||
):
|
||||
"""获取项目问答历史记录"""
|
||||
redis = await get_redis()
|
||||
|
||||
try:
|
||||
# 从Redis db207获取对话历史
|
||||
await redis.select(207)
|
||||
chat_history_key = f"project_chat_history:{project_id}"
|
||||
|
||||
history_data = await redis.get(chat_history_key)
|
||||
if not history_data:
|
||||
return []
|
||||
|
||||
return json.loads(history_data)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error getting QA history: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
finally:
|
||||
try:
|
||||
await redis.aclose()
|
||||
except Exception as e:
|
||||
print(f"Error closing Redis connection: {e}")
|
||||
@@ -0,0 +1,142 @@
|
||||
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
|
||||
from typing import Dict, Set
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
# Connection Manager for WebSocket clients
|
||||
class ConnectionManager:
|
||||
def __init__(self):
|
||||
self.active_connections: Dict[str, Dict[str, Set[WebSocket]]] = {}
|
||||
self._lock = asyncio.Lock()
|
||||
# 添加消息队列
|
||||
self.message_queues: Dict[str, asyncio.Queue] = {}
|
||||
self.broadcast_tasks: Dict[str, asyncio.Task] = {}
|
||||
|
||||
async def connect(self, websocket: WebSocket, experiment_id: str, serial_number: str):
|
||||
"""添加新的WebSocket连接"""
|
||||
async with self._lock:
|
||||
if experiment_id not in self.active_connections:
|
||||
self.active_connections[experiment_id] = {}
|
||||
if serial_number not in self.active_connections[experiment_id]:
|
||||
self.active_connections[experiment_id][serial_number] = set()
|
||||
self.active_connections[experiment_id][serial_number].add(websocket)
|
||||
print(f"新连接已添加到管理器 - 实验ID: {experiment_id}, 设备: {serial_number}")
|
||||
|
||||
async def disconnect(self, websocket: WebSocket, experiment_id: str, serial_number: str):
|
||||
"""移除WebSocket连接"""
|
||||
async with self._lock:
|
||||
try:
|
||||
if (experiment_id in self.active_connections and
|
||||
serial_number in self.active_connections[experiment_id]):
|
||||
self.active_connections[experiment_id][serial_number].remove(websocket)
|
||||
# 只在成功移除连接时打印一次日志
|
||||
print(f"WebSocket连接已断开 - 实验ID: {experiment_id}, 设备: {serial_number}")
|
||||
|
||||
# 清理空集合
|
||||
if not self.active_connections[experiment_id][serial_number]:
|
||||
del self.active_connections[experiment_id][serial_number]
|
||||
if not self.active_connections[experiment_id]:
|
||||
del self.active_connections[experiment_id]
|
||||
except KeyError:
|
||||
pass
|
||||
except Exception as e:
|
||||
print(f"断开连接时出错: {e}")
|
||||
|
||||
async def broadcast_worker(self, experiment_id: str, serial_number: str):
|
||||
"""单独的广播工作器"""
|
||||
queue = self.message_queues[f"{experiment_id}:{serial_number}"]
|
||||
while True:
|
||||
try:
|
||||
message = await queue.get()
|
||||
disconnected = set()
|
||||
|
||||
async with self._lock:
|
||||
connections = self.active_connections[experiment_id][serial_number].copy()
|
||||
|
||||
await asyncio.gather(*[
|
||||
self.send_message(ws, message, disconnected)
|
||||
for ws in connections
|
||||
], return_exceptions=True)
|
||||
|
||||
queue.task_done()
|
||||
|
||||
# 清理断开的连接
|
||||
for websocket in disconnected:
|
||||
await self.disconnect(websocket, experiment_id, serial_number)
|
||||
except Exception as e:
|
||||
print(f"广播工作器错误: {e}")
|
||||
|
||||
async def broadcast(self, message: str, experiment_id: str, serial_number: str):
|
||||
"""使用消息队列进行广播"""
|
||||
queue_key = f"{experiment_id}:{serial_number}"
|
||||
if queue_key not in self.message_queues:
|
||||
self.message_queues[queue_key] = asyncio.Queue()
|
||||
self.broadcast_tasks[queue_key] = asyncio.create_task(
|
||||
self.broadcast_worker(experiment_id, serial_number)
|
||||
)
|
||||
|
||||
await self.message_queues[queue_key].put(message)
|
||||
|
||||
async def send_message(self, websocket: WebSocket, message: str, disconnected: set):
|
||||
"""处理单个连接的消息发送"""
|
||||
try:
|
||||
await websocket.send_text(message)
|
||||
except Exception as e:
|
||||
disconnected.add(websocket)
|
||||
|
||||
# 创建连接管理器实例
|
||||
manager = ConnectionManager()
|
||||
|
||||
@router.websocket("/lab/ws/{experiment_id}/{serial_number}")
|
||||
async def websocket_endpoint(
|
||||
websocket: WebSocket,
|
||||
experiment_id: str,
|
||||
serial_number: str
|
||||
):
|
||||
try:
|
||||
await websocket.accept()
|
||||
print(f"WebSocket连接已建立 - 实验ID: {experiment_id}, 设备: {serial_number}")
|
||||
|
||||
# 添加到连接管理器
|
||||
await manager.connect(websocket, experiment_id, serial_number)
|
||||
|
||||
# 发送初始连接响应
|
||||
response = {
|
||||
"type": "connect_response",
|
||||
"status": "success",
|
||||
"message": "连接成功",
|
||||
"experiment_id": experiment_id,
|
||||
"serial_number": serial_number
|
||||
}
|
||||
await websocket.send_text(json.dumps(response))
|
||||
|
||||
try:
|
||||
# 保持连接并持续接收消息
|
||||
while True:
|
||||
try:
|
||||
data = await websocket.receive_text()
|
||||
# 只处理非心跳消息
|
||||
if data != "ping":
|
||||
message = json.loads(data)
|
||||
await manager.broadcast(data, experiment_id, serial_number)
|
||||
except WebSocketDisconnect:
|
||||
break
|
||||
except json.JSONDecodeError:
|
||||
continue # 忽略无效的JSON数据
|
||||
|
||||
except Exception as e:
|
||||
print(f"处理WebSocket消息时出错: {e}")
|
||||
|
||||
finally:
|
||||
await manager.disconnect(websocket, experiment_id, serial_number)
|
||||
|
||||
@router.get("/lab/status")
|
||||
async def get_status():
|
||||
return {
|
||||
"status": "running",
|
||||
"timestamp": time.time(),
|
||||
"active_connections": len(manager.active_connections)
|
||||
}
|
||||
Reference in New Issue
Block a user