This commit is contained in:
2025-01-21 10:40:29 +00:00
parent 3204ec5ccb
commit e0b92b5f32
20 changed files with 3899 additions and 0 deletions
+74
View File
@@ -0,0 +1,74 @@
from typing import Dict, Any
import os
from datetime import datetime, timezone
# MongoDB 配置
MONGODB_URL = "mongodb://localhost:27017"
DATABASE_NAME = "lab"
# Redis 配置
REDIS_URL = "redis://localhost:6379"
# DeepSeek API 配置
DEEPSEEK_API_CONFIG = {
"base_url": "https://api.deepseek.com/v1",
"api_key": "sk-3027fb3c810b4e17985fa397d41250b9"
}
# 文件上传配置
UPLOAD_PATH = "/obscura/task/references"
# Redis 数据库映射
REDIS_DB_MAP = {
"experiment_analysis": 201, # 实验分析报告
"project_analysis": 202, # 项目分析报告
"reference_analysis": 203, # 文献分析报告
"reference_summary": 204, # 文献汇总分析
"project_memo": 205, # 项目备忘录
"experiment_memo": 206, # 实验备忘录
"chat_history": 207, # 聊天历史记录
"task_status": 208 # 任务状态
}
# 任务状态枚举
class TaskStatus:
PENDING = "pending"
PROCESSING = "processing"
COMPLETED = "completed"
FAILED = "failed"
# 线程池配置
THREAD_POOL_CONFIG = {
"pdf_workers": 3,
"analysis_workers": 3,
"pro_analysis_workers": 3
}
# API 响应格式
def format_response(data: Any, message: str = "Success") -> Dict[str, Any]:
"""格式化 API 响应"""
return {
"status": "success",
"message": message,
"data": data,
"timestamp": datetime.now(timezone.utc).isoformat()
}
def format_error(message: str, status_code: int = 500) -> Dict[str, Any]:
"""格式化错误响应"""
return {
"status": "error",
"message": message,
"error_code": status_code,
"timestamp": datetime.now(timezone.utc).isoformat()
}
# 文件类型配置
ALLOWED_FILE_TYPES = {
"application/pdf",
"application/msword",
"application/vnd.openxmlformats-officedocument.wordprocessingml.document"
}
# 确保上传目录存在
os.makedirs(UPLOAD_PATH, exist_ok=True)
+53
View File
@@ -0,0 +1,53 @@
from bson import ObjectId
from pydantic import BaseModel, Field
from motor.motor_asyncio import AsyncIOMotorClient
from contextlib import asynccontextmanager
from redis import asyncio as aioredis
# Database Configuration
MONGODB_URL = "mongodb://lab:y6aHwySAhzrbibLD@222.186.10.253:27017/lab"
REDIS_URL = "redis://:Obscura@2024@222.186.10.253:6379"
class PyObjectId(ObjectId):
"""
自定义ObjectId类,用于在Pydantic模型中处理MongoDB的ObjectId
"""
@classmethod
def __get_validators__(cls):
yield cls.validate
@classmethod
def validate(cls, v, handler):
if not ObjectId.is_valid(v):
raise ValueError("Invalid ObjectId")
return ObjectId(v)
@classmethod
def __get_pydantic_json_schema__(cls, _schema_cache, **_kwargs):
return {
'type': 'string',
'description': 'ObjectId',
'pattern': r'^[0-9a-fA-F]{24}$'
}
class Database:
"""数据库连接管理类"""
client: AsyncIOMotorClient = None
async def get_database() -> AsyncIOMotorClient:
return Database.client.lab
async def connect_to_mongo():
Database.client = AsyncIOMotorClient(MONGODB_URL)
async def close_mongo_connection():
if Database.client:
Database.client.close()
async def get_redis():
redis = aioredis.from_url(
REDIS_URL,
encoding="utf-8",
decode_responses=True
)
return redis
+72
View File
@@ -0,0 +1,72 @@
# Standard library imports
from contextlib import asynccontextmanager
# Third-party imports
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from openai import OpenAI
# Local application imports
from .cores.config import DEEPSEEK_API_CONFIG
from .cores.db import (
connect_to_mongo,
close_mongo_connection,
)
from .routers.login import router as login_router
from .routers.project import router as project_router
from .routers.device import router as device_router
from .routers.experiment_device import router as experiment_device_router
from .routers.experiment import router as experiment_router
from .routers.experiment_report import router as experiment_report_router
from .routers.websocket import router as websocket_router
from .routers.project_report import router as project_report_router
from .routers.memo import router as memo_router
from .routers.paper import router as paper_router
# DeepSeek API Configuration
client = OpenAI(
base_url=DEEPSEEK_API_CONFIG["base_url"],
api_key=DEEPSEEK_API_CONFIG["api_key"]
)
@asynccontextmanager
async def lifespan(app: FastAPI):
# Startup
await connect_to_mongo()
yield
# Shutdown
await close_mongo_connection()
# 更新 FastAPI 实例化
app = FastAPI(lifespan=lifespan)
# 添加路由
app.include_router(login_router)
app.include_router(project_router)
app.include_router(device_router)
app.include_router(experiment_device_router)
app.include_router(experiment_router)
app.include_router(experiment_report_router)
app.include_router(websocket_router)
app.include_router(project_report_router)
app.include_router(memo_router)
app.include_router(paper_router)
# CORS configuration
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
expose_headers=["*"]
)
# Health check endpoint
@app.get("/health")
async def health_check():
return {"status": "healthy"}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=6000)
+614
View File
@@ -0,0 +1,614 @@
import json
import os
import asyncio
from datetime import datetime, timezone
from typing import List, Dict
import PyPDF2
import aiohttp
from concurrent.futures import ThreadPoolExecutor
from openai import OpenAI
from ..cores.config import DEEPSEEK_API_CONFIG
from ..cores.db import get_redis
from enum import Enum
# 创建OpenAI客户端
client = OpenAI(
base_url=DEEPSEEK_API_CONFIG["base_url"],
api_key=DEEPSEEK_API_CONFIG["api_key"]
)
# 修改函数定义为异步函数
async def analyze_reference_summary(reference_data):
system_prompt = """
You are an AI assistant responsible for analyzing paper reports.
You will summarize and analyze all paper reports and generate a comprehensive analysis report in JSON format.
The JSON structure must strictly follow the provided template.
"""
user_prompt = f"""Analyze the following paper reports:
Paper reports: {json.dumps(reference_data, ensure_ascii=False)}
Generate a JSON response with the following structure:
{{
"Paper Summary Report": {{
"overview": {{
"total_papers": "[Number of papers]",
"time_range": {{"start_year": "[Start year]", "end_year": "[End year]"}},
"main_research_areas": "[Main research areas of the papers]"
}},
"research_trends": {{
"major_themes": "[Major themes areas of the papers]",
"common_methodologies": "[Common methodologies used in the papers]",
"emerging_topics": "[Emerging topics in the papers]"
}},
"key_findings": {{
"theoretical_advances": "[Theoretical advances in the papers]",
"experimental_results": "[Experimental results in the papers]",
"common_conclusions": "[Common conclusions in the papers]"
}},
"research_gaps": {{
"current_limitations": "[Current limitations]",
"unexplored_areas": "[Unexplored areas]",
"technical_challenges": "[Technical challenges"
}},
"future_directions": {{
"potential_applications": "[Potential future applications]",
"methodological_suggestions": "[Methodological suggestions based on paper summary]"
}},
"impact_assessment": {{
"academic_influence": "[Summarize the academic influence of the paper]",
"practical_value": "[Summarize the practical value of the paper]"
}}
}}
}}
"""
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt}
]
try:
response = client.chat.completions.create(
model="deepseek-chat",
messages=messages,
response_format={'type': 'json_object'}
)
return json.loads(response.choices[0].message.content)
except Exception as e:
print(f"Error calling DeepSeek API: {e}")
return None
def run_analysis_in_thread(project_id: str, reference_ids: List[str]):
"""在独立线程中运行分析任务"""
# 创建新的事件循环
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
# 在新的事件循环中运行异步任务
loop.run_until_complete(process_reference_analysis(project_id, reference_ids))
except Exception as e:
print(f"Error in analysis thread: {str(e)}")
finally:
loop.close()
async def process_reference_analysis(project_id: str, reference_ids: List[str]):
"""后台处理文献分析任务"""
redis = await get_redis()
try:
# 更新状态键
await redis.select(204)
status_key = f"reference_analysis_status:{project_id}"
# 从Redis db203获取所有实验报告
await redis.select(203)
all_reference_reports = {}
completed_count = 0
for ref_id in reference_ids:
report_key = f"reference_report:{ref_id}"
report_data = await redis.get(report_key)
if report_data:
try:
parsed_report = json.loads(report_data)
all_reference_reports[ref_id] = parsed_report
completed_count += 1
# 更新进度
await redis.select(204)
current_status = await redis.get(status_key)
if current_status:
status_data = json.loads(current_status)
status_data["completed_references"] = completed_count
await redis.set(status_key, json.dumps(status_data))
await redis.select(203)
except json.JSONDecodeError:
print(f"Error parsing reference {ref_id} report")
continue
if all_reference_reports:
analysis_result = await analyze_reference_summary(all_reference_reports)
if analysis_result:
await redis.select(204)
report_key = f"reference_summary_report:{project_id}"
await redis.set(report_key, json.dumps(analysis_result))
# 更新最终状态
status_data = {
"status": "completed",
"completion_time": datetime.now(timezone.utc).isoformat(),
"total_references": len(reference_ids),
"completed_references": completed_count
}
await redis.set(status_key, json.dumps(status_data))
print(f"Reference analysis completed for project {project_id}")
else:
# 更新失败状态
status_data = {
"status": "failed",
"error": "Failed to generate analysis result",
"completion_time": datetime.now(timezone.utc).isoformat()
}
await redis.set(status_key, json.dumps(status_data))
print(f"Failed to generate analysis for project {project_id}")
else:
# 更新失败状态
status_data = {
"status": "failed",
"error": "No valid reference reports found",
"completion_time": datetime.now(timezone.utc).isoformat()
}
await redis.set(status_key, json.dumps(status_data))
print(f"No valid reference reports found for project {project_id}")
except Exception as e:
print(f"Error in background reference analysis: {str(e)}")
try:
await redis.select(204)
status_data = {
"status": "failed",
"error": str(e),
"completion_time": datetime.now(timezone.utc).isoformat()
}
await redis.set(status_key, json.dumps(status_data))
except Exception as redis_error:
print(f"Error updating Redis status: {redis_error}")
finally:
try:
await redis.aclose()
except Exception as e:
print(f"Error closing Redis connection: {e}")
async def process_reference_question(task_id: str, reference_id: str, question: str, reference: dict):
"""处理文献问答的后台任务"""
redis = await get_redis()
try:
# 更新任务状态为处理中
await redis.select(208) # 使用db208存储任务状态
await redis.hset(
f"task:{task_id}",
mapping={
"status": TaskStatus.PROCESSING,
"reference_id": reference_id,
"question": question
}
)
# 从Redis获取分析报告
await redis.select(203)
report_key = f"reference_report:{reference_id}"
report_data = await redis.get(report_key)
if not report_data:
raise Exception("文献分析报告不存在,请先进行分析")
# 读取PDF文件内容
file_path = reference.get("reference_link")
if not file_path or not os.path.exists(file_path):
raise Exception("文献文件不存在")
# 读取PDF内容
pdf_content = ""
with open(file_path, 'rb') as file:
pdf_reader = PyPDF2.PdfReader(file)
for page in pdf_reader.pages:
content = page.extract_text()
pdf_content += content
# 限制内容长度
MAX_CHARS = 180000
pdf_content = pdf_content[:MAX_CHARS]
# 构建系统提示和用户提示
system_prompt = """
你是一个专业的学术助手,负责回答关于学术文献的问题。
你应该基于文献内容和分析报告提供准确、专业的回答。
回答应当简洁明了,并尽可能引用文献中的具体内容。
"""
# 构建上下文
context = {
"文献内容": pdf_content,
"分析报告": json.loads(report_data)
}
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": f"基于以下文献内容和分析报告回答问题:\n\n文献信息:{json.dumps(context, ensure_ascii=False)}\n\n问题:{question}"}
]
# 调用DeepSeek API获取回答
response = client.chat.completions.create(
model="deepseek-chat",
messages=messages
)
answer = response.choices[0].message.content
# 保存对话历史到Redis db207
await redis.select(207)
chat_history_key = f"chat_history:{reference_id}"
# 获取现有历史记录
existing_history = await redis.get(chat_history_key)
history = json.loads(existing_history) if existing_history else []
# 添加新的对话
history.append({
"question": question,
"answer": answer,
"timestamp": datetime.now(timezone.utc).isoformat()
})
# 保存更新后的历史记录
await redis.set(chat_history_key, json.dumps(history))
# 更新任务状态为完成
await redis.select(208)
await redis.hset(
f"task:{task_id}",
mapping={
"status": TaskStatus.COMPLETED,
"answer": answer,
"reference_title": reference.get("reference_title")
}
)
except Exception as e:
print(f"Error processing question: {e}")
# 更新任务状态为失败
await redis.select(208)
await redis.hset(
f"task:{task_id}",
mapping={
"status": TaskStatus.FAILED,
"error": str(e)
}
)
finally:
try:
await redis.aclose()
except Exception as e:
print(f"Error closing Redis connection: {e}")
# 在全局范围创建线程池
pdf_thread_pool = ThreadPoolExecutor(max_workers=3) # 限制并发PDF处理数量
# 创建异步HTTP客户端会话
async def get_aiohttp_session():
return aiohttp.ClientSession(
base_url="https://api.deepseek.com/v1/", # 添加了末尾的斜杠
headers={"Authorization": f"Bearer sk-3027fb3c810b4e17985fa397d41250b9"}
)
async def read_pdf_async(file_path: str) -> str:
"""在线程池中异步读取PDF"""
def read_pdf():
try:
pdf_content = ""
with open(file_path, 'rb') as file:
pdf_reader = PyPDF2.PdfReader(file)
for page in pdf_reader.pages:
content = page.extract_text()
pdf_content += content
# 根据内容长度返回不同的结果
content_length = len(pdf_content)
print(f"\nPDF原始内容字符数: {content_length}")
if content_length <= 180000:
print("文档长度 ≤ 180000,返回完整内容")
return pdf_content
elif content_length <= 200000:
print("文档长度在 180000-200000 之间,截取前 180000 个字符")
return pdf_content[:180000]
else:
print("文档长度 > 200000,返回完整内容供分段处理")
return pdf_content # 返回完整内容,由调用者处理分段
except Exception as e:
print(f"Error reading PDF: {e}")
return ""
loop = asyncio.get_event_loop()
return await loop.run_in_executor(pdf_thread_pool, read_pdf)
async def call_deepseek_api_async(messages: list) -> dict:
"""异步调用DeepSeek API"""
async with await get_aiohttp_session() as session:
async with session.post("/chat/completions", json={
"model": "deepseek-chat",
"messages": messages,
"response_format": {"type": "json_object"}
}) as response:
if response.status == 200:
data = await response.json()
return json.loads(data["choices"][0]["message"]["content"])
else:
raise Exception(f"API调用失败: {await response.text()}")
async def analyze_long_document_async(content: str) -> List[dict]:
"""分段分析长文档"""
# 将内容分成多个段落,每段约60000个字符(估算后约50000 tokens
segments = []
content_length = len(content)
segment_size = 60000 # 减小段落大小
print(f"\n开始分段处理,总字符数: {content_length}")
print(f"每段大小: {segment_size} 字符")
for i in range(0, content_length, segment_size):
segment = content[i:i + segment_size]
segments.append(segment)
print(f"文档已分段,共 {len(segments)} 个段落")
# 对每个段落进行分析
analysis_results = []
for i, segment in enumerate(segments):
print(f"\n开始分析第 {i+1}/{len(segments)} 个段落...")
system_prompt = f"""
You are an AI assistant tasked with analyzing part {i+1} of {len(segments)} of an academic paper.
Generate a comprehensive analysis in JSON format covering both basic information and content analysis.
Note that this is part {i+1} of a longer document, so focus on the content provided.
The JSON structure must strictly follow the provided template.
Also generate a Mermaid flowchart code to visualize the research methodology if this segment contains methodology information.
Create a detailed and comprehensive flowchart that accurately represents the paper's research methodology.
"""
user_prompt = f"""Analyze the following paper segment and extract all relevant information:
Content: {segment}
Generate a JSON response with the following structure:
{{
"1. Basic Information": {{
"author": "[Author name(s) and affiliations]",
"publication_date": "[Publication date in YYYY-MM format]",
"title": "[Full title of the document]",
"journal_publisher": "[Journal name or publisher details]",
"document_type": "[Type: journal article/book/conference paper etc.]"
}},
"2. Content Analysis": {{
"abstract": "[Paper abstract]",
"research_purpose": "[Main objectives and research questions]",
"methodology": "[Research methods, data collection and analysis approaches]",
"main_arguments": "[Key theoretical frameworks and arguments]",
"conclusions": "[Complete findings and conclusions]",
"innovations": "[Novel contributions and original aspects]"
}},
"flowchart": "[If this segment contains methodology information, generate a Mermaid flowchart code that visualizes the research methodology. Follow these rules:
1. Use 'graph TD' for top-down flow
2. Each node should be in format: id[text] where:
- id is a unique identifier (A, B1, B2, etc.)
- text should be simple and clear, using only letters, numbers, and spaces
- DO NOT use any special characters including parentheses, colons, commas
- abbreviations should be written without parentheses, e.g., 'DNN' not '(DNN)'
- use space instead of special characters, e.g., 'Deep Learning Model' not 'Deep-Learning/Model'
3. Connections use '-->' between nodes
4. Ensure each line ends with a proper node reference
Create a detailed flowchart that shows:
- Research objectives and questions
- All major research methods used
- Data collection and analysis processes
- Key experimental or analytical steps
- Result synthesis and conclusion formation
Make the flowchart as detailed as possible while maintaining clarity.
If this segment does not contain methodology information, set this field to null.]"
}}
"""
try:
result = await call_deepseek_api_async([
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt}
])
if result:
print(f"{i+1} 个段落分析完成")
analysis_results.append(result)
# 等待一小段时间避免API限制
await asyncio.sleep(1)
else:
print(f"{i+1} 个段落分析失败")
except Exception as e:
print(f"分析段落时出错: {str(e)}")
continue
return analysis_results
async def merge_analysis_results(results: List[dict]) -> dict:
"""合并多个分析结果"""
if not results:
return {}
print(f"开始合并 {len(results)} 个分析结果...")
system_prompt = """
You are an AI assistant tasked with merging multiple analysis results of different parts of the same academic paper.
Generate a comprehensive merged analysis in JSON format.
The JSON structure must strictly follow the provided template.
Ensure the merged result is coherent and eliminates redundancy.
For the flowchart, select the most comprehensive one from the input results, or combine multiple flowcharts if they contain complementary information.
When combining flowcharts, ensure the result follows Mermaid syntax rules and avoids special characters.
Create a detailed and comprehensive flowchart that accurately represents the paper's complete research methodology.
"""
user_prompt = f"""Merge the following analysis results into a single coherent analysis:
Analysis Results: {json.dumps(results, ensure_ascii=False)}
Generate a JSON response with the following structure:
{{
"1. Basic Information": {{
"author": "[Merged author information]",
"publication_date": "[Publication date]",
"title": "[Complete title]",
"journal_publisher": "[Journal/publisher information]",
"document_type": "[Document type]"
}},
"2. Content Analysis": {{
"abstract": "[Complete abstract]",
"research_purpose": "[Comprehensive research objectives]",
"methodology": "[Complete methodology description]",
"main_arguments": "[Comprehensive theoretical frameworks and arguments]",
"conclusions": "[Complete findings and conclusions]",
"innovations": "[Complete list of innovations]"
}},
"flowchart": "[Select or combine the flowcharts following these rules:
1. Use 'graph TD' for top-down flow
2. Each node should be in format: id[text] where:
- id is a unique identifier (A, B1, B2, etc.)
- text should be simple and clear, using only letters, numbers, and spaces
- DO NOT use any special characters including parentheses, colons, commas
- abbreviations should be written without parentheses, e.g., 'DNN' not '(DNN)'
- use space instead of special characters, e.g., 'Deep Learning Model' not 'Deep-Learning/Model'
3. Connections use '-->' between nodes
4. Ensure each line ends with a proper node reference
Create a detailed flowchart that shows:
- Research objectives and questions
- All major research methods used
- Data collection and analysis processes
- Key experimental or analytical steps
- Result synthesis and conclusion formation
Make the flowchart as detailed as possible while maintaining clarity.
If no valid flowchart is found in any segment, set this field to null.]"
}}
"""
result = await call_deepseek_api_async([
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt}
])
if result:
print("分析结果合并完成")
else:
print("分析结果合并失败")
return result
async def analyze_reference_document_async(content: str):
"""分析文献的基本信息和内容"""
system_prompt = """
You are an AI assistant tasked with analyzing academic paper.
Generate a comprehensive analysis in JSON format covering both basic information and content analysis.
The JSON structure must strictly follow the provided template.
Also generate a Mermaid flowchart code to visualize the research methodology.
Ensure the flowchart follows strict formatting rules to avoid parsing errors.
Create a detailed and comprehensive flowchart that accurately represents the paper's research methodology.
"""
user_prompt = f"""Analyze the following paper and extract all relevant information:
Content: {content}
Generate a JSON response with the following structure:
{{
"1. Basic Information": {{
"author": "[Author name(s) and affiliations]",
"publication_date": "[Publication date in YYYY-MM format]",
"title": "[Full title of the document]",
"journal_publisher": "[Journal name or publisher details]",
"document_type": "[Type: journal article/book/conference paper etc.]"
}},
"2. Content Analysis": {{
"abstract": "[Paper abstract]",
"research_purpose": "[Main objectives and research questions]",
"methodology": "[Research methods, data collection and analysis approaches]",
"main_arguments": "[Key theoretical frameworks and arguments]",
"conclusions": "[Primary findings and conclusions]",
"innovations": "[Novel contributions and original aspects]"
}},
"flowchart": "[Generate a Mermaid flowchart code that visualizes the research methodology. Follow these rules:
1. Use 'graph TD' for top-down flow
2. Each node should be in format: id[text] where:
- id is a unique identifier (A, B1, B2, etc.)
- text should be simple and clear, using only letters, numbers, and spaces
- DO NOT use any special characters including parentheses, colons, commas
- abbreviations should be written without parentheses, e.g., 'DNN' not '(DNN)'
- use space instead of special characters, e.g., 'Deep Learning Model' not 'Deep-Learning/Model'
3. Connections use '-->' between nodes
4. Ensure each line ends with a proper node reference
Create a detailed flowchart that shows:
- Research objectives and questions
- All major research methods used
- Data collection and analysis processes
- Key experimental or analytical steps
- Result synthesis and conclusion formation
Make the flowchart as detailed as possible while maintaining clarity.
If the paper does not contain clear methodology information, set this field to null.]"
}}
"""
return await call_deepseek_api_async([
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt}
])
async def analyze_reference_value_async(content_analysis: dict):
"""基于内容分析结果评估文献的价值"""
system_prompt = """
You are an AI assistant tasked with evaluating the value of academic paper based on its content analysis.
Generate a comprehensive value evaluation in JSON format.
The JSON structure must strictly follow the provided template.
"""
user_prompt = f"""Based on the following content analysis, evaluate the paper's value:
Content Analysis: {json.dumps(content_analysis, ensure_ascii=False)}
Generate a JSON response with the following structure:
{{
"3. Value Evaluation": {{
"academic_contribution": "[Significance to the field of study]",
"practical_significance": "[Real-world applications and implications]",
"limitations": "[Research constraints and weaknesses]",
"implications": "[Suggestions for future research and practice]"
}}
}}
"""
return await call_deepseek_api_async([
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt}
])
+205
View File
@@ -0,0 +1,205 @@
from fastapi import APIRouter, HTTPException, Depends
from typing import List
from datetime import datetime, timezone
from bson import ObjectId
from ..cores.db import PyObjectId, get_database
from .login import get_current_user, UserModel
router = APIRouter()
@router.post("/lab/user/devices/{serial_number}")
async def add_device_to_user(
serial_number: str,
current_user: UserModel = Depends(get_current_user)
):
"""
通过序列号添加设备到用户设备列表
参数:
serial_number: 设备序列号
current_user: 当前登录用户
返回:
设备信息和添加状态
"""
db = await get_database()
try:
# 查找具有该序列号的设备
device = await db.devices.find_one({
"serial_numbers": {
"$elemMatch": {
"serial": serial_number
}
}
})
if not device:
raise HTTPException(status_code=404, detail="Device not found or invalid serial number")
# 检查该序列号是否已被其他用户激活
active_device = await db.user_devices.find_one({
"serial_number": serial_number,
"status": "active"
})
if active_device:
raise HTTPException(status_code=400, detail="Device is already in use by another user")
# Check if this serial number has already been added
existing = await db.user_devices.find_one({
"user_id": current_user.id,
"serial_number": serial_number
})
if existing:
if existing["status"] == "active":
raise HTTPException(status_code=400, detail="This device is already active")
else:
# 如果设备存在但状态是inactive,则重新激活
await db.user_devices.update_one(
{"_id": existing["_id"]},
{"$set": {"status": "active"}}
)
return {
"message": "Device reactivated successfully",
"device": {
"user_device_id": str(existing["_id"]),
"device_id": str(existing["device_id"]),
"serial_number": serial_number,
"device_name": existing["device_name"],
"device_type": existing["device_type"],
"device_number": existing["device_number"],
"sensors": existing.get("sensors", []),
"status": "active"
}
}
# 创建新的关联记录,包含完整的设备信息、序列号和传感器信息
user_device_dict = {
"user_id": current_user.id,
"device_id": device["_id"],
"serial_number": serial_number,
"device_name": device["device_name"],
"device_type": device["device_type"],
"device_number": device["device_number"],
"sensors": device.get("sensors", []), # 添加传感器信息
"status": "active", # 添加状态字段
"add_time": datetime.now(timezone.utc) # 添加时间记录
}
# 只有当序列号状态为available时才更新为in_use
await db.devices.update_one(
{
"_id": device["_id"],
"serial_numbers": {
"$elemMatch": {
"serial": serial_number,
"status": "available" # 只匹配状态为available的序列号
}
}
},
{
"$set": {
"serial_numbers.$.status": "in_use"
}
}
)
result = await db.user_devices.insert_one(user_device_dict)
# 返回设备信息
return {
"message": "Device added successfully",
"device": {
"user_device_id": str(result.inserted_id),
"device_id": str(device["_id"]),
"serial_number": serial_number,
"device_name": device["device_name"],
"device_type": device["device_type"],
"device_number": device["device_number"],
"sensors": device.get("sensors", []), # 在响应中包含传感器信息
"status": "active" # 添加状态字段"status": "active" # 添加状态字段
}
}
except Exception as e:
print(f"Error adding device to user: {str(e)}")
raise HTTPException(status_code=500, detail=f"Failed to add device: {str(e)}")
@router.get("/lab/userdevices")
async def get_user_devices(current_user: UserModel = Depends(get_current_user)):
"""
获取当前用户的所有传感器列表
返回:
List[Dict]: 传感器列表,包含当前用户关联的所有传感器信息
"""
db = await get_database()
try:
devices = []
async for device in db.user_devices.find({"user_id": current_user.id}):
devices.append({
"user_device_id": str(device["_id"]),
"device_id": str(device["device_id"]),
"serial_number": device["serial_number"], # 添加序列号
"device_name": device["device_name"],
"device_type": device["device_type"],
"device_number": device["device_number"],
"sensors": device.get("sensors", []), # 在响应中包含传感器信息
"status": device.get("status", "active")
})
return devices
except Exception as e:
print(f"Error getting user devices: {str(e)}")
raise HTTPException(status_code=500, detail=f"Failed to get user devices: {str(e)}")
@router.delete("/lab/user/devices/{device_id}")
async def remove_user_device(
device_id: str,
current_user: UserModel = Depends(get_current_user)
):
"""
将设备标记为非活动状态
参数:
device_id: 用户设备的ID
current_user: 当前登录用户
"""
db = await get_database()
try:
# 验证设备是否存在且属于当前用户,且状态为active
device = await db.user_devices.find_one({
"_id": ObjectId(device_id),
"user_id": current_user.id,
"status": "active" # 只能停用活动状态的设备
})
if not device:
raise HTTPException(status_code=404, detail="Active device not found or unauthorized access")
# 更新设备状态为非活动
result = await db.user_devices.update_one(
{
"_id": ObjectId(device_id),
"user_id": current_user.id
},
{
"$set": {
"status": "inactive",
"deactivate_time": datetime.now(timezone.utc) # 记录停用时间
}
}
)
if result.modified_count == 0:
raise HTTPException(status_code=500, detail="Failed to deactivate device")
return {
"message": "Device deactivated successfully",
"device_id": device_id,
"status": "inactive", # 添加状态字段
"deactivate_time": datetime.now(timezone.utc).isoformat()
}
except Exception as e:
print(f"Error deactivating device: {str(e)}")
raise HTTPException(status_code=500, detail=f"Failed to deactivate device: {str(e)}")
+632
View File
@@ -0,0 +1,632 @@
from fastapi import APIRouter, HTTPException, Depends, Query
from typing import List, Optional
from datetime import datetime, timezone
from bson import ObjectId
from pydantic import BaseModel, Field
import json
from ..cores.db import PyObjectId, get_database, get_redis
from .login import get_current_user, UserModel
from io import BytesIO, StringIO
import zipfile
import csv
from fastapi.responses import StreamingResponse
router = APIRouter()
class ExperimentStatus:
ACTIVE = "active"
COMPLETED = "completed"
class ExperimentModel(BaseModel):
id: Optional[PyObjectId] = Field(alias="_id", default=None)
project_id: PyObjectId
experiment_name: str
create_time: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
description: str
status: str = Field(default=ExperimentStatus.ACTIVE) # 添加状态字段
class Config:
populate_by_name = True
arbitrary_types_allowed = True
json_encoders = {ObjectId: str}
class ExperimentCreate(BaseModel):
"""实验创建请求模型"""
project_id: str
experiment_name: str
description: str | None = None # 修改为可选字段,默认为None
# 创建实验
@router.post("/lab/experiments")
async def create_experiment(
experiment: ExperimentCreate,
current_user: UserModel = Depends(get_current_user)
):
"""创建新实验"""
db = await get_database()
try:
# 验证项目是否存在且于当前用户
project = await db.projects.find_one({
"_id": ObjectId(experiment.project_id),
"user_id": current_user.id
})
if not project:
raise HTTPException(status_code=404, detail="Project not found or unauthorized access")
# 构造实验文档
experiment_dict = {
"project_id": ObjectId(experiment.project_id),
"experiment_name": experiment.experiment_name,
"description": experiment.description or "", # 如果None则使用空字符串
"create_time": datetime.now(timezone.utc),
"status": ExperimentStatus.ACTIVE
}
result = await db.experiments.insert_one(experiment_dict)
return {
"message": "Experiment created successfully",
"id": str(result.inserted_id)
}
except Exception as e:
print(f"Error creating experiment: {str(e)}")
raise HTTPException(status_code=500, detail=f"Failed to create experiment: {str(e)}")
# 获取实验列表
@router.get("/lab/experiments")
async def get_experiments(
project_id: str,
current_user: UserModel = Depends(get_current_user)
):
"""获取项目下的所有实验"""
db = await get_database()
# 验证项目是否存在且属于当前用户
project = await db.projects.find_one({
"_id": ObjectId(project_id),
"user_id": current_user.id
})
if not project:
raise HTTPException(status_code=404, detail="Project not found or unauthorized access")
# 查询实验列表
experiments = []
async for experiment in db.experiments.find({"project_id": ObjectId(project_id)}):
experiments.append(ExperimentModel(**experiment))
return experiments
# 数据模型
class ExperimentSession(BaseModel):
"""实会话模型,记录每次实验的开始和结束时间"""
id: Optional[PyObjectId] = Field(alias="_id", default=None)
experiment_id: PyObjectId
user_id: PyObjectId
start_time: datetime
end_time: Optional[datetime] = None
duration: Optional[float] = None # 持续时间(秒)
class Config:
populate_by_name = True
arbitrary_types_allowed = True
json_encoders = {ObjectId: str}
class ExperimentData(BaseModel):
"""实验数据模型"""
id: Optional[PyObjectId] = Field(alias="_id", default=None)
experiment_id: PyObjectId
session_ids: List[PyObjectId] # 修改为session_ids数组
user_id: PyObjectId
device_id: str
sensor_name: str
last_update: datetime # 添加最后更新时间
class Config:
populate_by_name = True
arbitrary_types_allowed = True
json_encoders = {ObjectId: str}
# 开始实验
@router.post("/lab/experiments/{experiment_id}/start")
async def start_experiment_session(
experiment_id: str,
current_user: UserModel = Depends(get_current_user)
):
"""开始新的实验会话"""
db = await get_database()
redis = await get_redis()
try:
# 首先获取该实验关联的所有设备
experiment_devices = []
async for device in db.experiment_devices.find({"experiment_id": ObjectId(experiment_id)}):
experiment_devices.append({
"user_device_id": str(device["user_device_id"]),
"device_name": device["device_name"],
"serial_number": device["serial_number"],
"sensors": device.get("sensors", [])
})
if not experiment_devices:
raise HTTPException(status_code=400, detail="No devices added to the experiment")
# 更新实验状态到Redis (新增)
await redis.select(199)
for device in experiment_devices:
status_key = f"experiment_status:{experiment_id}:{device['serial_number']}"
await redis.set(status_key, "active")
# 创建实验会话记录
session = {
"experiment_id": ObjectId(experiment_id),
"user_id": current_user.id,
"start_time": datetime.now(timezone.utc),
"devices": experiment_devices
}
result = await db.experiment_sessions.insert_one(session)
session_id = str(result.inserted_id)
return {
"message": "Experiment session started",
"session_id": session_id,
"devices": experiment_devices
}
except Exception as e:
print(f"Error starting experiment session: {str(e)}")
raise HTTPException(status_code=500, detail=f"Failed to start experiment session: {str(e)}")
finally:
if redis:
await redis.aclose()
# 停止实验
@router.post("/lab/experiments/{experiment_id}/stop")
async def stop_experiment_session(
experiment_id: str,
session_data: dict,
current_user: UserModel = Depends(get_current_user)
):
"""停止实验会话"""
db = await get_database()
redis = await get_redis()
try:
session_id = session_data.get("session_id")
if not session_id:
raise HTTPException(status_code=422, detail="Missing session_id parameter")
# 验证session是否存在且未结束
session = await db.experiment_sessions.find_one({
"_id": ObjectId(session_id),
"experiment_id": ObjectId(experiment_id),
"end_time": None
})
if not session:
raise HTTPException(status_code=404, detail="Active experiment session not found")
# 更新实验状态到Redis (新增)
await redis.select(199) # 使用同一个db
for device in session.get("devices", []):
status_key = f"experiment_status:{experiment_id}:{device['serial_number']}"
await redis.set(status_key, "inactive")
# 确保start_time是带时区的
start_time = session["start_time"]
if start_time.tzinfo is None:
start_time = start_time.replace(tzinfo=timezone.utc)
end_time = datetime.now(timezone.utc) # 确保end_time带有UTC时区
duration = (end_time - start_time).total_seconds()
# 更新会话状态
result = await db.experiment_sessions.update_one(
{"_id": ObjectId(session_id)},
{
"$set": {
"end_time": end_time,
"duration": duration
}
}
)
if result.modified_count == 0:
raise HTTPException(status_code=400, detail="Failed to update session status")
return {
"message": "Experiment session stopped",
"session_id": session_id,
"duration": duration
}
except Exception as e:
print(f"Error stopping experiment session: {str(e)}")
raise HTTPException(status_code=500, detail=f"Failed to stop experiment session: {str(e)}")
finally:
if redis:
await redis.aclose()
# 获取实验会话历史的路由
@router.get("/lab/experiments/{experiment_id}/sessions")
async def get_experiment_sessions(
experiment_id: str,
current_user: UserModel = Depends(get_current_user)
):
"""获取实验的所有会话记录"""
db = await get_database()
try:
sessions = []
async for session in db.experiment_sessions.find({
"experiment_id": ObjectId(experiment_id)
}).sort("start_time", -1):
sessions.append({
"session_id": str(session["_id"]),
"start_time": session["start_time"],
"end_time": session.get("end_time"),
"duration": session.get("duration")
})
return sessions
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to get experiment session records: {str(e)}")
# 导出实验数据
@router.get("/lab/experiments/{experiment_id}/export")
async def export_experiment_data(
experiment_id: str,
current_user: UserModel = Depends(get_current_user)
):
"""导出实验的所有数据为ZIP文件(包含多个CSV)"""
db = await get_database()
redis = await get_redis()
try:
zip_buffer = BytesIO()
with zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED) as zip_file:
# 获取实验的所有会话
sessions = []
async for session in db.experiment_sessions.find({
"experiment_id": ObjectId(experiment_id)
}):
end_time = session.get("end_time") or datetime.now(timezone.utc)
duration = (end_time - session["start_time"]).total_seconds()
sessions.append({
"id": str(session["_id"]),
"start_time": session["start_time"],
"end_time": end_time,
"devices": session.get("devices", []),
"duration": duration
})
if not sessions:
raise HTTPException(status_code=404, detail="No experiment session data found")
# 写入会话信息到sessions.csv
sessions_output = StringIO()
sessions_writer = csv.writer(sessions_output)
sessions_writer.writerow(['Session ID', 'Start Time', 'End Time', 'Duration (seconds)'])
for session in sessions:
sessions_writer.writerow([
session["id"],
session["start_time"].isoformat(),
session["end_time"].isoformat() if isinstance(session["end_time"], datetime) else "Active",
session["duration"]
])
zip_file.writestr('sessions.csv', sessions_output.getvalue())
# 创建README.txt
readme_content = """Experiment Data Export Instructions:
1. sessions.csv: Contains basic information for all experiment sessions
2. Data for each device is stored in folders named by device serial number
3. Each device folder contains multiple sensor_data_XXX.csv files, each file containing sensor data for a specific time period
4. All timestamps are in UTC timezone"""
zip_file.writestr('README.txt', readme_content)
# 选择db200获取原始数据
await redis.select(200)
# 按设备处理数据
for session in sessions:
for device in session["devices"]:
serial_number = device["serial_number"]
stream_key = f"experiment:{experiment_id}:{serial_number}"
try:
start_ms = int(session["start_time"].timestamp() * 1000)
end_ms = int(session["end_time"].timestamp() * 1000)
stream_data = await redis.xrange(
stream_key,
min=str(start_ms),
max=str(end_ms)
)
# 将数据分批处理,每个文件最多包含100000行数据
MAX_ROWS_PER_FILE = 100000
current_file_rows = 0
file_counter = 1
current_output = StringIO()
current_writer = csv.writer(current_output)
current_writer.writerow(['Session ID', 'Device Time', 'Sensor', 'Value'])
for entry_id, data in stream_data:
stream_timestamp = int(entry_id.split('-')[0])
device_time = datetime.fromtimestamp(
stream_timestamp / 1000,
tz=timezone.utc
)
data_str = data[b'data'] if isinstance(data.get('data'), bytes) else data['data']
if isinstance(data_str, bytes):
data_str = data_str.decode('utf-8')
point_data = json.loads(data_str)
channel_data = point_data["channel_data"]
for sensor_name, values in channel_data.items():
if isinstance(values, list):
for value in values:
current_writer.writerow([
session["id"],
device_time.isoformat(),
sensor_name,
value
])
current_file_rows += 1
else:
current_writer.writerow([
session["id"],
device_time.isoformat(),
sensor_name,
values
])
current_file_rows += 1
# 如果当前文件达到行数限制,保存并创建新文件
if current_file_rows >= MAX_ROWS_PER_FILE:
file_name = f"{serial_number}/session_{session['id']}/sensor_data_{file_counter:03d}.csv"
zip_file.writestr(file_name, current_output.getvalue())
current_output = StringIO()
current_writer = csv.writer(current_output)
current_writer.writerow(['Session ID', 'Device Time', 'Sensor', 'Value'])
current_file_rows = 0
file_counter += 1
# 保存最后一个文件
if current_file_rows > 0:
file_name = f"{serial_number}/session_{session['id']}/sensor_data_{file_counter:03d}.csv"
zip_file.writestr(file_name, current_output.getvalue())
except Exception as e:
print(f"Error processing stream data: {str(e)}")
continue
# 获取实验信息用于文件命名
experiment = await db.experiments.find_one({"_id": ObjectId(experiment_id)})
zip_filename = f"experiment_{experiment['experiment_name']}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.zip"
# 准备ZIP文<50><E69687>下载
zip_buffer.seek(0)
return StreamingResponse(
iter([zip_buffer.getvalue()]),
media_type="application/zip",
headers={
'Content-Disposition': f'attachment; filename="{zip_filename}"'
}
)
except Exception as e:
print(f"Error exporting data: {str(e)}")
raise HTTPException(status_code=500, detail=f"Failed to export data: {str(e)}")
finally:
try:
await redis.aclose()
except Exception as e:
print(f"Error closing Redis connection: {e}")
# 完成实验的路由
@router.post("/lab/experiments/{experiment_id}/complete")
async def complete_experiment(
experiment_id: str,
current_user: UserModel = Depends(get_current_user)
):
"""完成实验,将实验状态标记为已完成"""
db = await get_database()
try:
# 验证实验是否存在
experiment = await db.experiments.find_one({"_id": ObjectId(experiment_id)})
if not experiment:
raise HTTPException(status_code=404, detail="Experiment not found")
# 检查是否有正在进行的会话
active_session = await db.experiment_sessions.find_one({
"experiment_id": ObjectId(experiment_id),
"end_time": None
})
if active_session:
raise HTTPException(status_code=400, detail="Please stop all ongoing experiment sessions first")
# 更新实验状态为已完成
result = await db.experiments.update_one(
{"_id": ObjectId(experiment_id)},
{"$set": {"status": ExperimentStatus.COMPLETED}}
)
if result.modified_count == 0:
raise HTTPException(status_code=400, detail="Failed to update experiment status")
return {"message": "Experiment completed"}
except Exception as e:
print(f"Error completing experiment: {str(e)}")
raise HTTPException(status_code=500, detail=f"Failed to complete experiment: {str(e)}")
# 进入实验的路由,添加状态检查
@router.get("/lab/experiments/{experiment_id}")
async def get_experiment(
experiment_id: str,
current_user: UserModel = Depends(get_current_user)
):
"""获取实验详情,包括状态"""
db = await get_database()
try:
experiment = await db.experiments.find_one({"_id": ObjectId(experiment_id)})
if not experiment:
raise HTTPException(status_code=404, detail="Experiment not found")
# 创建一个新的字典来存储处理后的数据
experiment_dict = {
"_id": str(experiment["_id"]),
"project_id": str(experiment["project_id"]),
"experiment_name": experiment["experiment_name"],
"create_time": experiment["create_time"],
"description": experiment["description"],
"status": experiment.get("status", ExperimentStatus.ACTIVE)
}
# 获取实验会话历史
sessions = []
async for session in db.experiment_sessions.find({"experiment_id": ObjectId(experiment_id)}):
session_dict = {
"id": str(session["_id"]),
"experiment_id": str(session["experiment_id"]),
"start_time": session["start_time"],
"end_time": session.get("end_time"),
"duration": session.get("duration")
}
sessions.append(session_dict)
experiment_dict["sessions"] = sessions
return experiment_dict
except Exception as e:
print(f"Error getting experiment: {str(e)}")
raise HTTPException(status_code=500, detail=f"Failed to get experiment: {str(e)}")
# 删除实验的路由
@router.delete("/lab/experiments/{experiment_id}")
async def delete_experiment(
experiment_id: str,
current_user: UserModel = Depends(get_current_user)
):
"""删除实验及其相关数据"""
db = await get_database()
redis = await get_redis()
try:
# 验证实验是否存在且属于当前用户的项目
experiment = await db.experiments.find_one({"_id": ObjectId(experiment_id)})
if not experiment:
raise HTTPException(status_code=404, detail="Experiment not found")
project = await db.projects.find_one({
"_id": experiment["project_id"],
"user_id": current_user.id
})
if not project:
raise HTTPException(status_code=403, detail="Unauthorized to delete this experiment")
# 删除实验相关数据
await db.experiment_sessions.delete_many({"experiment_id": ObjectId(experiment_id)})
await db.experiment_devices.delete_many({"experiment_id": ObjectId(experiment_id)})
await db.experiments.delete_one({"_id": ObjectId(experiment_id)})
# 删除Redis中的实验报告
await redis.select(201)
await redis.delete(f"experiment_report:{experiment_id}")
return {"message": "Experiment successfully deleted"}
except Exception as e:
print(f"Error deleting experiment: {str(e)}")
raise HTTPException(status_code=500, detail=f"Failed to delete experiment: {str(e)}")
finally:
try:
await redis.aclose()
except Exception as e:
print(f"Error closing Redis connection: {e}")
# 添加公式模型
class FormulaModel(BaseModel):
data_name: str
data_unit: str
formula: str
# 添加公式请求模型
class FormulaCreate(BaseModel):
sensor_name: str
data_name: str
data_unit: str
formula: str
@router.post("/lab/experiments/{experiment_id}/devices/{device_id}/formulas")
async def add_sensor_formula(
experiment_id: str,
device_id: str,
formula: FormulaCreate,
current_user: UserModel = Depends(get_current_user)
):
"""为实验设备的传感器添加计算公式"""
db = await get_database()
try:
# 验证实验设备是否存在,使用 user_device_id 查询
device = await db.experiment_devices.find_one({
"user_device_id": ObjectId(device_id),
"experiment_id": ObjectId(experiment_id)
})
if not device:
raise HTTPException(status_code=404, detail="Experiment device not found")
# 查找对应的传感器
sensor_found = False
sensors = device.get("sensors", [])
for sensor in sensors:
if sensor["sensor_name"] == formula.sensor_name:
# 初始化或获取现有公式列表
if "formulas" not in sensor:
sensor["formulas"] = []
# 添加新公式
sensor["formulas"].append({
"data_name": formula.data_name,
"data_unit": formula.data_unit,
"formula": formula.formula
})
sensor_found = True
break
if not sensor_found:
raise HTTPException(status_code=404, detail="Sensor not found")
# 更新设备文档,使用 user_device_id 更新
result = await db.experiment_devices.update_one(
{
"user_device_id": ObjectId(device_id),
"experiment_id": ObjectId(experiment_id)
},
{"$set": {"sensors": sensors}}
)
if result.modified_count == 0:
raise HTTPException(status_code=400, detail="Failed to add formula")
return {"message": "Formula added successfully"}
except Exception as e:
print(f"Error adding formula: {str(e)}")
raise HTTPException(status_code=500, detail=f"Failed to add formula: {str(e)}")
+238
View File
@@ -0,0 +1,238 @@
from fastapi import APIRouter, HTTPException, Depends
from typing import List, Optional
from datetime import datetime, timezone
from bson import ObjectId
from pydantic import BaseModel, Field
from ..cores.db import PyObjectId, get_database
from .login import get_current_user, UserModel
router = APIRouter()
class ExperimentDeviceModel(BaseModel):
"""实验-设备关联模型"""
id: Optional[PyObjectId] = Field(alias="_id", default=None)
experiment_id: PyObjectId
user_device_id: PyObjectId
class Config:
populate_by_name = True
arbitrary_types_allowed = True
json_encoders = {ObjectId: str}
@router.post("/lab/experiments/{experiment_id}/devices/{user_device_id}")
async def add_device_to_experiment(
experiment_id: str,
user_device_id: str,
current_user: UserModel = Depends(get_current_user)
):
"""将设备添加到特定实验中"""
db = await get_database()
try:
# 验证实验是否存在且属于当前用户
experiment = await db.experiments.find_one({"_id": ObjectId(experiment_id)})
if not experiment:
raise HTTPException(status_code=404, detail="Experiment not found")
# 验证项目所有权
project = await db.projects.find_one({
"_id": experiment["project_id"],
"user_id": current_user.id
})
if not project:
raise HTTPException(status_code=403, detail="Unauthorized access to this experiment")
# 验证用户设备是否存在且状态为active
user_device = await db.user_devices.find_one({
"_id": ObjectId(user_device_id),
"user_id": current_user.id,
"status": "active" # 添加状态检查
})
if not user_device:
raise HTTPException(status_code=404, detail="Active device not found or unauthorized access")
# 检查是否已经添加过这个设备
existing = await db.experiment_devices.find_one({
"experiment_id": ObjectId(experiment_id),
"user_device_id": ObjectId(user_device_id)
})
if existing:
raise HTTPException(status_code=400, detail="This device has already been added to the experiment")
# 创建关联记录,直接复制用户设备的所有信息
experiment_device_dict = {
"experiment_id": ObjectId(experiment_id),
"user_device_id": ObjectId(user_device_id),
**{k:v for k,v in user_device.items() if k not in ['_id', 'user_id']}
}
result = await db.experiment_devices.insert_one(experiment_device_dict)
return {
"message": "Device added successfully",
"experiment_device_id": str(result.inserted_id)
}
except Exception as e:
print(f"Error adding device to experiment: {str(e)}")
raise HTTPException(status_code=500, detail=f"Failed to add device: {str(e)}")
@router.get("/lab/experiments/{experiment_id}/devices")
async def get_experiment_devices(
experiment_id: str,
current_user: UserModel = Depends(get_current_user)
):
"""获取实验中的所有设备信息"""
db = await get_database()
try:
# 验证实验是否存在且属于当前用户
experiment = await db.experiments.find_one({"_id": ObjectId(experiment_id)})
if not experiment:
raise HTTPException(status_code=404, detail="Experiment not found")
# 验证项目所有权
project = await db.projects.find_one({
"_id": experiment["project_id"],
"user_id": current_user.id
})
if not project:
raise HTTPException(status_code=403, detail="Unauthorized access to this experiment")
# 获取实验关联的所有设备
devices = []
async for device in db.experiment_devices.find({"experiment_id": ObjectId(experiment_id)}):
# 确保返回传感器的公式信息
device_info = {
"_id": str(device["_id"]),
"experiment_id": str(device["experiment_id"]),
"user_device_id": str(device["user_device_id"]),
"device_id": str(device["device_id"]),
"serial_number": device["serial_number"],
"device_name": device["device_name"],
"device_type": device["device_type"],
"device_number": device["device_number"],
"sensors": []
}
# 保留完整的传感器信息,包括 index
for sensor in device.get("sensors", []):
sensor_info = {
"index": sensor["index"], # 保留 index
"sensor_name": sensor["sensor_name"],
"sensor_type": sensor["sensor_type"],
"unit": sensor["unit"]
}
# 如果有公式信息,也添加进去
if "formulas" in sensor:
sensor_info["formulas"] = sensor["formulas"]
device_info["sensors"].append(sensor_info)
devices.append(device_info)
return devices
except Exception as e:
print(f"Error getting experiment devices: {str(e)}")
raise HTTPException(status_code=500, detail=f"Failed to get experiment devices: {str(e)}")
@router.get("/lab/experiments/{experiment_id}/devices/public")
async def get_experiment_devices_public(
experiment_id: str
):
"""获取实验中的所有设备信息(公开接口)"""
db = await get_database()
try:
# 验证实验是否存在
experiment = await db.experiments.find_one({"_id": ObjectId(experiment_id)})
if not experiment:
raise HTTPException(status_code=404, detail="Experiment not found")
# 获取实验关联的所有设备
devices = []
async for device in db.experiment_devices.find({"experiment_id": ObjectId(experiment_id)}):
# 确保返回传感器的公式信息
device_info = {
"_id": str(device["_id"]),
"experiment_id": str(device["experiment_id"]),
"user_device_id": str(device["user_device_id"]),
"device_id": str(device["device_id"]),
"serial_number": device["serial_number"],
"device_name": device["device_name"],
"device_type": device["device_type"],
"device_number": device["device_number"],
"sensors": []
}
# 保留完整的传感器信息,包括 index
for sensor in device.get("sensors", []):
sensor_info = {
"index": sensor["index"], # 保留 index
"sensor_name": sensor["sensor_name"],
"sensor_type": sensor["sensor_type"],
"unit": sensor["unit"]
}
# 如果有公式信息,也添加进去
if "formulas" in sensor:
sensor_info["formulas"] = sensor["formulas"]
device_info["sensors"].append(sensor_info)
devices.append(device_info)
return devices
except Exception as e:
print(f"Error getting experiment devices: {str(e)}")
raise HTTPException(status_code=500, detail=f"Failed to get experiment devices: {str(e)}")
@router.delete("/lab/experiments/{experiment_id}/devices/{user_device_id}")
async def remove_device_from_experiment(
experiment_id: str,
user_device_id: str,
current_user: UserModel = Depends(get_current_user)
):
"""
从实验中移除特定设备
参数:
experiment_id: 实验ID
user_device_id: 用户设备ID
current_user: 当前登录用户
"""
db = await get_database()
try:
# 验证实验是否存在且属于当前用户
experiment = await db.experiments.find_one({"_id": ObjectId(experiment_id)})
if not experiment:
raise HTTPException(status_code=404, detail="Experiment not found")
# 验证项目所有权
project = await db.projects.find_one({
"_id": experiment["project_id"],
"user_id": current_user.id
})
if not project:
raise HTTPException(status_code=403, detail="Unauthorized to delete this experiment")
# 删除实验-设备关联
result = await db.experiment_devices.delete_one({
"experiment_id": ObjectId(experiment_id),
"user_device_id": ObjectId(user_device_id)
})
if result.deleted_count == 0:
raise HTTPException(status_code=404, detail="Device association not found")
return {
"message": "Device removed from experiment",
"experiment_id": experiment_id,
"user_device_id": user_device_id
}
except Exception as e:
print(f"Error removing device from experiment: {str(e)}")
raise HTTPException(status_code=500, detail=f"Failed to remove device: {str(e)}")
+515
View File
@@ -0,0 +1,515 @@
from fastapi import APIRouter, HTTPException, Depends
from typing import Dict, Any
import json
from datetime import datetime, timezone
import asyncio
from motor.motor_asyncio import AsyncIOMotorClient
from bson import ObjectId
from redis import asyncio as aioredis
from openai import OpenAI
from pydantic import BaseModel
from ..cores.config import (
MONGODB_URL, REDIS_URL, DEEPSEEK_API_CONFIG,
REDIS_DB_MAP, TaskStatus, THREAD_POOL_CONFIG, format_response,
exp_analysis_thread_pool
)
from ..cores.db import get_database, get_redis
from .login import get_current_user, UserModel
router = APIRouter(
prefix="/lab",
tags=["experiment_report"]
)
# DeepSeek API Configuration
client = OpenAI(
base_url=DEEPSEEK_API_CONFIG["base_url"],
api_key=DEEPSEEK_API_CONFIG["api_key"]
)
async def analyze_experiment_data(experiment_info):
system_prompt = """
You are an AI assistant tasked with analyzing experimental data.
Generate a comprehensive experiment analysis report in JSON format.
The JSON structure must strictly follow the provided template.
"""
user_prompt = f"""Analyze the experiment data based on the following information:
Experiment data: {json.dumps(experiment_info['sessions'], ensure_ascii=False)}
Generate a JSON response with the following structure:
{{
"Experiment Analysis Report": {{
"1. Basic Information": {{
"Total Sessions": "{experiment_info['total_sessions']}",
"Total Duration": "{experiment_info['total_duration']} seconds",
"Data Points": "{experiment_info['total_points']}",
"Devices number": "{experiment_info['device_stats']['total_devices']}",
"Sensors number": "{experiment_info['device_stats']['total_sensors']}"
}},
"2. Session Data Analysis": {{
"[Session ID]": {{
"Duration": "[Duration] seconds",
"Data Points": "[data_points]"
}},
// Repeat for each session dynamically
}},
"3. Key Findings": [
"[Finding 1]",
"[Finding 2]",
"[Finding 3]"
],
"4. Recommendations": [
"[Recommendation 1]",
"[Recommendation 2]",
"[Recommendation 3]"
]
}}
}}
"""
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt}
]
try:
response = client.chat.completions.create(
model="deepseek-chat",
messages=messages,
response_format={'type': 'json_object'}
)
return json.loads(response.choices[0].message.content)
except Exception as e:
print(f"Error calling DeepSeek API: {e}")
return None
def run_experiment_in_thread(experiment_id: str):
"""在独立线程中运行分析任务"""
# 创建新的事件循环
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
# 在新的事件循环中运行异步任务
loop.run_until_complete(process_experiment_analysis(experiment_id))
except Exception as e:
print(f"Error in analysis thread: {str(e)}")
finally:
try:
# 清理所有待处理的任务
pending = asyncio.all_tasks(loop)
for task in pending:
task.cancel()
# 运行直到所有任务完成
if pending:
loop.run_until_complete(asyncio.gather(*pending, return_exceptions=True))
except Exception as e:
print(f"Error cleaning up tasks: {str(e)}")
finally:
loop.close()
async def process_experiment_analysis(experiment_id: str):
"""后台处理实验分析任务"""
print(f"\n=== 开始实验分析 ===")
print(f"实验ID: {experiment_id}")
# 创建新的数据库和Redis连接
mongo_client = AsyncIOMotorClient(MONGODB_URL)
db = mongo_client["lab"]
redis = await aioredis.from_url(REDIS_URL, encoding="utf-8", decode_responses=True)
try:
await redis.select(REDIS_DB_MAP["experiment_analysis"])
status_key = f"experiment_analysis_status:{experiment_id}"
# 更新状态为进行中
status_data = {
"status": "processing",
"start_time": datetime.now(timezone.utc).isoformat()
}
await redis.set(status_key, json.dumps(status_data))
# 获取实验基本信息
experiment = await db.experiments.find_one({"_id": ObjectId(experiment_id)})
if not experiment:
raise ValueError("实验不存在")
# 收集会话数据
sessions = []
total_points = 0
total_duration = 0
devices_set = set()
sensors_count = 0
async for session in db.experiment_sessions.find({
"experiment_id": ObjectId(experiment_id)
}):
# 计算会话持续时间
end_time = session.get("end_time") or datetime.now(timezone.utc)
duration = (end_time - session["start_time"]).total_seconds()
total_duration += duration
# 转换时间戳为毫秒
start_ms = int(session["start_time"].timestamp() * 1000)
end_ms = int(end_time.timestamp() * 1000)
# 统计设备和传感器数量
session_devices = session.get("devices", [])
for device in session_devices:
devices_set.add(device["serial_number"])
sensors_count += len(device.get("sensors", []))
# 获取会话数据点数
session_points = 0
for device in session_devices:
stream_key = f"experiment:{experiment_id}:{device['serial_number']}"
await redis.select(200)
data_points = await redis.xrange(
stream_key,
min=str(start_ms),
max=str(end_ms)
)
session_points += len(data_points)
total_points += session_points
sessions.append({
"session_id": str(session["_id"]),
"duration": duration,
"data_points": session_points
})
if not sessions:
raise ValueError("没有找到实验会话数据")
experiment_info = {
"experiment_name": experiment["experiment_name"],
"total_sessions": len(sessions),
"total_duration": total_duration,
"total_points": total_points,
"device_stats": {
"total_devices": len(devices_set),
"total_sensors": sensors_count
},
"sessions": sessions
}
# 执行分析
analysis_result = await analyze_experiment_data(experiment_info)
if analysis_result:
# 保存分析结果
await redis.select(REDIS_DB_MAP["experiment_analysis"])
report_key = f"experiment_report:{experiment_id}"
await redis.set(report_key, json.dumps(analysis_result))
# 更新状态为完成
status_data = {
"status": "completed",
"completion_time": datetime.now(timezone.utc).isoformat()
}
await redis.set(status_key, json.dumps(status_data))
print("分析报告已保存")
else:
print("分析失败")
status_data = {
"status": "failed",
"error": "Failed to generate analysis result",
"completion_time": datetime.now(timezone.utc).isoformat()
}
await redis.set(status_key, json.dumps(status_data))
except Exception as e:
print(f"Error in experiment analysis: {str(e)}")
try:
await redis.select(REDIS_DB_MAP["experiment_analysis"])
status_key = f"experiment_analysis_status:{experiment_id}"
status_data = {
"status": "failed",
"error": str(e),
"completion_time": datetime.now(timezone.utc).isoformat()
}
await redis.set(status_key, json.dumps(status_data))
except Exception as redis_error:
print(f"Error updating Redis status: {redis_error}")
finally:
try:
await redis.aclose()
mongo_client.close()
except Exception as e:
print(f"Error closing connections: {e}")
@router.get("/experiments/{experiment_id}/analyze")
async def analyze_data(
experiment_id: str,
current_user: UserModel = Depends(get_current_user)
):
"""启动实验数据分析"""
db = await get_database()
redis = await get_redis()
try:
await redis.select(REDIS_DB_MAP["experiment_analysis"])
status_key = f"experiment_analysis_status:{experiment_id}"
current_status = await redis.get(status_key)
if current_status:
status_data = json.loads(current_status)
if status_data.get("status") == "processing":
return format_response({
"experiment_id": experiment_id,
"start_time": status_data.get("start_time")
}, "实验分析任务正在进行中")
# 记录分析开始状态
status_data = {
"status": "processing",
"start_time": datetime.now(timezone.utc).isoformat()
}
await redis.set(status_key, json.dumps(status_data))
# 使用线程池提交任务
exp_analysis_thread_pool.submit(run_experiment_in_thread, experiment_id)
return format_response({
"experiment_id": experiment_id,
"start_time": status_data["start_time"]
}, "实验分析任务已启动")
except Exception as e:
print(f"Error starting experiment analysis: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
finally:
try:
await redis.aclose()
except Exception as e:
print(f"Error closing Redis connection: {e}")
@router.get("/experiments/{experiment_id}/analysis_status")
async def get_experiment_analysis_status(
experiment_id: str,
current_user: UserModel = Depends(get_current_user)
):
"""获取实验分析任务的状态"""
redis = await get_redis()
try:
await redis.select(REDIS_DB_MAP["experiment_analysis"])
status_key = f"experiment_analysis_status:{experiment_id}"
status_data = await redis.get(status_key)
if not status_data:
return format_response({
"experiment_id": experiment_id,
"status": "not_started"
})
return format_response(json.loads(status_data))
except Exception as e:
print(f"Error getting analysis status: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
finally:
try:
await redis.aclose()
except Exception as e:
print(f"Error closing Redis connection: {e}")
@router.get("/experiments/{experiment_id}/report")
async def get_saved_report(
experiment_id: str,
current_user: UserModel = Depends(get_current_user)
):
"""从 Redis 读取已保存的实验报告"""
redis = await get_redis()
try:
await redis.select(REDIS_DB_MAP["experiment_analysis"])
report_key = f"experiment_report:{experiment_id}"
existing_report = await redis.get(report_key)
if not existing_report:
raise HTTPException(status_code=404, detail="No saved experiment report found")
return format_response(json.loads(existing_report))
except Exception as e:
print(f"Error getting experiment report: {e}")
raise HTTPException(status_code=500, detail=str(e))
finally:
try:
await redis.aclose()
except Exception as e:
print(f"Error closing Redis connection: {e}")
class QuestionModel(BaseModel):
"""问题模型"""
question: str
@router.post("/experiments/{experiment_id}/qa")
async def ask_experiment_question(
experiment_id: str,
question: QuestionModel,
current_user: UserModel = Depends(get_current_user)
):
"""向实验报告提问(异步)"""
db = await get_database()
redis = await get_redis()
try:
# 验证实验是否存在
experiment = await db.experiments.find_one({"_id": ObjectId(experiment_id)})
if not experiment:
raise HTTPException(status_code=404, detail="实验不存在")
# 生成任务ID
task_id = str(ObjectId())
# 创建后台任务
asyncio.create_task(process_experiment_question(
task_id=task_id,
experiment_id=experiment_id,
question=question.question,
experiment=experiment
))
return {
"task_id": task_id,
"status": TaskStatus.PENDING,
"message": "问题已提交,正在处理中"
}
except Exception as e:
print(f"Error in ask_experiment_question: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
finally:
try:
await redis.aclose()
except Exception as e:
print(f"Error closing Redis connection: {e}")
async def process_experiment_question(task_id: str, experiment_id: str, question: str, experiment: dict):
"""处理实验问答的后台任务"""
redis = await get_redis()
try:
# 更新任务状态为处理中
await redis.select(208) # 使用db208存储任务状态
await redis.hset(
f"task:{task_id}",
mapping={
"status": TaskStatus.PROCESSING,
"experiment_id": experiment_id,
"question": question
}
)
# 从Redis获取分析报告
await redis.select(201)
report_key = f"experiment_report:{experiment_id}"
report_data = await redis.get(report_key)
if not report_data:
raise Exception("实验分析报告不存在,请先进行分析")
# 构建系统提示和用户提示
system_prompt = """
你是一个专业的实验助手,负责回答关于实验报告的问题。
你应该基于实验报告提供准确、专业的回答。
回答应当简洁明了,并尽可能引用报告中的具体内容。
"""
# 构建上下文
context = {
"实验名称": experiment.get("experiment_name", "未知实验"),
"分析报告": json.loads(report_data)
}
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": f"基于以下实验报告回答问题:\n\n实验信息:{json.dumps(context, ensure_ascii=False)}\n\n问题:{question}"}
]
# 调用DeepSeek API获取回答
response = client.chat.completions.create(
model="deepseek-chat",
messages=messages
)
answer = response.choices[0].message.content
# 保存对话历史到Redis db207
await redis.select(207)
chat_history_key = f"experiment_chat_history:{experiment_id}"
# 获取现有历史记录
existing_history = await redis.get(chat_history_key)
history = json.loads(existing_history) if existing_history else []
# 添加新的对话
history.append({
"question": question,
"answer": answer,
"timestamp": datetime.now(timezone.utc).isoformat()
})
# 保存更新后的历史记录
await redis.set(chat_history_key, json.dumps(history))
# 更新任务状态为完成
await redis.select(208)
await redis.hset(
f"task:{task_id}",
mapping={
"status": TaskStatus.COMPLETED,
"answer": answer,
"experiment_name": experiment.get("experiment_name")
}
)
except Exception as e:
print(f"Error processing question: {e}")
# 更新任务状态为失败
await redis.select(208)
await redis.hset(
f"task:{task_id}",
mapping={
"status": TaskStatus.FAILED,
"error": str(e)
}
)
finally:
try:
await redis.aclose()
except Exception as e:
print(f"Error closing Redis connection: {e}")
@router.get("/experiments/{experiment_id}/qa/history")
async def get_experiment_qa_history(
experiment_id: str,
current_user: UserModel = Depends(get_current_user)
):
"""获取实验问答历史记录"""
redis = await get_redis()
try:
# 从Redis db207获取对话历史
await redis.select(207)
chat_history_key = f"experiment_chat_history:{experiment_id}"
history_data = await redis.get(chat_history_key)
if not history_data:
return []
return json.loads(history_data)
except Exception as e:
print(f"Error getting QA history: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
finally:
try:
await redis.aclose()
except Exception as e:
print(f"Error closing Redis connection: {e}")
+112
View File
@@ -0,0 +1,112 @@
from fastapi import APIRouter, HTTPException, Depends
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
from datetime import datetime, timedelta, timezone
import jwt
from jwt.exceptions import ExpiredSignatureError, InvalidSignatureError
from pydantic import BaseModel, EmailStr, Field
from typing import Optional
from bson import ObjectId
from ..cores.db import PyObjectId, get_database
# JWT Configuration
SECRET_KEY = "Obscura@2024"
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 60 * 24 * 30
# OAuth2 scheme
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
# Router
router = APIRouter(
prefix="/lab",
tags=["authentication"]
)
# Pydantic models
class UserModel(BaseModel):
"""
用户模型
包含用户基本信息:用户名、密码、邮箱、姓名、所属机构
"""
id: Optional[PyObjectId] = Field(alias="_id", default=None)
username: str
password: str
email: EmailStr
name: str
institution: str
class Config:
populate_by_name = True
arbitrary_types_allowed = True
json_encoders = {ObjectId: str}
# Security functions
def create_access_token(data: dict):
to_encode = data.copy()
expire = datetime.now(timezone.utc) + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
to_encode.update({"exp": expire})
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
return encoded_jwt
async def get_current_user(token: str = Depends(oauth2_scheme)):
credentials_exception = HTTPException(
status_code=401,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
)
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
username: str = payload.get("sub")
if username is None:
raise credentials_exception
except (ExpiredSignatureError, InvalidSignatureError) as e:
raise HTTPException(
status_code=401,
detail=str(e),
headers={"WWW-Authenticate": "Bearer"},
)
db = await get_database()
user = await db.users.find_one({"username": username})
if user is None:
raise credentials_exception
return UserModel(**user)
# Auth routes
@router.post("/register")
async def register(user_data: UserModel):
"""
用户注册接口
"""
db = await get_database()
# 检查用户名是否已存在
if await db.users.find_one({"username": user_data.username}):
raise HTTPException(status_code=400, detail="Username has already been registered")
# 创建用户文档
user_dict = user_data.model_dump(exclude={"id"})
try:
result = await db.users.insert_one(user_dict)
return {
"message": "Registration successful",
"id": str(result.inserted_id)
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Database Error: {str(e)}")
@router.post("/token")
async def login(form_data: OAuth2PasswordRequestForm = Depends()):
"""
用户登录接口
"""
db = await get_database()
user = await db.users.find_one({"username": form_data.username})
# 直接比较明文密码
if not user or form_data.password != user["password"]:
raise HTTPException(status_code=400, detail="Incorrect username or password")
access_token = create_access_token(data={"sub": user["username"]})
return {"access_token": access_token, "token_type": "bearer"}
+141
View File
@@ -0,0 +1,141 @@
from fastapi import APIRouter, HTTPException, Depends
from pydantic import BaseModel, Field
from datetime import datetime, timezone
import json
from ..cores.db import get_redis
from .login import get_current_user, UserModel
router = APIRouter(
prefix="/lab",
tags=["memo"]
)
class MemoModel(BaseModel):
"""项目备忘录模型"""
content: str
create_time: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
@router.post("/projects/{project_id}/memo")
async def save_project_memo(
project_id: str,
memo: MemoModel,
current_user: UserModel = Depends(get_current_user)
):
"""保存项目备忘录"""
redis = await get_redis()
try:
# 选择 db205
await redis.select(205)
memo_key = f"memo:{project_id}"
# 保存备忘录内容和创建时间
memo_data = {
"content": memo.content,
"create_time": memo.create_time.isoformat()
}
await redis.set(memo_key, json.dumps(memo_data))
return {"message": "备忘录保存成功"}
except Exception as e:
print(f"Error saving memo: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
finally:
try:
await redis.aclose()
except Exception as e:
print(f"Error closing Redis connection: {e}")
@router.get("/projects/{project_id}/memo")
async def get_project_memo(
project_id: str,
current_user: UserModel = Depends(get_current_user)
):
"""获取项目备忘录"""
redis = await get_redis()
try:
# 选择 db205
await redis.select(205)
memo_key = f"memo:{project_id}"
# 获取备忘录
memo_data = await redis.get(memo_key)
if not memo_data:
return {"content": "", "create_time": None}
return json.loads(memo_data)
except Exception as e:
print(f"Error getting memo: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
finally:
try:
await redis.aclose()
except Exception as e:
print(f"Error closing Redis connection: {e}")
@router.post("/experiments/{experiment_id}/memo")
async def save_experiment_memo(
experiment_id: str,
memo: MemoModel,
current_user: UserModel = Depends(get_current_user)
):
"""保存实验备忘录"""
redis = await get_redis()
try:
# 选择 db206
await redis.select(206)
memo_key = f"memo:{experiment_id}"
# 保存备忘录内容和创建时间
memo_data = {
"content": memo.content,
"create_time": memo.create_time.isoformat()
}
await redis.set(memo_key, json.dumps(memo_data))
return {"message": "Experiment memo saved successfully"}
except Exception as e:
print(f"Error saving experiment memo: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
finally:
try:
await redis.aclose()
except Exception as e:
print(f"Error closing Redis connection: {e}")
@router.get("/experiments/{experiment_id}/memo")
async def get_experiment_memo(
experiment_id: str,
current_user: UserModel = Depends(get_current_user)
):
"""获取实验备忘录"""
redis = await get_redis()
try:
# 选择 db206
await redis.select(206)
memo_key = f"memo:{experiment_id}"
# 获取备忘录
memo_data = await redis.get(memo_key)
if not memo_data:
return {"content": "", "create_time": None}
return json.loads(memo_data)
except Exception as e:
print(f"Error getting experiment memo: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
finally:
try:
await redis.aclose()
except Exception as e:
print(f"Error closing Redis connection: {e}")
+410
View File
@@ -0,0 +1,410 @@
from fastapi import APIRouter, Depends, HTTPException, UploadFile, File
from typing import List
from bson import ObjectId
import os
import asyncio
from datetime import datetime, timezone
import json
from ..cores.db import get_database, get_redis
from ..cores.config import UPLOAD_PATH, TaskStatus
from ..routers.login import get_current_user, UserModel
from ..models.paper import (
ReferenceModel, QuestionModel,
process_batch_analysis, process_reference_question,
run_analysis_in_thread
)
router = APIRouter(
prefix="/lab",
tags=["paper"]
)
@router.delete("/projects/{project_id}/references/{reference_id}")
async def delete_reference(
project_id: str,
reference_id: str,
current_user: UserModel = Depends(get_current_user)
):
"""删除项目文献"""
db = await get_database()
try:
# 验证项目所有权
project = await db.projects.find_one({
"_id": ObjectId(project_id),
"user_id": current_user.id
})
if not project:
raise HTTPException(status_code=403, detail="Unauthorized to access this project")
# 获取文献信息
reference = await db.references.find_one({"_id": ObjectId(reference_id)})
if not reference:
raise HTTPException(status_code=404, detail="Reference not found")
# 删除文件
if os.path.exists(reference["reference_link"]):
os.remove(reference["reference_link"])
# 删除数据库记录
await db.references.delete_one({"_id": ObjectId(reference_id)})
return {"message": "Reference successfully deleted"}
except Exception as e:
print(f"Error deleting reference: {str(e)}")
raise HTTPException(status_code=500, detail=f"Failed to delete reference: {str(e)}")
@router.get("/projects/{project_id}/references")
async def get_project_references(
project_id: str,
current_user: UserModel = Depends(get_current_user)
):
"""获取项目的文献列表"""
db = await get_database()
try:
# 查找项目下的所有文献
references_cursor = db.references.find({
"project_id": ObjectId(project_id)
})
references = []
async for ref in references_cursor:
references.append({
"_id": str(ref["_id"]),
"project_id": str(ref["project_id"]),
"reference_link": ref["reference_link"],
"reference_title": ref["reference_title"],
"upload_time": ref["upload_time"]
})
return references
except Exception as e:
print(f"Error getting project references: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@router.get("/references/{reference_id}/report")
async def get_reference_report(
reference_id: str,
current_user: UserModel = Depends(get_current_user)
):
"""从 Redis db203 读取已保存的文献报告"""
redis = await get_redis()
try:
await redis.select(203)
report_key = f"reference_report:{reference_id}"
existing_report = await redis.get(report_key)
if not existing_report:
raise HTTPException(status_code=404, detail="No saved reference report found")
return json.loads(existing_report)
except Exception as e:
print(f"Error getting reference report: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
finally:
try:
await redis.aclose()
except Exception as e:
print(f"Error closing Redis connection: {e}")
@router.get("/references/{project_id}/analyze_report")
async def analyze_reference_summary_report(
project_id: str,
current_user: UserModel = Depends(get_current_user)
):
"""分析文献数据"""
db = await get_database()
redis = await get_redis()
try:
await redis.select(204)
status_key = f"reference_analysis_status:{project_id}"
current_status = await redis.get(status_key)
if current_status:
status_data = json.loads(current_status)
if status_data.get("status") == "processing":
return {
"message": "文献分析任务正在进行中",
"status": "processing",
"project_id": project_id,
"start_time": status_data.get("start_time")
}
reference_cursor = db.references.find({
"project_id": ObjectId(project_id)
})
reference_ids = []
async for reference in reference_cursor:
reference_ids.append(str(reference["_id"]))
if not reference_ids:
raise HTTPException(status_code=404, detail="No references found for this project")
status_data = {
"status": "processing",
"start_time": datetime.now(timezone.utc).isoformat(),
"total_references": len(reference_ids),
"completed_references": 0
}
await redis.set(status_key, json.dumps(status_data))
asyncio.create_task(run_analysis_in_thread(project_id, reference_ids))
return {
"message": "文献分析任务已启动",
"status": "processing",
"project_id": project_id,
"start_time": status_data["start_time"],
"total_references": len(reference_ids)
}
except Exception as e:
print(f"Error starting reference analysis: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
finally:
try:
await redis.aclose()
except Exception as e:
print(f"Error closing Redis connection: {e}")
@router.get("/references/{project_id}/analysis_status")
async def get_reference_analysis_status(
project_id: str,
current_user: UserModel = Depends(get_current_user)
):
"""获取文献分析任务的状态"""
redis = await get_redis()
try:
await redis.select(204)
status_key = f"reference_analysis_status:{project_id}"
status_data = await redis.get(status_key)
if not status_data:
return {
"status": "not_started",
"project_id": project_id
}
return json.loads(status_data)
except Exception as e:
print(f"Error getting analysis status: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
finally:
try:
await redis.aclose()
except Exception as e:
print(f"Error closing Redis connection: {e}")
@router.get("/references/{project_id}/summary_report")
async def get_reference_summary_report(
project_id: str,
current_user: UserModel = Depends(get_current_user)
):
"""从 Redis db204 读取已保存的文献报告"""
redis = await get_redis()
try:
await redis.select(204)
report_key = f"reference_summary_report:{project_id}"
existing_report = await redis.get(report_key)
if not existing_report:
raise HTTPException(status_code=404, detail="No saved reference report found")
return json.loads(existing_report)
except Exception as e:
print(f"Error getting reference report: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
finally:
try:
await redis.aclose()
except Exception as e:
print(f"Error closing Redis connection: {e}")
@router.post("/references/{reference_id}/qa")
async def ask_reference_question(
reference_id: str,
question: QuestionModel,
current_user: UserModel = Depends(get_current_user)
):
"""向文献提问(异步)"""
db = await get_database()
redis = await get_redis()
try:
reference = await db.references.find_one({"_id": ObjectId(reference_id)})
if not reference:
raise HTTPException(status_code=404, detail="文献不存在")
task_id = str(ObjectId())
asyncio.create_task(process_reference_question(
task_id=task_id,
reference_id=reference_id,
question=question.question,
reference=reference
))
return {
"task_id": task_id,
"status": TaskStatus.PENDING,
"message": "问题已提交,正在处理中"
}
except Exception as e:
print(f"Error in ask_reference_question: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
finally:
try:
await redis.aclose()
except Exception as e:
print(f"Error closing Redis connection: {e}")
@router.get("/task/{task_id}")
async def get_task_status(task_id: str):
"""获取任务状态和结果"""
redis = await get_redis()
try:
await redis.select(208)
task_data = await redis.hgetall(f"task:{task_id}")
if not task_data:
raise HTTPException(status_code=404, detail="任务不存在")
response = {
"task_id": task_id,
"status": task_data.get("status", TaskStatus.PENDING)
}
if task_data.get("status") == TaskStatus.COMPLETED:
response.update({
"answer": task_data.get("answer"),
"reference_title": task_data.get("reference_title")
})
elif task_data.get("status") == TaskStatus.FAILED:
response.update({
"error": task_data.get("error")
})
return response
except Exception as e:
print(f"Error getting task status: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
finally:
try:
await redis.aclose()
except Exception as e:
print(f"Error closing Redis connection: {e}")
@router.get("/references/{reference_id}/qa/history")
async def get_reference_qa_history(
reference_id: str,
current_user: UserModel = Depends(get_current_user)
):
"""获取文献问答历史记录"""
redis = await get_redis()
try:
await redis.select(207)
chat_history_key = f"chat_history:{reference_id}"
history_data = await redis.get(chat_history_key)
if not history_data:
return []
return json.loads(history_data)
except Exception as e:
print(f"Error getting QA history: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
finally:
try:
await redis.aclose()
except Exception as e:
print(f"Error closing Redis connection: {e}")
@router.post("/projects/{project_id}/references/batch")
async def batch_upload_references(
project_id: str,
files: List[UploadFile] = File(...),
current_user: UserModel = Depends(get_current_user)
):
"""批量上传项目相关文献"""
db = await get_database()
try:
project = await db.projects.find_one({
"_id": ObjectId(project_id),
"user_id": current_user.id
})
if not project:
raise HTTPException(status_code=404, detail="Project not found or unauthorized access")
uploaded_references = []
for file in files:
allowed_types = ["application/pdf", "application/msword",
"application/vnd.openxmlformats-officedocument.wordprocessingml.document"]
if file.content_type not in allowed_types:
continue
os.makedirs(UPLOAD_PATH, exist_ok=True)
file_extension = os.path.splitext(file.filename)[1]
safe_filename = f"{project_id}_{datetime.now().strftime('%Y%m%d_%H%M%S')}{file_extension}"
file_path = os.path.join(UPLOAD_PATH, safe_filename)
with open(file_path, "wb") as buffer:
content = await file.read()
buffer.write(content)
reference = {
"project_id": ObjectId(project_id),
"reference_link": file_path,
"reference_title": file.filename,
"upload_time": datetime.now(timezone.utc)
}
result = await db.references.insert_one(reference)
reference_info = {
"reference_id": str(result.inserted_id),
"file_path": file_path,
"reference_title": reference["reference_title"]
}
uploaded_references.append(reference_info)
redis = await get_redis()
try:
await redis.select(203)
report_key = f"reference_report:{str(result.inserted_id)}"
initial_status = {
"status": "processing",
"message": "Analysis in progress"
}
await redis.set(report_key, json.dumps(initial_status))
finally:
await redis.aclose()
if uploaded_references:
asyncio.create_task(process_batch_analysis(uploaded_references))
return {
"message": f"Successfully uploaded {len(uploaded_references)} files",
"uploaded_files": uploaded_references
}
except Exception as e:
print(f"Batch upload error: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
+133
View File
@@ -0,0 +1,133 @@
from fastapi import APIRouter, HTTPException, Depends
from datetime import datetime, timezone
from typing import List, Optional
from pydantic import BaseModel, Field
from bson import ObjectId
from ..cores.db import PyObjectId, get_database, get_redis
from .login import get_current_user, UserModel
# Router
router = APIRouter(
prefix="/lab",
tags=["projects"]
)
# Pydantic models
class ProjectModel(BaseModel):
"""
项目模型
包含项目信息:用户ID、项目名称、创建时间、描述
"""
id: Optional[PyObjectId] = Field(alias="_id", default=None)
user_id: Optional[PyObjectId] = None
project_name: str
create_time: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
description: str
class Config:
populate_by_name = True
arbitrary_types_allowed = True
json_encoders = {ObjectId: str}
# Project routes
@router.post("/projects")
async def create_project(
project: ProjectModel,
current_user: UserModel = Depends(get_current_user)
):
"""
创建新项目
参数:
project: 项目信息
current_user: 当前登录用户
"""
db = await get_database()
project_dict = {
"user_id": current_user.id,
"project_name": project.project_name,
"description": project.description,
"create_time": datetime.now(timezone.utc)
}
result = await db.projects.insert_one(project_dict)
return {
"message": "Project created successfully",
"id": str(result.inserted_id)
}
@router.get("/projects")
async def get_projects(current_user: UserModel = Depends(get_current_user)):
"""
获取当前用户所有项目
参数:
current_user: 当前登录用户,通过JWT token验证获取
返回:
projects: 项目列表,每个项目包含完整的项目信息
"""
db = await get_database()
# 初始化项目列表
projects = []
# 异步查询数据库,获取该用户的所有项目
async for project in db.projects.find({"user_id": current_user.id}):
projects.append(ProjectModel(**project))
return projects
# 添加删除项目的路由
@router.delete("/lab/projects/{project_id}")
async def delete_project(
project_id: str,
current_user: UserModel = Depends(get_current_user)
):
"""删除项目及其所有相关数据"""
db = await get_database()
redis = await get_redis()
try:
# 验证项目是否存在且属于当前用户
project = await db.projects.find_one({
"_id": ObjectId(project_id),
"user_id": current_user.id
})
if not project:
raise HTTPException(status_code=404, detail="Project not found or unauthorized access")
# 获取项目下所有实验的ID
experiment_ids = []
async for exp in db.experiments.find({"project_id": ObjectId(project_id)}):
experiment_ids.append(str(exp["_id"]))
# 删除所有相关数据
for exp_id in experiment_ids:
await db.experiment_sessions.delete_many({"experiment_id": ObjectId(exp_id)})
await db.experiment_devices.delete_many({"experiment_id": ObjectId(exp_id)})
# 删除Redis中的实验报告
await redis.select(201)
await redis.delete(f"experiment_report:{exp_id}")
# 删除所有实验
await db.experiments.delete_many({"project_id": ObjectId(project_id)})
# 删除项目
await db.projects.delete_one({"_id": ObjectId(project_id)})
# 删除Redis中的项目报告
await redis.select(202)
await redis.delete(f"project_report:{project_id}")
return {"message": "Project successfully deleted"}
except Exception as e:
print(f"Error deleting project: {str(e)}")
raise HTTPException(status_code=500, detail=f"Failed to delete project: {str(e)}")
finally:
try:
await redis.aclose()
except Exception as e:
print(f"Error closing Redis connection: {e}")
+558
View File
@@ -0,0 +1,558 @@
from fastapi import APIRouter, HTTPException, Depends
from typing import Dict, List, Optional, Any
from datetime import datetime, timezone
import json
import asyncio
from bson import ObjectId
from motor.motor_asyncio import AsyncIOMotorClient
from openai import OpenAI
from redis import asyncio as aioredis
from pydantic import BaseModel
from ..cores.config import MONGODB_URL, REDIS_URL, DEEPSEEK_API_CONFIG, TaskStatus
from ..cores.db import get_database, get_redis
from .login import get_current_user, UserModel
router = APIRouter(
prefix="/lab",
tags=["project_report"]
)
# DeepSeek API Configuration
client = OpenAI(
base_url=DEEPSEEK_API_CONFIG["base_url"],
api_key=DEEPSEEK_API_CONFIG["api_key"]
)
async def analyze_project_data(project_data: Dict[str, Any]) -> Optional[Dict[str, Any]]:
system_prompt = """
You are an AI assistant responsible for analyzing lab reports.
You will summarize and analyze all lab reports and generate a comprehensive analysis report in JSON format.
The JSON structure must strictly follow the provided template.
"""
user_prompt = f"""The project report is summarized based on the following experimental analysis reports:
Experimental analysis reports: {json.dumps(project_data, ensure_ascii=False)}
Generate a JSON response with the following structure:
{{
"Project Analysis Report": {{
"1. Project Overview": {{
"Project Name": "{project_data['project_stats']['project_name']}",
"Total Experiments": "{project_data['project_stats']['total_experiments']}",
"Total Data Points": "{project_data['project_stats']['total_data_points']}",
}},
"2. Aggregated Statistics": {{
"Total Sessions": {project_data['project_stats']['total_sessions']},
"Total Duration": {project_data['project_stats']['total_duration']},
"Total Data Points": {project_data['project_stats']['total_data_points']},
"Average Session Duration": {project_data['project_stats']['avg_session_duration']},
"Average Data Points per Session": {project_data['project_stats']['avg_data_points_per_session']}
}},
"3. Performance Analysis": {{
"Best Performing Sessions": [
{{
"Session ID": "session_id",
"experiment_id": "experiment_id",
"experiment_name": "experiment_name",
"Duration": "duration",
"Data Points": "data_points",
"Success Factors": "success_factors"
}}
],
"Problematic Sessions": [
{{
"Session ID": "session_id",
"experiment_id": "experiment_id",
"experiment_name": "experiment_name",
"Issues": "specific_issues",
"Possible Causes": "possible_causes"
}}
]
}},
"4. Common Findings": {{
"Recurring Issues": "common_failure_modes",
"Equipment performance": "equipment_performance",
"Sensor reliability analysis": "sensor_reliability_analysis"
}},
"5. Recommendations": {{
"Equipment optimization suggestions": "equipment_optimization_suggestions",
"Experiment process improvement suggestions": "experiment_process_improvement_suggestions",
"Data collection strategy adjustment": "data_collection_strategy_adjustment"
}}
}}
}}
"""
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt}
]
try:
response = client.chat.completions.create(
model="deepseek-chat",
messages=messages,
response_format={'type': 'json_object'}
)
return json.loads(response.choices[0].message.content)
except Exception as e:
print(f"Error calling DeepSeek API: {e}")
return None
@router.get("/projects/{project_id}/analyze")
async def analyze_project_data_endpoint(
project_id: str,
force: bool = False,
current_user: UserModel = Depends(get_current_user)
):
"""启动项目数据分析"""
print(f"\n=== analyze_project_data_endpoint 开始 ===")
print(f"项目ID: {project_id}")
print(f"Force: {force}")
db = await get_database()
redis = await get_redis()
try:
# 检查是否已经有正在进行的分析任务
await redis.select(202)
status_key = f"project_analysis_status:{project_id}"
current_status = await redis.get(status_key)
if current_status and not force:
status_data = json.loads(current_status)
if status_data.get("status") == "processing":
print("已有分析任务正在进行中")
return {
"message": "项目分析任务正在进行中",
"status": "processing",
"project_id": project_id,
"start_time": status_data.get("start_time")
}
# 记录分析开始状态
status_data = {
"status": "processing",
"start_time": datetime.now(timezone.utc).isoformat()
}
await redis.set(status_key, json.dumps(status_data))
print("启动分析线程...")
# 创建并启动新线程
import threading
analysis_thread = threading.Thread(
target=run_project_in_thread,
args=(project_id,),
daemon=True
)
analysis_thread.start()
print("分析线程已启动")
return {
"message": "项目分析任务已启动",
"status": "processing",
"project_id": project_id,
"start_time": status_data["start_time"]
}
except Exception as e:
print(f"Error starting project analysis: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
finally:
try:
await redis.aclose()
except Exception as e:
print(f"Error closing Redis connection: {e}")
def run_project_in_thread(project_id: str):
"""在独立线程中运行分析任务"""
# 创建新的事件循环
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
# 在新的事件循环中运行异步任务
loop.run_until_complete(process_project_analysis(project_id))
except Exception as e:
print(f"Error in analysis thread: {str(e)}")
finally:
loop.close()
async def process_project_analysis(project_id: str):
"""后台处理项目分析任务"""
print(f"\n=== 开始项目分析 ===")
print(f"项目ID: {project_id}")
# 创建新的数据库和Redis连接
mongo_client = AsyncIOMotorClient(MONGODB_URL)
db = mongo_client["lab"]
redis = await aioredis.from_url(REDIS_URL, encoding="utf-8", decode_responses=True)
try:
await redis.select(202)
status_key = f"project_analysis_status:{project_id}"
# 更新状态为进行中
status_data = {
"status": "processing",
"start_time": datetime.now(timezone.utc).isoformat()
}
await redis.set(status_key, json.dumps(status_data))
# 获取项目名称
project = await db.projects.find_one({"_id": ObjectId(project_id)})
project_name = project["project_name"] if project else "Unknown Project"
# 从MongoDB查找项目下的所有实验
experiment_cursor = db.experiments.find({
"project_id": ObjectId(project_id)
})
experiment_ids = []
experiment_names = {} # 存储实验ID和名称的映射
async for experiment in experiment_cursor:
exp_id = str(experiment["_id"])
experiment_ids.append(exp_id)
experiment_names[exp_id] = experiment["experiment_name"]
if not experiment_ids:
print("错误: 未找到任何实验")
raise Exception("No experiments found in this project")
# 从Redis db201获取所有实验报告
await redis.select(201)
all_experiment_reports = {}
# 统计数据初始化
total_sessions = 0
total_duration = 0
total_data_points = 0
total_devices = set()
total_sensors = 0
for exp_id in experiment_ids:
report_key = f"experiment_report:{exp_id}"
report_data = await redis.get(report_key)
if report_data:
try:
parsed_report = json.loads(report_data)
all_experiment_reports[exp_id] = parsed_report
# 从Basic Information中提取数据
basic_info = parsed_report["Experiment Analysis Report"]["1. Basic Information"]
total_sessions += int(basic_info["Total Sessions"])
total_duration += float(basic_info["Total Duration"].split()[0]) # 去掉"seconds"
total_data_points += int(basic_info["Data Points"])
total_devices.update([str(i) for i in range(int(basic_info["Devices number"]))])
total_sensors += int(basic_info["Sensors number"])
except (json.JSONDecodeError, KeyError) as e:
continue
if not all_experiment_reports:
raise Exception("No experiment reports found")
# 计算平均值
avg_session_duration = total_duration / total_sessions if total_sessions > 0 else 0
avg_data_points_per_session = total_data_points / total_sessions if total_sessions > 0 else 0
# 添加项目级统计数据
project_stats = {
"project_name": project_name,
"total_experiments": len(experiment_ids),
"total_sessions": total_sessions,
"total_duration": total_duration,
"total_data_points": total_data_points,
"total_devices": len(total_devices),
"total_sensors": total_sensors,
"avg_session_duration": avg_session_duration,
"avg_data_points_per_session": avg_data_points_per_session,
"experiment_names": experiment_names
}
# 将统计数据和原始报告一起发送给分析函数
project_data = {
"project_stats": project_stats,
"experiment_reports": all_experiment_reports
}
# 执行分析
analysis_result = await analyze_project_data(project_data)
if analysis_result:
# 保存分析结果
await redis.select(202)
report_key = f"project_report:{project_id}"
await redis.set(report_key, json.dumps(analysis_result))
# 更新状态为完成
status_key = f"project_analysis_status:{project_id}"
status_data = {
"status": "completed",
"completion_time": datetime.now(timezone.utc).isoformat()
}
await redis.set(status_key, json.dumps(status_data))
print("分析报告已保存")
else:
print("\n7. 分析失败")
# 更新失败状态
status_data = {
"status": "failed",
"error": "Failed to generate analysis result",
"completion_time": datetime.now(timezone.utc).isoformat()
}
await redis.set(status_key, json.dumps(status_data))
except Exception as e:
try:
await redis.select(202)
status_key = f"project_analysis_status:{project_id}"
status_data = {
"status": "failed",
"error": str(e),
"completion_time": datetime.now(timezone.utc).isoformat()
}
await redis.set(status_key, json.dumps(status_data))
except Exception as redis_error:
print(f"Error updating Redis status: {redis_error}")
finally:
try:
await redis.aclose()
mongo_client.close()
except Exception as e:
print(f"Error closing Redis connection: {e}")
@router.get("/projects/{project_id}/analysis_status")
async def get_project_analysis_status(
project_id: str,
current_user: UserModel = Depends(get_current_user)
):
"""获取项目分析任务的状态"""
redis = await get_redis()
try:
await redis.select(202)
status_key = f"project_analysis_status:{project_id}"
status_data = await redis.get(status_key)
if not status_data:
return {
"status": "not_started",
"project_id": project_id
}
return json.loads(status_data)
except Exception as e:
print(f"Error getting analysis status: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
finally:
try:
await redis.aclose()
except Exception as e:
print(f"Error closing Redis connection: {e}")
@router.get("/projects/{project_id}/report")
async def get_project_saved_report(
project_id: str,
current_user: UserModel = Depends(get_current_user)
):
"""从 Redis db202 读取已保存的项目报告"""
redis = await get_redis()
try:
# 选择 db202
await redis.select(202)
report_key = f"project_report:{project_id}"
# 获取已保存的报告
existing_report = await redis.get(report_key)
if not existing_report:
raise HTTPException(status_code=404, detail="No saved project report found")
# 返回报告
return json.loads(existing_report)
except Exception as e:
print(f"Error getting project report: {e}")
raise HTTPException(status_code=500, detail=str(e))
finally:
try:
await redis.aclose()
except Exception as e:
print(f"Error closing Redis connection: {e}")
class QuestionModel(BaseModel):
"""问题模型"""
question: str
# 项目问答
@router.post("/lab/projects/{project_id}/qa")
async def ask_project_question(
project_id: str,
question: QuestionModel,
current_user: UserModel = Depends(get_current_user)
):
"""向项目报告提问(异步)"""
db = await get_database()
redis = await get_redis()
try:
# 验证项目是否存在
project = await db.projects.find_one({"_id": ObjectId(project_id)})
if not project:
raise HTTPException(status_code=404, detail="项目不存在")
# 生成任务ID
task_id = str(ObjectId())
# 创建后台任务
asyncio.create_task(process_project_question(
task_id=task_id,
project_id=project_id,
question=question.question,
project=project
))
return {
"task_id": task_id,
"status": TaskStatus.PENDING,
"message": "问题已提交,正在处理中"
}
except Exception as e:
print(f"Error in ask_project_question: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
finally:
try:
await redis.aclose()
except Exception as e:
print(f"Error closing Redis connection: {e}")
async def process_project_question(task_id: str, project_id: str, question: str, project: dict):
"""处理项目问答的后台任务"""
redis = await get_redis()
try:
# 更新任务状态为处理中
await redis.select(208) # 使用db208存储任务状态
await redis.hset(
f"task:{task_id}",
mapping={
"status": TaskStatus.PROCESSING,
"project_id": project_id,
"question": question
}
)
# 从Redis获取分析报告
await redis.select(202)
report_key = f"project_report:{project_id}"
report_data = await redis.get(report_key)
if not report_data:
raise Exception("项目分析报告不存在,请先进行分析")
# 构建系统提示和用户提示
system_prompt = """
你是一个专业的项目助手,负责回答关于项目报告的问题。
你应该基于项目报告提供准确、专业的回答。
回答应当简洁明了,并尽可能引用报告中的具体内容。
"""
# 构建上下文
context = {
"项目名称": project.get("project_name", "未知项目"),
"分析报告": json.loads(report_data)
}
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": f"基于以下项目报告回答问题:\n\n项目信息:{json.dumps(context, ensure_ascii=False)}\n\n问题:{question}"}
]
# 调用DeepSeek API获取回答
response = client.chat.completions.create(
model="deepseek-chat",
messages=messages
)
answer = response.choices[0].message.content
# 保存对话历史到Redis db207
await redis.select(207)
chat_history_key = f"project_chat_history:{project_id}"
# 获取现有历史记录
existing_history = await redis.get(chat_history_key)
history = json.loads(existing_history) if existing_history else []
# 添加新的对话
history.append({
"question": question,
"answer": answer,
"timestamp": datetime.now(timezone.utc).isoformat()
})
# 保存更新后的历史记录
await redis.set(chat_history_key, json.dumps(history))
# 更新任务状态为完成
await redis.select(208)
await redis.hset(
f"task:{task_id}",
mapping={
"status": TaskStatus.COMPLETED,
"answer": answer,
"project_name": project.get("project_name")
}
)
except Exception as e:
print(f"Error processing question: {e}")
# 更新任务状态为失败
await redis.select(208)
await redis.hset(
f"task:{task_id}",
mapping={
"status": TaskStatus.FAILED,
"error": str(e)
}
)
finally:
try:
await redis.aclose()
except Exception as e:
print(f"Error closing Redis connection: {e}")
@router.get("/lab/projects/{project_id}/qa/history")
async def get_project_qa_history(
project_id: str,
current_user: UserModel = Depends(get_current_user)
):
"""获取项目问答历史记录"""
redis = await get_redis()
try:
# 从Redis db207获取对话历史
await redis.select(207)
chat_history_key = f"project_chat_history:{project_id}"
history_data = await redis.get(chat_history_key)
if not history_data:
return []
return json.loads(history_data)
except Exception as e:
print(f"Error getting QA history: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
finally:
try:
await redis.aclose()
except Exception as e:
print(f"Error closing Redis connection: {e}")
+142
View File
@@ -0,0 +1,142 @@
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
from typing import Dict, Set
import asyncio
import json
import time
router = APIRouter()
# Connection Manager for WebSocket clients
class ConnectionManager:
def __init__(self):
self.active_connections: Dict[str, Dict[str, Set[WebSocket]]] = {}
self._lock = asyncio.Lock()
# 添加消息队列
self.message_queues: Dict[str, asyncio.Queue] = {}
self.broadcast_tasks: Dict[str, asyncio.Task] = {}
async def connect(self, websocket: WebSocket, experiment_id: str, serial_number: str):
"""添加新的WebSocket连接"""
async with self._lock:
if experiment_id not in self.active_connections:
self.active_connections[experiment_id] = {}
if serial_number not in self.active_connections[experiment_id]:
self.active_connections[experiment_id][serial_number] = set()
self.active_connections[experiment_id][serial_number].add(websocket)
print(f"新连接已添加到管理器 - 实验ID: {experiment_id}, 设备: {serial_number}")
async def disconnect(self, websocket: WebSocket, experiment_id: str, serial_number: str):
"""移除WebSocket连接"""
async with self._lock:
try:
if (experiment_id in self.active_connections and
serial_number in self.active_connections[experiment_id]):
self.active_connections[experiment_id][serial_number].remove(websocket)
# 只在成功移除连接时打印一次日志
print(f"WebSocket连接已断开 - 实验ID: {experiment_id}, 设备: {serial_number}")
# 清理空集合
if not self.active_connections[experiment_id][serial_number]:
del self.active_connections[experiment_id][serial_number]
if not self.active_connections[experiment_id]:
del self.active_connections[experiment_id]
except KeyError:
pass
except Exception as e:
print(f"断开连接时出错: {e}")
async def broadcast_worker(self, experiment_id: str, serial_number: str):
"""单独的广播工作器"""
queue = self.message_queues[f"{experiment_id}:{serial_number}"]
while True:
try:
message = await queue.get()
disconnected = set()
async with self._lock:
connections = self.active_connections[experiment_id][serial_number].copy()
await asyncio.gather(*[
self.send_message(ws, message, disconnected)
for ws in connections
], return_exceptions=True)
queue.task_done()
# 清理断开的连接
for websocket in disconnected:
await self.disconnect(websocket, experiment_id, serial_number)
except Exception as e:
print(f"广播工作器错误: {e}")
async def broadcast(self, message: str, experiment_id: str, serial_number: str):
"""使用消息队列进行广播"""
queue_key = f"{experiment_id}:{serial_number}"
if queue_key not in self.message_queues:
self.message_queues[queue_key] = asyncio.Queue()
self.broadcast_tasks[queue_key] = asyncio.create_task(
self.broadcast_worker(experiment_id, serial_number)
)
await self.message_queues[queue_key].put(message)
async def send_message(self, websocket: WebSocket, message: str, disconnected: set):
"""处理单个连接的消息发送"""
try:
await websocket.send_text(message)
except Exception as e:
disconnected.add(websocket)
# 创建连接管理器实例
manager = ConnectionManager()
@router.websocket("/lab/ws/{experiment_id}/{serial_number}")
async def websocket_endpoint(
websocket: WebSocket,
experiment_id: str,
serial_number: str
):
try:
await websocket.accept()
print(f"WebSocket连接已建立 - 实验ID: {experiment_id}, 设备: {serial_number}")
# 添加到连接管理器
await manager.connect(websocket, experiment_id, serial_number)
# 发送初始连接响应
response = {
"type": "connect_response",
"status": "success",
"message": "连接成功",
"experiment_id": experiment_id,
"serial_number": serial_number
}
await websocket.send_text(json.dumps(response))
try:
# 保持连接并持续接收消息
while True:
try:
data = await websocket.receive_text()
# 只处理非心跳消息
if data != "ping":
message = json.loads(data)
await manager.broadcast(data, experiment_id, serial_number)
except WebSocketDisconnect:
break
except json.JSONDecodeError:
continue # 忽略无效的JSON数据
except Exception as e:
print(f"处理WebSocket消息时出错: {e}")
finally:
await manager.disconnect(websocket, experiment_id, serial_number)
@router.get("/lab/status")
async def get_status():
return {
"status": "running",
"timestamp": time.time(),
"active_connections": len(manager.active_connections)
}
View File
View File
View File
View File