update
This commit is contained in:
+12
-18
@@ -1,37 +1,30 @@
|
||||
from typing import Dict, Any
|
||||
import os
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from enum import Enum
|
||||
from openai import OpenAI
|
||||
# MongoDB 配置
|
||||
MONGODB_URL = "mongodb://localhost:27017"
|
||||
MONGODB_URL = "mongodb://lab:y6aHwySAhzrbibLD@222.186.10.253:27017/lab"
|
||||
DATABASE_NAME = "lab"
|
||||
|
||||
# Redis 配置
|
||||
REDIS_URL = "redis://localhost:6379"
|
||||
REDIS_URL = "redis://:Obscura@2024@222.186.10.253:6379"
|
||||
|
||||
# DeepSeek API 配置
|
||||
DEEPSEEK_API_CONFIG = {
|
||||
"base_url": "https://api.deepseek.com/v1",
|
||||
"api_key": "sk-3027fb3c810b4e17985fa397d41250b9"
|
||||
}
|
||||
|
||||
# DeepSeek API Configuration
|
||||
client = OpenAI(
|
||||
base_url=DEEPSEEK_API_CONFIG["base_url"],
|
||||
api_key=DEEPSEEK_API_CONFIG["api_key"]
|
||||
)
|
||||
# 文件上传配置
|
||||
UPLOAD_PATH = "/obscura/task/references"
|
||||
|
||||
# Redis 数据库映射
|
||||
REDIS_DB_MAP = {
|
||||
"experiment_analysis": 201, # 实验分析报告
|
||||
"project_analysis": 202, # 项目分析报告
|
||||
"reference_analysis": 203, # 文献分析报告
|
||||
"reference_summary": 204, # 文献汇总分析
|
||||
"project_memo": 205, # 项目备忘录
|
||||
"experiment_memo": 206, # 实验备忘录
|
||||
"chat_history": 207, # 聊天历史记录
|
||||
"task_status": 208 # 任务状态
|
||||
}
|
||||
|
||||
# 任务状态枚举
|
||||
class TaskStatus:
|
||||
class TaskStatus(Enum):
|
||||
PENDING = "pending"
|
||||
PROCESSING = "processing"
|
||||
COMPLETED = "completed"
|
||||
@@ -41,7 +34,8 @@ class TaskStatus:
|
||||
THREAD_POOL_CONFIG = {
|
||||
"pdf_workers": 3,
|
||||
"analysis_workers": 3,
|
||||
"pro_analysis_workers": 3
|
||||
"pro_analysis_workers": 3,
|
||||
"summary_workers": 3
|
||||
}
|
||||
|
||||
# API 响应格式
|
||||
|
||||
+94
-4
@@ -3,10 +3,77 @@ from pydantic import BaseModel, Field
|
||||
from motor.motor_asyncio import AsyncIOMotorClient
|
||||
from contextlib import asynccontextmanager
|
||||
from redis import asyncio as aioredis
|
||||
from .config import MONGODB_URL, REDIS_URL
|
||||
|
||||
# Database Configuration
|
||||
MONGODB_URL = "mongodb://lab:y6aHwySAhzrbibLD@222.186.10.253:27017/lab"
|
||||
REDIS_URL = "redis://:Obscura@2024@222.186.10.253:6379"
|
||||
# Redis数据库映射
|
||||
REDIS_DB_MAPPING = {
|
||||
"experiment_status": {
|
||||
"db": 199,
|
||||
"description": "存储实验状态",
|
||||
"key_pattern": "experiment_status:{experiment_id}:{device_serial}"
|
||||
},
|
||||
"raw_data": {
|
||||
"db": 200,
|
||||
"description": "存储原始数据",
|
||||
"key_pattern": "experiment:{device_serial}"
|
||||
},
|
||||
"experiment_analysis": {
|
||||
"db": 201,
|
||||
"description": "存储实验分析报告",
|
||||
"key_patterns": {
|
||||
"report": "experiment_report:{experiment_id}",
|
||||
"status": "experiment_analysis_status:{experiment_id}"
|
||||
}
|
||||
},
|
||||
"project_analysis": {
|
||||
"db": 202,
|
||||
"description": "存储项目分析报告",
|
||||
"key_patterns": {
|
||||
"report": "project_report:{project_id}",
|
||||
"status": "project_analysis_status:{project_id}"
|
||||
}
|
||||
},
|
||||
"reference_analysis": {
|
||||
"db": 203,
|
||||
"description": "存储文献分析报告",
|
||||
"key_patterns": {
|
||||
"report": "reference_report:{reference_id}",
|
||||
"status": "reference_analysis_status:{reference_id}"
|
||||
}
|
||||
},
|
||||
"reference_summary": {
|
||||
"db": 204,
|
||||
"description": "存储文献分析状态和汇总报告",
|
||||
"key_patterns": {
|
||||
"status": "reference_analysis_status:{project_id}",
|
||||
"report": "reference_summary_report:{project_id}"
|
||||
}
|
||||
},
|
||||
"project_memo": {
|
||||
"db": 205,
|
||||
"description": "存储项目备忘录",
|
||||
"key_pattern": "project_memo:{project_id}"
|
||||
},
|
||||
"experiment_memo": {
|
||||
"db": 206,
|
||||
"description": "存储实验备忘录",
|
||||
"key_pattern": "experiment_memo:{experiment_id}"
|
||||
},
|
||||
"chat_history": {
|
||||
"db": 207,
|
||||
"description": "存储问答历史记录",
|
||||
"key_patterns": {
|
||||
"reference": "chat_history:{reference_id}",
|
||||
"experiment": "experiment_chat_history:{experiment_id}",
|
||||
"project": "project_chat_history:{project_id}"
|
||||
}
|
||||
},
|
||||
"qa_task": {
|
||||
"db": 208,
|
||||
"description": "存储问答任务状态",
|
||||
"key_pattern": "task:{task_id}"
|
||||
}
|
||||
}
|
||||
|
||||
class PyObjectId(ObjectId):
|
||||
"""
|
||||
@@ -50,4 +117,27 @@ async def get_redis():
|
||||
encoding="utf-8",
|
||||
decode_responses=True
|
||||
)
|
||||
return redis
|
||||
return redis
|
||||
|
||||
async def get_redis_key(redis_type: str, key_pattern: str, **kwargs) -> tuple[int, str]:
|
||||
"""
|
||||
获取Redis数据库编号和格式化后的键名
|
||||
|
||||
参数:
|
||||
redis_type: REDIS_DB_MAPPING中的键名
|
||||
key_pattern: 键模式名称
|
||||
**kwargs: 用于格式化键模式的参数
|
||||
|
||||
返回:
|
||||
tuple[int, str]: (数据库编号, 格式化后的键名)
|
||||
"""
|
||||
db_info = REDIS_DB_MAPPING[redis_type]
|
||||
db_number = db_info["db"]
|
||||
|
||||
if "key_patterns" in db_info:
|
||||
pattern = db_info["key_patterns"][key_pattern]
|
||||
else:
|
||||
pattern = db_info["key_pattern"]
|
||||
|
||||
key = pattern.format(**kwargs)
|
||||
return db_number, key
|
||||
+3
-7
@@ -7,7 +7,7 @@ from fastapi.middleware.cors import CORSMiddleware
|
||||
from openai import OpenAI
|
||||
|
||||
# Local application imports
|
||||
from .cores.config import DEEPSEEK_API_CONFIG
|
||||
from .cores.config import client
|
||||
from .cores.db import (
|
||||
connect_to_mongo,
|
||||
close_mongo_connection,
|
||||
@@ -22,12 +22,8 @@ from .routers.websocket import router as websocket_router
|
||||
from .routers.project_report import router as project_report_router
|
||||
from .routers.memo import router as memo_router
|
||||
from .routers.paper import router as paper_router
|
||||
from .routers.paper_summary import router as paper_summary_router
|
||||
|
||||
# DeepSeek API Configuration
|
||||
client = OpenAI(
|
||||
base_url=DEEPSEEK_API_CONFIG["base_url"],
|
||||
api_key=DEEPSEEK_API_CONFIG["api_key"]
|
||||
)
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
@@ -51,7 +47,7 @@ app.include_router(websocket_router)
|
||||
app.include_router(project_report_router)
|
||||
app.include_router(memo_router)
|
||||
app.include_router(paper_router)
|
||||
|
||||
app.include_router(paper_summary_router)
|
||||
# CORS configuration
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
|
||||
@@ -0,0 +1,132 @@
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional, List
|
||||
from pydantic import BaseModel, EmailStr, Field
|
||||
from bson import ObjectId
|
||||
from ..cores.db import PyObjectId
|
||||
|
||||
class ExperimentDeviceModel(BaseModel):
|
||||
"""实验-设备关联模型"""
|
||||
id: Optional[PyObjectId] = Field(alias="_id", default=None)
|
||||
experiment_id: PyObjectId
|
||||
user_device_id: PyObjectId
|
||||
|
||||
class Config:
|
||||
populate_by_name = True
|
||||
arbitrary_types_allowed = True
|
||||
json_encoders = {ObjectId: str}
|
||||
|
||||
class QuestionModel(BaseModel):
|
||||
"""问题模型"""
|
||||
question: str
|
||||
|
||||
class UserModel(BaseModel):
|
||||
"""用户模型"""
|
||||
id: Optional[PyObjectId] = Field(alias="_id", default=None)
|
||||
username: str
|
||||
password: str
|
||||
email: EmailStr
|
||||
name: str
|
||||
institution: str
|
||||
|
||||
class Config:
|
||||
populate_by_name = True
|
||||
arbitrary_types_allowed = True
|
||||
json_encoders = {ObjectId: str}
|
||||
|
||||
class ProjectModel(BaseModel):
|
||||
"""项目模型"""
|
||||
id: Optional[PyObjectId] = Field(alias="_id", default=None)
|
||||
user_id: Optional[PyObjectId] = None
|
||||
project_name: str
|
||||
create_time: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
description: str
|
||||
|
||||
class Config:
|
||||
populate_by_name = True
|
||||
arbitrary_types_allowed = True
|
||||
json_encoders = {ObjectId: str}
|
||||
|
||||
class ExperimentModel(BaseModel):
|
||||
"""实验模型"""
|
||||
id: Optional[PyObjectId] = Field(alias="_id", default=None)
|
||||
project_id: PyObjectId
|
||||
experiment_name: str
|
||||
create_time: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
description: str
|
||||
status: str = Field(default="active")
|
||||
|
||||
class Config:
|
||||
populate_by_name = True
|
||||
arbitrary_types_allowed = True
|
||||
json_encoders = {ObjectId: str}
|
||||
|
||||
class ExperimentCreate(BaseModel):
|
||||
"""实验创建请求模型"""
|
||||
project_id: str
|
||||
experiment_name: str
|
||||
description: str | None = None
|
||||
|
||||
class ExperimentSession(BaseModel):
|
||||
"""实验会话模型"""
|
||||
id: Optional[PyObjectId] = Field(alias="_id", default=None)
|
||||
experiment_id: PyObjectId
|
||||
user_id: PyObjectId
|
||||
start_time: datetime
|
||||
end_time: Optional[datetime] = None
|
||||
duration: Optional[float] = None
|
||||
|
||||
class Config:
|
||||
populate_by_name = True
|
||||
arbitrary_types_allowed = True
|
||||
json_encoders = {ObjectId: str}
|
||||
|
||||
class ExperimentData(BaseModel):
|
||||
"""实验数据模型"""
|
||||
id: Optional[PyObjectId] = Field(alias="_id", default=None)
|
||||
experiment_id: PyObjectId
|
||||
session_ids: List[PyObjectId]
|
||||
user_id: PyObjectId
|
||||
device_id: str
|
||||
sensor_name: str
|
||||
last_update: datetime
|
||||
|
||||
class Config:
|
||||
populate_by_name = True
|
||||
arbitrary_types_allowed = True
|
||||
json_encoders = {ObjectId: str}
|
||||
|
||||
class MemoModel(BaseModel):
|
||||
"""备忘录模型"""
|
||||
content: str
|
||||
create_time: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
|
||||
class FormulaModel(BaseModel):
|
||||
"""公式模型"""
|
||||
data_name: str
|
||||
data_unit: str
|
||||
formula: str
|
||||
|
||||
class FormulaCreate(BaseModel):
|
||||
"""公式创建请求模型"""
|
||||
sensor_name: str
|
||||
data_name: str
|
||||
data_unit: str
|
||||
formula: str
|
||||
|
||||
class ExperimentStatus:
|
||||
ACTIVE = "active"
|
||||
COMPLETED = "completed"
|
||||
|
||||
# 在其他 Pydantic 模型后添加
|
||||
class ReferenceModel(BaseModel):
|
||||
"""文献引用模型"""
|
||||
id: Optional[PyObjectId] = Field(alias="_id", default=None)
|
||||
project_id: PyObjectId
|
||||
reference_link: str
|
||||
reference_title: str
|
||||
upload_time: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
|
||||
class Config:
|
||||
populate_by_name = True
|
||||
arbitrary_types_allowed = True
|
||||
json_encoders = {ObjectId: str}
|
||||
@@ -0,0 +1,354 @@
|
||||
from ..cores.config import DEEPSEEK_API_CONFIG, TaskStatus, client
|
||||
from ..cores.db import get_redis, get_redis_key
|
||||
from ..cores.config import MONGODB_URL, REDIS_URL, REDIS_DB_MAP
|
||||
from datetime import datetime, timezone
|
||||
import json
|
||||
import asyncio
|
||||
from motor.motor_asyncio import AsyncIOMotorClient
|
||||
from bson import ObjectId
|
||||
from redis import asyncio as aioredis
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from ..cores.config import THREAD_POOL_CONFIG
|
||||
|
||||
|
||||
|
||||
# 创建项目分析线程池
|
||||
exp_analysis_executor = ThreadPoolExecutor(
|
||||
max_workers=THREAD_POOL_CONFIG["exp_analysis_workers"],
|
||||
thread_name_prefix="exp_analysis"
|
||||
)
|
||||
|
||||
async def analyze_experiment_data(experiment_info):
|
||||
system_prompt = """
|
||||
You are an AI assistant tasked with analyzing experimental data.
|
||||
Generate a comprehensive experiment analysis report in JSON format.
|
||||
The JSON structure must strictly follow the provided template.
|
||||
"""
|
||||
|
||||
user_prompt = f"""Analyze the experiment data based on the following information:
|
||||
Experiment data: {json.dumps(experiment_info['sessions'], ensure_ascii=False)}
|
||||
|
||||
Generate a JSON response with the following structure:
|
||||
|
||||
{{
|
||||
"Experiment Analysis Report": {{
|
||||
"1. Basic Information": {{
|
||||
"Total Sessions": "{experiment_info['total_sessions']}",
|
||||
"Total Duration": "{experiment_info['total_duration']} seconds",
|
||||
"Data Points": "{experiment_info['total_points']}",
|
||||
"Devices number": "{experiment_info['device_stats']['total_devices']}",
|
||||
"Sensors number": "{experiment_info['device_stats']['total_sensors']}"
|
||||
}},
|
||||
"2. Session Data Analysis": {{
|
||||
"[Session ID]": {{
|
||||
"Duration": "[Duration] seconds",
|
||||
"Data Points": "[data_points]"
|
||||
}},
|
||||
// Repeat for each session dynamically
|
||||
}},
|
||||
"3. Key Findings": [
|
||||
"[Finding 1]",
|
||||
"[Finding 2]",
|
||||
"[Finding 3]"
|
||||
],
|
||||
"4. Recommendations": [
|
||||
"[Recommendation 1]",
|
||||
"[Recommendation 2]",
|
||||
"[Recommendation 3]"
|
||||
]
|
||||
}}
|
||||
}}
|
||||
"""
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_prompt}
|
||||
]
|
||||
|
||||
try:
|
||||
response = client.chat.completions.create(
|
||||
model="deepseek-chat",
|
||||
messages=messages,
|
||||
response_format={'type': 'json_object'}
|
||||
)
|
||||
return json.loads(response.choices[0].message.content)
|
||||
except Exception as e:
|
||||
print(f"Error calling DeepSeek API: {e}")
|
||||
return None
|
||||
|
||||
def run_experiment_in_thread(experiment_id: str):
|
||||
"""在独立线程中运行分析任务"""
|
||||
# 创建新的事件循环
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
# 在新的事件循环中运行异步任务
|
||||
loop.run_until_complete(process_experiment_analysis(experiment_id))
|
||||
except Exception as e:
|
||||
print(f"Error in analysis thread: {str(e)}")
|
||||
finally:
|
||||
try:
|
||||
# 清理所有待处理的任务
|
||||
pending = asyncio.all_tasks(loop)
|
||||
for task in pending:
|
||||
task.cancel()
|
||||
# 运行直到所有任务完成
|
||||
if pending:
|
||||
loop.run_until_complete(asyncio.gather(*pending, return_exceptions=True))
|
||||
except Exception as e:
|
||||
print(f"Error cleaning up tasks: {str(e)}")
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
async def process_experiment_analysis(experiment_id: str):
|
||||
"""后台处理实验分析任务"""
|
||||
print(f"\n=== 开始实验分析 ===")
|
||||
print(f"实验ID: {experiment_id}")
|
||||
|
||||
# 创建新的数据库和Redis连接
|
||||
mongo_client = AsyncIOMotorClient(MONGODB_URL)
|
||||
db = mongo_client["lab"]
|
||||
redis = await aioredis.from_url(REDIS_URL, encoding="utf-8", decode_responses=True)
|
||||
|
||||
try:
|
||||
await redis.select(REDIS_DB_MAP["experiment_analysis"])
|
||||
status_key = f"experiment_analysis_status:{experiment_id}"
|
||||
|
||||
# 更新状态为进行中
|
||||
status_data = {
|
||||
"status": "processing",
|
||||
"start_time": datetime.now(timezone.utc).isoformat()
|
||||
}
|
||||
await redis.set(status_key, json.dumps(status_data))
|
||||
|
||||
# 获取实验基本信息
|
||||
experiment = await db.experiments.find_one({"_id": ObjectId(experiment_id)})
|
||||
if not experiment:
|
||||
raise ValueError("实验不存在")
|
||||
|
||||
# 收集会话数据
|
||||
sessions = []
|
||||
total_points = 0
|
||||
total_duration = 0
|
||||
devices_set = set()
|
||||
sensors_count = 0
|
||||
|
||||
async for session in db.experiment_sessions.find({
|
||||
"experiment_id": ObjectId(experiment_id)
|
||||
}):
|
||||
# 计算会话持续时间
|
||||
end_time = session.get("end_time") or datetime.now(timezone.utc)
|
||||
duration = (end_time - session["start_time"]).total_seconds()
|
||||
total_duration += duration
|
||||
|
||||
# 转换时间戳为毫秒
|
||||
start_ms = int(session["start_time"].timestamp() * 1000)
|
||||
end_ms = int(end_time.timestamp() * 1000)
|
||||
|
||||
# 统计设备和传感器数量
|
||||
session_devices = session.get("devices", [])
|
||||
for device in session_devices:
|
||||
devices_set.add(device["serial_number"])
|
||||
sensors_count += len(device.get("sensors", []))
|
||||
|
||||
# 获取会话数据点数
|
||||
session_points = 0
|
||||
for device in session_devices:
|
||||
stream_key = f"experiment:{experiment_id}:{device['serial_number']}"
|
||||
await redis.select(200)
|
||||
data_points = await redis.xrange(
|
||||
stream_key,
|
||||
min=str(start_ms),
|
||||
max=str(end_ms)
|
||||
)
|
||||
session_points += len(data_points)
|
||||
|
||||
total_points += session_points
|
||||
sessions.append({
|
||||
"session_id": str(session["_id"]),
|
||||
"duration": duration,
|
||||
"data_points": session_points
|
||||
})
|
||||
|
||||
if not sessions:
|
||||
raise ValueError("没有找到实验会话数据")
|
||||
|
||||
experiment_info = {
|
||||
"experiment_name": experiment["experiment_name"],
|
||||
"total_sessions": len(sessions),
|
||||
"total_duration": total_duration,
|
||||
"total_points": total_points,
|
||||
"device_stats": {
|
||||
"total_devices": len(devices_set),
|
||||
"total_sensors": sensors_count
|
||||
},
|
||||
"sessions": sessions
|
||||
}
|
||||
|
||||
# 执行分析
|
||||
analysis_result = await analyze_experiment_data(experiment_info)
|
||||
|
||||
if analysis_result:
|
||||
# 保存分析结果
|
||||
db_number, report_key = await get_redis_key(
|
||||
"experiment_analysis",
|
||||
"report",
|
||||
experiment_id=experiment_id
|
||||
)
|
||||
await redis.select(db_number)
|
||||
await redis.set(report_key, json.dumps(analysis_result))
|
||||
|
||||
# 更新状态
|
||||
db_number, status_key = await get_redis_key(
|
||||
"experiment_analysis",
|
||||
"status",
|
||||
experiment_id=experiment_id
|
||||
)
|
||||
await redis.select(db_number)
|
||||
status_data = {
|
||||
"status": "completed",
|
||||
"completion_time": datetime.now(timezone.utc).isoformat()
|
||||
}
|
||||
await redis.set(status_key, json.dumps(status_data))
|
||||
print("分析报告已保存")
|
||||
else:
|
||||
print("分析失败")
|
||||
status_data = {
|
||||
"status": "failed",
|
||||
"error": "Failed to generate analysis result",
|
||||
"completion_time": datetime.now(timezone.utc).isoformat()
|
||||
}
|
||||
await redis.set(status_key, json.dumps(status_data))
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error in experiment analysis: {str(e)}")
|
||||
try:
|
||||
db_number, status_key = await get_redis_key(
|
||||
"experiment_analysis",
|
||||
"status",
|
||||
experiment_id=experiment_id
|
||||
)
|
||||
await redis.select(db_number)
|
||||
status_data = {
|
||||
"status": "failed",
|
||||
"error": str(e),
|
||||
"completion_time": datetime.now(timezone.utc).isoformat()
|
||||
}
|
||||
await redis.set(status_key, json.dumps(status_data))
|
||||
except Exception as redis_error:
|
||||
print(f"Error updating Redis status: {redis_error}")
|
||||
finally:
|
||||
try:
|
||||
await redis.aclose()
|
||||
mongo_client.close()
|
||||
except Exception as e:
|
||||
print(f"Error closing connections: {e}")
|
||||
|
||||
|
||||
async def process_experiment_question(task_id: str, experiment_id: str, question: str, experiment: dict):
|
||||
"""处理实验问答的后台任务"""
|
||||
redis = await get_redis()
|
||||
|
||||
try:
|
||||
# 更新任务状态
|
||||
db_number, task_key = await get_redis_key(
|
||||
"qa_task",
|
||||
"key_pattern",
|
||||
task_id=task_id
|
||||
)
|
||||
await redis.select(db_number)
|
||||
await redis.hset(
|
||||
task_key,
|
||||
mapping={
|
||||
"status": TaskStatus.PROCESSING,
|
||||
"experiment_id": experiment_id,
|
||||
"question": question
|
||||
}
|
||||
)
|
||||
|
||||
# 获取分析报告
|
||||
db_number, report_key = await get_redis_key(
|
||||
"experiment_analysis",
|
||||
"report",
|
||||
experiment_id=experiment_id
|
||||
)
|
||||
await redis.select(db_number)
|
||||
report_data = await redis.get(report_key)
|
||||
|
||||
if not report_data:
|
||||
raise Exception("实验分析报告不存在,请先进行分析")
|
||||
|
||||
# 构建系统提示和用户提示
|
||||
system_prompt = """
|
||||
你是一个专业的实验助手,负责回答关于实验报告的问题。
|
||||
你应该基于实验报告提供准确、专业的回答。
|
||||
回答应当简洁明了,并尽可能引用报告中的具体内容。
|
||||
"""
|
||||
|
||||
# 构建上下文
|
||||
context = {
|
||||
"实验名称": experiment.get("experiment_name", "未知实验"),
|
||||
"分析报告": json.loads(report_data)
|
||||
}
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": f"基于以下实验报告回答问题:\n\n实验信息:{json.dumps(context, ensure_ascii=False)}\n\n问题:{question}"}
|
||||
]
|
||||
|
||||
# 调用DeepSeek API获取回答
|
||||
response = client.chat.completions.create(
|
||||
model="deepseek-chat",
|
||||
messages=messages
|
||||
)
|
||||
answer = response.choices[0].message.content
|
||||
|
||||
# 保存对话历史
|
||||
db_number, chat_history_key = await get_redis_key(
|
||||
"chat_history",
|
||||
"experiment",
|
||||
experiment_id=experiment_id
|
||||
)
|
||||
await redis.select(db_number)
|
||||
# 获取现有历史记录
|
||||
existing_history = await redis.get(chat_history_key)
|
||||
history = json.loads(existing_history) if existing_history else []
|
||||
|
||||
# 添加新的对话
|
||||
history.append({
|
||||
"question": question,
|
||||
"answer": answer,
|
||||
"timestamp": datetime.now(timezone.utc).isoformat()
|
||||
})
|
||||
|
||||
# 保存更新后的历史记录
|
||||
await redis.set(chat_history_key, json.dumps(history))
|
||||
|
||||
# 更新任务状态为完成
|
||||
await redis.select(db_number)
|
||||
await redis.hset(
|
||||
task_key,
|
||||
mapping={
|
||||
"status": TaskStatus.COMPLETED,
|
||||
"answer": answer,
|
||||
"experiment_name": experiment.get("experiment_name")
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error processing question: {e}")
|
||||
# 更新任务状态为失败
|
||||
await redis.select(db_number)
|
||||
await redis.hset(
|
||||
task_key,
|
||||
mapping={
|
||||
"status": TaskStatus.FAILED,
|
||||
"error": str(e)
|
||||
}
|
||||
)
|
||||
finally:
|
||||
try:
|
||||
await redis.aclose()
|
||||
except Exception as e:
|
||||
print(f"Error closing Redis connection: {e}")
|
||||
+1
-296
@@ -1,306 +1,11 @@
|
||||
import json
|
||||
import os
|
||||
import asyncio
|
||||
from datetime import datetime, timezone
|
||||
from typing import List, Dict
|
||||
import PyPDF2
|
||||
import aiohttp
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from openai import OpenAI
|
||||
from ..cores.config import DEEPSEEK_API_CONFIG
|
||||
from ..cores.db import get_redis
|
||||
from enum import Enum
|
||||
|
||||
# 创建OpenAI客户端
|
||||
client = OpenAI(
|
||||
base_url=DEEPSEEK_API_CONFIG["base_url"],
|
||||
api_key=DEEPSEEK_API_CONFIG["api_key"]
|
||||
)
|
||||
|
||||
# 修改函数定义为异步函数
|
||||
async def analyze_reference_summary(reference_data):
|
||||
system_prompt = """
|
||||
You are an AI assistant responsible for analyzing paper reports.
|
||||
You will summarize and analyze all paper reports and generate a comprehensive analysis report in JSON format.
|
||||
The JSON structure must strictly follow the provided template.
|
||||
"""
|
||||
|
||||
user_prompt = f"""Analyze the following paper reports:
|
||||
Paper reports: {json.dumps(reference_data, ensure_ascii=False)}
|
||||
|
||||
Generate a JSON response with the following structure:
|
||||
|
||||
{{
|
||||
"Paper Summary Report": {{
|
||||
"overview": {{
|
||||
"total_papers": "[Number of papers]",
|
||||
"time_range": {{"start_year": "[Start year]", "end_year": "[End year]"}},
|
||||
"main_research_areas": "[Main research areas of the papers]"
|
||||
}},
|
||||
"research_trends": {{
|
||||
"major_themes": "[Major themes areas of the papers]",
|
||||
"common_methodologies": "[Common methodologies used in the papers]",
|
||||
"emerging_topics": "[Emerging topics in the papers]"
|
||||
}},
|
||||
"key_findings": {{
|
||||
"theoretical_advances": "[Theoretical advances in the papers]",
|
||||
"experimental_results": "[Experimental results in the papers]",
|
||||
"common_conclusions": "[Common conclusions in the papers]"
|
||||
}},
|
||||
"research_gaps": {{
|
||||
"current_limitations": "[Current limitations]",
|
||||
"unexplored_areas": "[Unexplored areas]",
|
||||
"technical_challenges": "[Technical challenges"
|
||||
}},
|
||||
"future_directions": {{
|
||||
"potential_applications": "[Potential future applications]",
|
||||
"methodological_suggestions": "[Methodological suggestions based on paper summary]"
|
||||
}},
|
||||
"impact_assessment": {{
|
||||
"academic_influence": "[Summarize the academic influence of the paper]",
|
||||
"practical_value": "[Summarize the practical value of the paper]"
|
||||
}}
|
||||
}}
|
||||
}}
|
||||
"""
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_prompt}
|
||||
]
|
||||
|
||||
try:
|
||||
response = client.chat.completions.create(
|
||||
model="deepseek-chat",
|
||||
messages=messages,
|
||||
response_format={'type': 'json_object'}
|
||||
)
|
||||
return json.loads(response.choices[0].message.content)
|
||||
except Exception as e:
|
||||
print(f"Error calling DeepSeek API: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def run_analysis_in_thread(project_id: str, reference_ids: List[str]):
|
||||
"""在独立线程中运行分析任务"""
|
||||
# 创建新的事件循环
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
# 在新的事件循环中运行异步任务
|
||||
loop.run_until_complete(process_reference_analysis(project_id, reference_ids))
|
||||
except Exception as e:
|
||||
print(f"Error in analysis thread: {str(e)}")
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
async def process_reference_analysis(project_id: str, reference_ids: List[str]):
|
||||
"""后台处理文献分析任务"""
|
||||
redis = await get_redis()
|
||||
|
||||
try:
|
||||
# 更新状态键
|
||||
await redis.select(204)
|
||||
status_key = f"reference_analysis_status:{project_id}"
|
||||
|
||||
# 从Redis db203获取所有实验报告
|
||||
await redis.select(203)
|
||||
all_reference_reports = {}
|
||||
completed_count = 0
|
||||
|
||||
for ref_id in reference_ids:
|
||||
report_key = f"reference_report:{ref_id}"
|
||||
report_data = await redis.get(report_key)
|
||||
|
||||
if report_data:
|
||||
try:
|
||||
parsed_report = json.loads(report_data)
|
||||
all_reference_reports[ref_id] = parsed_report
|
||||
completed_count += 1
|
||||
|
||||
# 更新进度
|
||||
await redis.select(204)
|
||||
current_status = await redis.get(status_key)
|
||||
if current_status:
|
||||
status_data = json.loads(current_status)
|
||||
status_data["completed_references"] = completed_count
|
||||
await redis.set(status_key, json.dumps(status_data))
|
||||
await redis.select(203)
|
||||
|
||||
except json.JSONDecodeError:
|
||||
print(f"Error parsing reference {ref_id} report")
|
||||
continue
|
||||
|
||||
if all_reference_reports:
|
||||
analysis_result = await analyze_reference_summary(all_reference_reports)
|
||||
|
||||
if analysis_result:
|
||||
await redis.select(204)
|
||||
report_key = f"reference_summary_report:{project_id}"
|
||||
await redis.set(report_key, json.dumps(analysis_result))
|
||||
|
||||
# 更新最终状态
|
||||
status_data = {
|
||||
"status": "completed",
|
||||
"completion_time": datetime.now(timezone.utc).isoformat(),
|
||||
"total_references": len(reference_ids),
|
||||
"completed_references": completed_count
|
||||
}
|
||||
await redis.set(status_key, json.dumps(status_data))
|
||||
|
||||
print(f"Reference analysis completed for project {project_id}")
|
||||
else:
|
||||
# 更新失败状态
|
||||
status_data = {
|
||||
"status": "failed",
|
||||
"error": "Failed to generate analysis result",
|
||||
"completion_time": datetime.now(timezone.utc).isoformat()
|
||||
}
|
||||
await redis.set(status_key, json.dumps(status_data))
|
||||
print(f"Failed to generate analysis for project {project_id}")
|
||||
else:
|
||||
# 更新失败状态
|
||||
status_data = {
|
||||
"status": "failed",
|
||||
"error": "No valid reference reports found",
|
||||
"completion_time": datetime.now(timezone.utc).isoformat()
|
||||
}
|
||||
await redis.set(status_key, json.dumps(status_data))
|
||||
print(f"No valid reference reports found for project {project_id}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error in background reference analysis: {str(e)}")
|
||||
try:
|
||||
await redis.select(204)
|
||||
status_data = {
|
||||
"status": "failed",
|
||||
"error": str(e),
|
||||
"completion_time": datetime.now(timezone.utc).isoformat()
|
||||
}
|
||||
await redis.set(status_key, json.dumps(status_data))
|
||||
except Exception as redis_error:
|
||||
print(f"Error updating Redis status: {redis_error}")
|
||||
finally:
|
||||
try:
|
||||
await redis.aclose()
|
||||
except Exception as e:
|
||||
print(f"Error closing Redis connection: {e}")
|
||||
|
||||
|
||||
|
||||
|
||||
async def process_reference_question(task_id: str, reference_id: str, question: str, reference: dict):
|
||||
"""处理文献问答的后台任务"""
|
||||
redis = await get_redis()
|
||||
|
||||
try:
|
||||
# 更新任务状态为处理中
|
||||
await redis.select(208) # 使用db208存储任务状态
|
||||
await redis.hset(
|
||||
f"task:{task_id}",
|
||||
mapping={
|
||||
"status": TaskStatus.PROCESSING,
|
||||
"reference_id": reference_id,
|
||||
"question": question
|
||||
}
|
||||
)
|
||||
|
||||
# 从Redis获取分析报告
|
||||
await redis.select(203)
|
||||
report_key = f"reference_report:{reference_id}"
|
||||
report_data = await redis.get(report_key)
|
||||
|
||||
if not report_data:
|
||||
raise Exception("文献分析报告不存在,请先进行分析")
|
||||
|
||||
# 读取PDF文件内容
|
||||
file_path = reference.get("reference_link")
|
||||
if not file_path or not os.path.exists(file_path):
|
||||
raise Exception("文献文件不存在")
|
||||
|
||||
# 读取PDF内容
|
||||
pdf_content = ""
|
||||
with open(file_path, 'rb') as file:
|
||||
pdf_reader = PyPDF2.PdfReader(file)
|
||||
for page in pdf_reader.pages:
|
||||
content = page.extract_text()
|
||||
pdf_content += content
|
||||
|
||||
# 限制内容长度
|
||||
MAX_CHARS = 180000
|
||||
pdf_content = pdf_content[:MAX_CHARS]
|
||||
|
||||
# 构建系统提示和用户提示
|
||||
system_prompt = """
|
||||
你是一个专业的学术助手,负责回答关于学术文献的问题。
|
||||
你应该基于文献内容和分析报告提供准确、专业的回答。
|
||||
回答应当简洁明了,并尽可能引用文献中的具体内容。
|
||||
"""
|
||||
|
||||
# 构建上下文
|
||||
context = {
|
||||
"文献内容": pdf_content,
|
||||
"分析报告": json.loads(report_data)
|
||||
}
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": f"基于以下文献内容和分析报告回答问题:\n\n文献信息:{json.dumps(context, ensure_ascii=False)}\n\n问题:{question}"}
|
||||
]
|
||||
|
||||
# 调用DeepSeek API获取回答
|
||||
response = client.chat.completions.create(
|
||||
model="deepseek-chat",
|
||||
messages=messages
|
||||
)
|
||||
answer = response.choices[0].message.content
|
||||
|
||||
# 保存对话历史到Redis db207
|
||||
await redis.select(207)
|
||||
chat_history_key = f"chat_history:{reference_id}"
|
||||
|
||||
# 获取现有历史记录
|
||||
existing_history = await redis.get(chat_history_key)
|
||||
history = json.loads(existing_history) if existing_history else []
|
||||
|
||||
# 添加新的对话
|
||||
history.append({
|
||||
"question": question,
|
||||
"answer": answer,
|
||||
"timestamp": datetime.now(timezone.utc).isoformat()
|
||||
})
|
||||
|
||||
# 保存更新后的历史记录
|
||||
await redis.set(chat_history_key, json.dumps(history))
|
||||
|
||||
# 更新任务状态为完成
|
||||
await redis.select(208)
|
||||
await redis.hset(
|
||||
f"task:{task_id}",
|
||||
mapping={
|
||||
"status": TaskStatus.COMPLETED,
|
||||
"answer": answer,
|
||||
"reference_title": reference.get("reference_title")
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error processing question: {e}")
|
||||
# 更新任务状态为失败
|
||||
await redis.select(208)
|
||||
await redis.hset(
|
||||
f"task:{task_id}",
|
||||
mapping={
|
||||
"status": TaskStatus.FAILED,
|
||||
"error": str(e)
|
||||
}
|
||||
)
|
||||
finally:
|
||||
try:
|
||||
await redis.aclose()
|
||||
except Exception as e:
|
||||
print(f"Error closing Redis connection: {e}")
|
||||
from ..cores.config import client
|
||||
|
||||
|
||||
# 在全局范围创建线程池
|
||||
|
||||
@@ -0,0 +1,197 @@
|
||||
import json
|
||||
import asyncio
|
||||
from datetime import datetime, timezone
|
||||
from typing import List, Dict
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from ..cores.config import THREAD_POOL_CONFIG, client
|
||||
from ..cores.db import get_redis, get_redis_key
|
||||
|
||||
# 创建线程池
|
||||
summary_executor = ThreadPoolExecutor(max_workers=THREAD_POOL_CONFIG["summary_workers"])
|
||||
|
||||
# 修改函数定义为异步函数
|
||||
async def analyze_reference_summary(reference_data):
|
||||
system_prompt = """
|
||||
You are an AI assistant responsible for analyzing paper reports.
|
||||
You will summarize and analyze all paper reports and generate a comprehensive analysis report in JSON format.
|
||||
The JSON structure must strictly follow the provided template.
|
||||
"""
|
||||
|
||||
user_prompt = f"""Analyze the following paper reports:
|
||||
Paper reports: {json.dumps(reference_data, ensure_ascii=False)}
|
||||
|
||||
Generate a JSON response with the following structure:
|
||||
|
||||
{{
|
||||
"Paper Summary Report": {{
|
||||
"overview": {{
|
||||
"total_papers": "[Number of papers]",
|
||||
"time_range": {{"start_year": "[Start year]", "end_year": "[End year]"}},
|
||||
"main_research_areas": "[Main research areas of the papers]"
|
||||
}},
|
||||
"research_trends": {{
|
||||
"major_themes": "[Major themes areas of the papers]",
|
||||
"common_methodologies": "[Common methodologies used in the papers]",
|
||||
"emerging_topics": "[Emerging topics in the papers]"
|
||||
}},
|
||||
"key_findings": {{
|
||||
"theoretical_advances": "[Theoretical advances in the papers]",
|
||||
"experimental_results": "[Experimental results in the papers]",
|
||||
"common_conclusions": "[Common conclusions in the papers]"
|
||||
}},
|
||||
"research_gaps": {{
|
||||
"current_limitations": "[Current limitations]",
|
||||
"unexplored_areas": "[Unexplored areas]",
|
||||
"technical_challenges": "[Technical challenges"
|
||||
}},
|
||||
"future_directions": {{
|
||||
"potential_applications": "[Potential future applications]",
|
||||
"methodological_suggestions": "[Methodological suggestions based on paper summary]"
|
||||
}},
|
||||
"impact_assessment": {{
|
||||
"academic_influence": "[Summarize the academic influence of the paper]",
|
||||
"practical_value": "[Summarize the practical value of the paper]"
|
||||
}}
|
||||
}}
|
||||
}}
|
||||
"""
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_prompt}
|
||||
]
|
||||
|
||||
try:
|
||||
response = client.chat.completions.create(
|
||||
model="deepseek-chat",
|
||||
messages=messages,
|
||||
response_format={'type': 'json_object'}
|
||||
)
|
||||
return json.loads(response.choices[0].message.content)
|
||||
except Exception as e:
|
||||
print(f"Error calling DeepSeek API: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def run_analysis_in_thread(project_id: str, reference_ids: List[str]):
|
||||
"""在独立线程中运行分析任务"""
|
||||
# 使用summary_executor来执行任务
|
||||
summary_executor.submit(lambda: asyncio.run(process_reference_analysis(project_id, reference_ids)))
|
||||
|
||||
async def process_reference_analysis(project_id: str, reference_ids: List[str]):
|
||||
"""后台处理文献分析任务"""
|
||||
redis = await get_redis()
|
||||
|
||||
try:
|
||||
# 更新状态键
|
||||
db_number, status_key = await get_redis_key(
|
||||
"reference_summary",
|
||||
"status",
|
||||
project_id=project_id
|
||||
)
|
||||
await redis.select(db_number)
|
||||
|
||||
# 获取所有实验报告
|
||||
db_number, report_key = await get_redis_key(
|
||||
"reference_analysis",
|
||||
"report",
|
||||
reference_id=ref_id
|
||||
)
|
||||
await redis.select(db_number)
|
||||
all_reference_reports = {}
|
||||
completed_count = 0
|
||||
|
||||
for ref_id in reference_ids:
|
||||
report_key = f"reference_report:{ref_id}"
|
||||
report_data = await redis.get(report_key)
|
||||
|
||||
if report_data:
|
||||
try:
|
||||
parsed_report = json.loads(report_data)
|
||||
all_reference_reports[ref_id] = parsed_report
|
||||
completed_count += 1
|
||||
|
||||
# 更新进度
|
||||
db_number, status_key = await get_redis_key(
|
||||
"reference_summary",
|
||||
"status",
|
||||
project_id=project_id
|
||||
)
|
||||
await redis.select(db_number)
|
||||
current_status = await redis.get(status_key)
|
||||
if current_status:
|
||||
status_data = json.loads(current_status)
|
||||
status_data["completed_references"] = completed_count
|
||||
await redis.set(status_key, json.dumps(status_data))
|
||||
|
||||
except json.JSONDecodeError:
|
||||
print(f"Error parsing reference {ref_id} report")
|
||||
continue
|
||||
|
||||
if all_reference_reports:
|
||||
analysis_result = await analyze_reference_summary(all_reference_reports)
|
||||
|
||||
if analysis_result:
|
||||
# 保存汇总报告
|
||||
db_number, report_key = await get_redis_key(
|
||||
"reference_summary",
|
||||
"report",
|
||||
project_id=project_id
|
||||
)
|
||||
await redis.select(db_number)
|
||||
await redis.set(report_key, json.dumps(analysis_result))
|
||||
|
||||
# 更新最终状态
|
||||
db_number, status_key = await get_redis_key(
|
||||
"reference_summary",
|
||||
"status",
|
||||
project_id=project_id
|
||||
)
|
||||
await redis.select(db_number)
|
||||
status_data = {
|
||||
"status": "completed",
|
||||
"completion_time": datetime.now(timezone.utc).isoformat(),
|
||||
"total_references": len(reference_ids),
|
||||
"completed_references": completed_count
|
||||
}
|
||||
await redis.set(status_key, json.dumps(status_data))
|
||||
|
||||
print(f"Reference analysis completed for project {project_id}")
|
||||
else:
|
||||
# 更新失败状态
|
||||
await redis.set(status_key, json.dumps({
|
||||
"status": "failed",
|
||||
"error": "Failed to generate analysis result",
|
||||
"completion_time": datetime.now(timezone.utc).isoformat()
|
||||
}))
|
||||
print(f"Failed to generate analysis for project {project_id}")
|
||||
else:
|
||||
# 更新失败状态
|
||||
await redis.set(status_key, json.dumps({
|
||||
"status": "failed",
|
||||
"error": "No valid reference reports found",
|
||||
"completion_time": datetime.now(timezone.utc).isoformat()
|
||||
}))
|
||||
print(f"No valid reference reports found for project {project_id}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error in background reference analysis: {str(e)}")
|
||||
try:
|
||||
db_number, status_key = await get_redis_key(
|
||||
"reference_summary",
|
||||
"status",
|
||||
project_id=project_id
|
||||
)
|
||||
await redis.select(db_number)
|
||||
await redis.set(status_key, json.dumps({
|
||||
"status": "failed",
|
||||
"error": str(e),
|
||||
"completion_time": datetime.now(timezone.utc).isoformat()
|
||||
}))
|
||||
except Exception as redis_error:
|
||||
print(f"Error updating Redis status: {redis_error}")
|
||||
finally:
|
||||
try:
|
||||
await redis.aclose()
|
||||
except Exception as e:
|
||||
print(f"Error closing Redis connection: {e}")
|
||||
@@ -0,0 +1,366 @@
|
||||
from fastapi import APIRouter, HTTPException, Depends
|
||||
from typing import Dict, List, Optional, Any
|
||||
from datetime import datetime, timezone
|
||||
import json
|
||||
import asyncio
|
||||
from bson import ObjectId
|
||||
from motor.motor_asyncio import AsyncIOMotorClient
|
||||
from redis import asyncio as aioredis
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from ..cores.config import MONGODB_URL, REDIS_URL, TaskStatus, THREAD_POOL_CONFIG, client
|
||||
from ..cores.db import get_redis, get_redis_key
|
||||
|
||||
# 创建项目分析线程池
|
||||
pro_analysis_executor = ThreadPoolExecutor(
|
||||
max_workers=THREAD_POOL_CONFIG["pro_analysis_workers"],
|
||||
thread_name_prefix="pro_analysis"
|
||||
)
|
||||
|
||||
|
||||
async def analyze_project_data(project_data: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
||||
system_prompt = """
|
||||
You are an AI assistant responsible for analyzing lab reports.
|
||||
You will summarize and analyze all lab reports and generate a comprehensive analysis report in JSON format.
|
||||
The JSON structure must strictly follow the provided template.
|
||||
"""
|
||||
|
||||
user_prompt = f"""The project report is summarized based on the following experimental analysis reports:
|
||||
Experimental analysis reports: {json.dumps(project_data, ensure_ascii=False)}
|
||||
|
||||
Generate a JSON response with the following structure:
|
||||
|
||||
{{
|
||||
"Project Analysis Report": {{
|
||||
"1. Project Overview": {{
|
||||
"Project Name": "{project_data['project_stats']['project_name']}",
|
||||
"Total Experiments": "{project_data['project_stats']['total_experiments']}",
|
||||
"Total Data Points": "{project_data['project_stats']['total_data_points']}",
|
||||
}},
|
||||
"2. Aggregated Statistics": {{
|
||||
"Total Sessions": {project_data['project_stats']['total_sessions']},
|
||||
"Total Duration": {project_data['project_stats']['total_duration']},
|
||||
"Total Data Points": {project_data['project_stats']['total_data_points']},
|
||||
"Average Session Duration": {project_data['project_stats']['avg_session_duration']},
|
||||
"Average Data Points per Session": {project_data['project_stats']['avg_data_points_per_session']}
|
||||
}},
|
||||
"3. Performance Analysis": {{
|
||||
"Best Performing Sessions": [
|
||||
{{
|
||||
"Session ID": "session_id",
|
||||
"experiment_id": "experiment_id",
|
||||
"experiment_name": "experiment_name",
|
||||
"Duration": "duration",
|
||||
"Data Points": "data_points",
|
||||
"Success Factors": "success_factors"
|
||||
}}
|
||||
],
|
||||
"Problematic Sessions": [
|
||||
{{
|
||||
"Session ID": "session_id",
|
||||
"experiment_id": "experiment_id",
|
||||
"experiment_name": "experiment_name",
|
||||
"Issues": "specific_issues",
|
||||
"Possible Causes": "possible_causes"
|
||||
}}
|
||||
]
|
||||
}},
|
||||
"4. Common Findings": {{
|
||||
"Recurring Issues": "common_failure_modes",
|
||||
"Equipment performance": "equipment_performance",
|
||||
"Sensor reliability analysis": "sensor_reliability_analysis"
|
||||
}},
|
||||
"5. Recommendations": {{
|
||||
"Equipment optimization suggestions": "equipment_optimization_suggestions",
|
||||
"Experiment process improvement suggestions": "experiment_process_improvement_suggestions",
|
||||
"Data collection strategy adjustment": "data_collection_strategy_adjustment"
|
||||
}}
|
||||
}}
|
||||
}}
|
||||
"""
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_prompt}
|
||||
]
|
||||
|
||||
try:
|
||||
response = client.chat.completions.create(
|
||||
model="deepseek-chat",
|
||||
messages=messages,
|
||||
response_format={'type': 'json_object'}
|
||||
)
|
||||
return json.loads(response.choices[0].message.content)
|
||||
except Exception as e:
|
||||
print(f"Error calling DeepSeek API: {e}")
|
||||
return None
|
||||
|
||||
def run_project_in_thread(project_id: str):
|
||||
"""在独立线程中运行分析任务"""
|
||||
# 创建新的事件循环
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
# 在新的事件循环中运行异步任务
|
||||
loop.run_until_complete(process_project_analysis(project_id))
|
||||
except Exception as e:
|
||||
print(f"Error in analysis thread: {str(e)}")
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
async def process_project_analysis(project_id: str):
|
||||
"""后台处理项目分析任务"""
|
||||
print(f"\n=== 开始项目分析 ===")
|
||||
print(f"项目ID: {project_id}")
|
||||
|
||||
# 创建新的数据库和Redis连接
|
||||
mongo_client = AsyncIOMotorClient(MONGODB_URL)
|
||||
db = mongo_client["lab"]
|
||||
redis = await aioredis.from_url(REDIS_URL, encoding="utf-8", decode_responses=True)
|
||||
|
||||
try:
|
||||
db_number, status_key = await get_redis_key(
|
||||
"project_analysis",
|
||||
"status", # 使用 status 键模式
|
||||
project_id=project_id
|
||||
)
|
||||
await redis.select(db_number)
|
||||
|
||||
# 更新状态为进行中
|
||||
status_data = {
|
||||
"status": "processing",
|
||||
"start_time": datetime.now(timezone.utc).isoformat()
|
||||
}
|
||||
await redis.set(status_key, json.dumps(status_data))
|
||||
|
||||
# 获取项目名称
|
||||
project = await db.projects.find_one({"_id": ObjectId(project_id)})
|
||||
project_name = project["project_name"] if project else "Unknown Project"
|
||||
|
||||
# 从MongoDB查找项目下的所有实验
|
||||
experiment_cursor = db.experiments.find({
|
||||
"project_id": ObjectId(project_id)
|
||||
})
|
||||
|
||||
experiment_ids = []
|
||||
experiment_names = {} # 存储实验ID和名称的映射
|
||||
async for experiment in experiment_cursor:
|
||||
exp_id = str(experiment["_id"])
|
||||
experiment_ids.append(exp_id)
|
||||
experiment_names[exp_id] = experiment["experiment_name"]
|
||||
|
||||
if not experiment_ids:
|
||||
print("错误: 未找到任何实验")
|
||||
raise Exception("No experiments found in this project")
|
||||
|
||||
# 从Redis db201获取所有实验报告
|
||||
await redis.select(201)
|
||||
all_experiment_reports = {}
|
||||
|
||||
# 统计数据初始化
|
||||
total_sessions = 0
|
||||
total_duration = 0
|
||||
total_data_points = 0
|
||||
total_devices = set()
|
||||
total_sensors = 0
|
||||
|
||||
for exp_id in experiment_ids:
|
||||
report_key = f"experiment_report:{exp_id}"
|
||||
report_data = await redis.get(report_key)
|
||||
|
||||
if report_data:
|
||||
try:
|
||||
parsed_report = json.loads(report_data)
|
||||
all_experiment_reports[exp_id] = parsed_report
|
||||
|
||||
# 从Basic Information中提取数据
|
||||
basic_info = parsed_report["Experiment Analysis Report"]["1. Basic Information"]
|
||||
total_sessions += int(basic_info["Total Sessions"])
|
||||
total_duration += float(basic_info["Total Duration"].split()[0]) # 去掉"seconds"
|
||||
total_data_points += int(basic_info["Data Points"])
|
||||
total_devices.update([str(i) for i in range(int(basic_info["Devices number"]))])
|
||||
total_sensors += int(basic_info["Sensors number"])
|
||||
|
||||
except (json.JSONDecodeError, KeyError) as e:
|
||||
continue
|
||||
|
||||
if not all_experiment_reports:
|
||||
raise Exception("No experiment reports found")
|
||||
|
||||
# 计算平均值
|
||||
avg_session_duration = total_duration / total_sessions if total_sessions > 0 else 0
|
||||
avg_data_points_per_session = total_data_points / total_sessions if total_sessions > 0 else 0
|
||||
|
||||
# 添加项目级统计数据
|
||||
project_stats = {
|
||||
"project_name": project_name,
|
||||
"total_experiments": len(experiment_ids),
|
||||
"total_sessions": total_sessions,
|
||||
"total_duration": total_duration,
|
||||
"total_data_points": total_data_points,
|
||||
"total_devices": len(total_devices),
|
||||
"total_sensors": total_sensors,
|
||||
"avg_session_duration": avg_session_duration,
|
||||
"avg_data_points_per_session": avg_data_points_per_session,
|
||||
"experiment_names": experiment_names
|
||||
}
|
||||
|
||||
# 将统计数据和原始报告一起发送给分析函数
|
||||
project_data = {
|
||||
"project_stats": project_stats,
|
||||
"experiment_reports": all_experiment_reports
|
||||
}
|
||||
|
||||
# 执行分析
|
||||
analysis_result = await analyze_project_data(project_data)
|
||||
|
||||
if analysis_result:
|
||||
# 保存分析结果
|
||||
db_number, report_key = await get_redis_key(
|
||||
"project_analysis",
|
||||
"report", # 使用 report 键模式
|
||||
project_id=project_id
|
||||
)
|
||||
await redis.select(db_number)
|
||||
await redis.set(report_key, json.dumps(analysis_result))
|
||||
|
||||
# 更新状态为完成
|
||||
await redis.select(db_number) # 使用相同的 db_number
|
||||
await redis.set(status_key, json.dumps({
|
||||
"status": "completed",
|
||||
"completion_time": datetime.now(timezone.utc).isoformat()
|
||||
}))
|
||||
print("分析报告已保存")
|
||||
else:
|
||||
print("\n7. 分析失败")
|
||||
# 更新失败状态
|
||||
await redis.set(status_key, json.dumps({
|
||||
"status": "failed",
|
||||
"error": "Failed to generate analysis result",
|
||||
"completion_time": datetime.now(timezone.utc).isoformat()
|
||||
}))
|
||||
|
||||
except Exception as e:
|
||||
try:
|
||||
await redis.select(db_number)
|
||||
await redis.set(status_key, json.dumps({
|
||||
"status": "failed",
|
||||
"error": str(e),
|
||||
"completion_time": datetime.now(timezone.utc).isoformat()
|
||||
}))
|
||||
except Exception as redis_error:
|
||||
print(f"Error updating Redis status: {redis_error}")
|
||||
finally:
|
||||
try:
|
||||
await redis.aclose()
|
||||
mongo_client.close()
|
||||
except Exception as e:
|
||||
print(f"Error closing Redis connection: {e}")
|
||||
|
||||
async def process_project_question(task_id: str, project_id: str, question: str, project: dict):
|
||||
"""处理项目问答的后台任务"""
|
||||
redis = await get_redis()
|
||||
|
||||
try:
|
||||
# 更新任务状态为处理中
|
||||
db_number, task_key = await get_redis_key(
|
||||
"qa_task",
|
||||
"key_pattern",
|
||||
task_id=task_id
|
||||
)
|
||||
await redis.select(db_number)
|
||||
await redis.hset(
|
||||
task_key,
|
||||
mapping={
|
||||
"status": TaskStatus.PROCESSING,
|
||||
"project_id": project_id,
|
||||
"question": question
|
||||
}
|
||||
)
|
||||
|
||||
# 从Redis获取分析报告
|
||||
db_number, report_key = await get_redis_key(
|
||||
"project_analysis",
|
||||
"report",
|
||||
project_id=project_id
|
||||
)
|
||||
await redis.select(db_number)
|
||||
report_data = await redis.get(report_key)
|
||||
|
||||
if not report_data:
|
||||
raise Exception("项目分析报告不存在,请先进行分析")
|
||||
|
||||
# 构建系统提示和用户提示
|
||||
system_prompt = """
|
||||
你是一个专业的项目助手,负责回答关于项目报告的问题。
|
||||
你应该基于项目报告提供准确、专业的回答。
|
||||
回答应当简洁明了,并尽可能引用报告中的具体内容。
|
||||
"""
|
||||
|
||||
# 构建上下文
|
||||
context = {
|
||||
"项目名称": project.get("project_name", "未知项目"),
|
||||
"分析报告": json.loads(report_data)
|
||||
}
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": f"基于以下项目报告回答问题:\n\n项目信息:{json.dumps(context, ensure_ascii=False)}\n\n问题:{question}"}
|
||||
]
|
||||
|
||||
# 调用DeepSeek API获取回答
|
||||
response = client.chat.completions.create(
|
||||
model="deepseek-chat",
|
||||
messages=messages
|
||||
)
|
||||
answer = response.choices[0].message.content
|
||||
|
||||
# 保存对话历史到Redis db207
|
||||
db_number, chat_history_key = await get_redis_key(
|
||||
"chat_history",
|
||||
"project",
|
||||
project_id=project_id
|
||||
)
|
||||
await redis.select(db_number)
|
||||
|
||||
# 获取现有历史记录
|
||||
existing_history = await redis.get(chat_history_key)
|
||||
history = json.loads(existing_history) if existing_history else []
|
||||
|
||||
# 添加新的对话
|
||||
history.append({
|
||||
"question": question,
|
||||
"answer": answer,
|
||||
"timestamp": datetime.now(timezone.utc).isoformat()
|
||||
})
|
||||
|
||||
# 保存更新后的历史记录
|
||||
await redis.set(chat_history_key, json.dumps(history))
|
||||
|
||||
# 更新任务状态为完成
|
||||
await redis.select(db_number)
|
||||
await redis.hset(
|
||||
task_key,
|
||||
mapping={
|
||||
"status": TaskStatus.COMPLETED,
|
||||
"answer": answer,
|
||||
"project_name": project.get("project_name")
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error processing question: {e}")
|
||||
# 更新任务状态为失败
|
||||
await redis.select(db_number)
|
||||
await redis.hset(
|
||||
task_key,
|
||||
mapping={
|
||||
"status": TaskStatus.FAILED,
|
||||
"error": str(e)
|
||||
}
|
||||
)
|
||||
finally:
|
||||
try:
|
||||
await redis.aclose()
|
||||
except Exception as e:
|
||||
print(f"Error closing Redis connection: {e}")
|
||||
@@ -2,7 +2,7 @@ from fastapi import APIRouter, HTTPException, Depends
|
||||
from typing import List
|
||||
from datetime import datetime, timezone
|
||||
from bson import ObjectId
|
||||
from ..cores.db import PyObjectId, get_database
|
||||
from ..cores.db import get_database
|
||||
from .login import get_current_user, UserModel
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@@ -1,43 +1,23 @@
|
||||
from fastapi import APIRouter, HTTPException, Depends, Query
|
||||
from typing import List, Optional
|
||||
from datetime import datetime, timezone
|
||||
from bson import ObjectId
|
||||
from pydantic import BaseModel, Field
|
||||
import json
|
||||
from ..cores.db import PyObjectId, get_database, get_redis
|
||||
from .login import get_current_user, UserModel
|
||||
from ..cores.db import get_database, get_redis, get_redis_key
|
||||
from ..models.basemodel import ExperimentModel, ExperimentCreate, UserModel, ExperimentStatus,FormulaCreate
|
||||
from .login import get_current_user
|
||||
from io import BytesIO, StringIO
|
||||
import zipfile
|
||||
import csv
|
||||
from fastapi.responses import StreamingResponse
|
||||
from fastapi import APIRouter, HTTPException, Depends
|
||||
|
||||
router = APIRouter()
|
||||
router = APIRouter(
|
||||
prefix="/lab",
|
||||
tags=["experiment"]
|
||||
)
|
||||
|
||||
class ExperimentStatus:
|
||||
ACTIVE = "active"
|
||||
COMPLETED = "completed"
|
||||
|
||||
class ExperimentModel(BaseModel):
|
||||
id: Optional[PyObjectId] = Field(alias="_id", default=None)
|
||||
project_id: PyObjectId
|
||||
experiment_name: str
|
||||
create_time: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
description: str
|
||||
status: str = Field(default=ExperimentStatus.ACTIVE) # 添加状态字段
|
||||
|
||||
class Config:
|
||||
populate_by_name = True
|
||||
arbitrary_types_allowed = True
|
||||
json_encoders = {ObjectId: str}
|
||||
|
||||
class ExperimentCreate(BaseModel):
|
||||
"""实验创建请求模型"""
|
||||
project_id: str
|
||||
experiment_name: str
|
||||
description: str | None = None # 修改为可选字段,默认为None
|
||||
|
||||
# 创建实验
|
||||
@router.post("/lab/experiments")
|
||||
@router.post("/experiments")
|
||||
async def create_experiment(
|
||||
experiment: ExperimentCreate,
|
||||
current_user: UserModel = Depends(get_current_user)
|
||||
@@ -75,7 +55,7 @@ async def create_experiment(
|
||||
raise HTTPException(status_code=500, detail=f"Failed to create experiment: {str(e)}")
|
||||
|
||||
# 获取实验列表
|
||||
@router.get("/lab/experiments")
|
||||
@router.get("/experiments")
|
||||
async def get_experiments(
|
||||
project_id: str,
|
||||
current_user: UserModel = Depends(get_current_user)
|
||||
@@ -97,38 +77,8 @@ async def get_experiments(
|
||||
experiments.append(ExperimentModel(**experiment))
|
||||
|
||||
return experiments
|
||||
# 数据模型
|
||||
class ExperimentSession(BaseModel):
|
||||
"""实会话模型,记录每次实验的开始和结束时间"""
|
||||
id: Optional[PyObjectId] = Field(alias="_id", default=None)
|
||||
experiment_id: PyObjectId
|
||||
user_id: PyObjectId
|
||||
start_time: datetime
|
||||
end_time: Optional[datetime] = None
|
||||
duration: Optional[float] = None # 持续时间(秒)
|
||||
|
||||
class Config:
|
||||
populate_by_name = True
|
||||
arbitrary_types_allowed = True
|
||||
json_encoders = {ObjectId: str}
|
||||
|
||||
class ExperimentData(BaseModel):
|
||||
"""实验数据模型"""
|
||||
id: Optional[PyObjectId] = Field(alias="_id", default=None)
|
||||
experiment_id: PyObjectId
|
||||
session_ids: List[PyObjectId] # 修改为session_ids数组
|
||||
user_id: PyObjectId
|
||||
device_id: str
|
||||
sensor_name: str
|
||||
last_update: datetime # 添加最后更新时间
|
||||
|
||||
class Config:
|
||||
populate_by_name = True
|
||||
arbitrary_types_allowed = True
|
||||
json_encoders = {ObjectId: str}
|
||||
|
||||
# 开始实验
|
||||
@router.post("/lab/experiments/{experiment_id}/start")
|
||||
@router.post("/experiments/{experiment_id}/start")
|
||||
async def start_experiment_session(
|
||||
experiment_id: str,
|
||||
current_user: UserModel = Depends(get_current_user)
|
||||
@@ -154,7 +104,13 @@ async def start_experiment_session(
|
||||
# 更新实验状态到Redis (新增)
|
||||
await redis.select(199)
|
||||
for device in experiment_devices:
|
||||
status_key = f"experiment_status:{experiment_id}:{device['serial_number']}"
|
||||
db_number, status_key = await get_redis_key(
|
||||
"experiment_status",
|
||||
"key_pattern",
|
||||
experiment_id=experiment_id,
|
||||
device_serial=device['serial_number']
|
||||
)
|
||||
await redis.select(db_number)
|
||||
await redis.set(status_key, "active")
|
||||
|
||||
# 创建实验会话记录
|
||||
@@ -182,7 +138,7 @@ async def start_experiment_session(
|
||||
await redis.aclose()
|
||||
|
||||
# 停止实验
|
||||
@router.post("/lab/experiments/{experiment_id}/stop")
|
||||
@router.post("/experiments/{experiment_id}/stop")
|
||||
async def stop_experiment_session(
|
||||
experiment_id: str,
|
||||
session_data: dict,
|
||||
@@ -210,7 +166,13 @@ async def stop_experiment_session(
|
||||
# 更新实验状态到Redis (新增)
|
||||
await redis.select(199) # 使用同一个db
|
||||
for device in session.get("devices", []):
|
||||
status_key = f"experiment_status:{experiment_id}:{device['serial_number']}"
|
||||
db_number, status_key = await get_redis_key(
|
||||
"experiment_status",
|
||||
"key_pattern",
|
||||
experiment_id=experiment_id,
|
||||
device_serial=device['serial_number']
|
||||
)
|
||||
await redis.select(db_number)
|
||||
await redis.set(status_key, "inactive")
|
||||
|
||||
# 确保start_time是带时区的
|
||||
@@ -249,7 +211,7 @@ async def stop_experiment_session(
|
||||
await redis.aclose()
|
||||
|
||||
# 获取实验会话历史的路由
|
||||
@router.get("/lab/experiments/{experiment_id}/sessions")
|
||||
@router.get("/experiments/{experiment_id}/sessions")
|
||||
async def get_experiment_sessions(
|
||||
experiment_id: str,
|
||||
current_user: UserModel = Depends(get_current_user)
|
||||
@@ -274,7 +236,7 @@ async def get_experiment_sessions(
|
||||
raise HTTPException(status_code=500, detail=f"Failed to get experiment session records: {str(e)}")
|
||||
|
||||
# 导出实验数据
|
||||
@router.get("/lab/experiments/{experiment_id}/export")
|
||||
@router.get("/experiments/{experiment_id}/export")
|
||||
async def export_experiment_data(
|
||||
experiment_id: str,
|
||||
current_user: UserModel = Depends(get_current_user)
|
||||
@@ -434,7 +396,7 @@ async def export_experiment_data(
|
||||
print(f"Error closing Redis connection: {e}")
|
||||
|
||||
# 完成实验的路由
|
||||
@router.post("/lab/experiments/{experiment_id}/complete")
|
||||
@router.post("/experiments/{experiment_id}/complete")
|
||||
async def complete_experiment(
|
||||
experiment_id: str,
|
||||
current_user: UserModel = Depends(get_current_user)
|
||||
@@ -474,7 +436,7 @@ async def complete_experiment(
|
||||
|
||||
|
||||
# 进入实验的路由,添加状态检查
|
||||
@router.get("/lab/experiments/{experiment_id}")
|
||||
@router.get("/experiments/{experiment_id}")
|
||||
async def get_experiment(
|
||||
experiment_id: str,
|
||||
current_user: UserModel = Depends(get_current_user)
|
||||
@@ -518,7 +480,7 @@ async def get_experiment(
|
||||
raise HTTPException(status_code=500, detail=f"Failed to get experiment: {str(e)}")
|
||||
|
||||
# 删除实验的路由
|
||||
@router.delete("/lab/experiments/{experiment_id}")
|
||||
@router.delete("/experiments/{experiment_id}")
|
||||
async def delete_experiment(
|
||||
experiment_id: str,
|
||||
current_user: UserModel = Depends(get_current_user)
|
||||
@@ -560,20 +522,8 @@ async def delete_experiment(
|
||||
except Exception as e:
|
||||
print(f"Error closing Redis connection: {e}")
|
||||
|
||||
# 添加公式模型
|
||||
class FormulaModel(BaseModel):
|
||||
data_name: str
|
||||
data_unit: str
|
||||
formula: str
|
||||
|
||||
# 添加公式请求模型
|
||||
class FormulaCreate(BaseModel):
|
||||
sensor_name: str
|
||||
data_name: str
|
||||
data_unit: str
|
||||
formula: str
|
||||
|
||||
@router.post("/lab/experiments/{experiment_id}/devices/{device_id}/formulas")
|
||||
@router.post("/experiments/{experiment_id}/devices/{device_id}/formulas")
|
||||
async def add_sensor_formula(
|
||||
experiment_id: str,
|
||||
device_id: str,
|
||||
|
||||
@@ -1,25 +1,15 @@
|
||||
from fastapi import APIRouter, HTTPException, Depends
|
||||
from typing import List, Optional
|
||||
from datetime import datetime, timezone
|
||||
from bson import ObjectId
|
||||
from pydantic import BaseModel, Field
|
||||
from ..cores.db import PyObjectId, get_database
|
||||
from .login import get_current_user, UserModel
|
||||
from ..cores.db import get_database
|
||||
from ..models.basemodel import UserModel
|
||||
from .login import get_current_user
|
||||
|
||||
router = APIRouter()
|
||||
router = APIRouter(
|
||||
prefix="/lab",
|
||||
tags=["experiment_device"]
|
||||
)
|
||||
|
||||
class ExperimentDeviceModel(BaseModel):
|
||||
"""实验-设备关联模型"""
|
||||
id: Optional[PyObjectId] = Field(alias="_id", default=None)
|
||||
experiment_id: PyObjectId
|
||||
user_device_id: PyObjectId
|
||||
|
||||
class Config:
|
||||
populate_by_name = True
|
||||
arbitrary_types_allowed = True
|
||||
json_encoders = {ObjectId: str}
|
||||
|
||||
@router.post("/lab/experiments/{experiment_id}/devices/{user_device_id}")
|
||||
@router.post("/experiments/{experiment_id}/devices/{user_device_id}")
|
||||
async def add_device_to_experiment(
|
||||
experiment_id: str,
|
||||
user_device_id: str,
|
||||
@@ -77,7 +67,7 @@ async def add_device_to_experiment(
|
||||
print(f"Error adding device to experiment: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"Failed to add device: {str(e)}")
|
||||
|
||||
@router.get("/lab/experiments/{experiment_id}/devices")
|
||||
@router.get("/experiments/{experiment_id}/devices")
|
||||
async def get_experiment_devices(
|
||||
experiment_id: str,
|
||||
current_user: UserModel = Depends(get_current_user)
|
||||
@@ -137,7 +127,7 @@ async def get_experiment_devices(
|
||||
print(f"Error getting experiment devices: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"Failed to get experiment devices: {str(e)}")
|
||||
|
||||
@router.get("/lab/experiments/{experiment_id}/devices/public")
|
||||
@router.get("/experiments/{experiment_id}/devices/public")
|
||||
async def get_experiment_devices_public(
|
||||
experiment_id: str
|
||||
):
|
||||
@@ -188,7 +178,7 @@ async def get_experiment_devices_public(
|
||||
print(f"Error getting experiment devices: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"Failed to get experiment devices: {str(e)}")
|
||||
|
||||
@router.delete("/lab/experiments/{experiment_id}/devices/{user_device_id}")
|
||||
@router.delete("/experiments/{experiment_id}/devices/{user_device_id}")
|
||||
async def remove_device_from_experiment(
|
||||
experiment_id: str,
|
||||
user_device_id: str,
|
||||
|
||||
@@ -1,245 +1,22 @@
|
||||
from fastapi import APIRouter, HTTPException, Depends
|
||||
from typing import Dict, Any
|
||||
import json
|
||||
from datetime import datetime, timezone
|
||||
import asyncio
|
||||
from motor.motor_asyncio import AsyncIOMotorClient
|
||||
from bson import ObjectId
|
||||
from redis import asyncio as aioredis
|
||||
from openai import OpenAI
|
||||
from pydantic import BaseModel
|
||||
from ..cores.config import (
|
||||
MONGODB_URL, REDIS_URL, DEEPSEEK_API_CONFIG,
|
||||
REDIS_DB_MAP, TaskStatus, THREAD_POOL_CONFIG, format_response,
|
||||
REDIS_DB_MAP, TaskStatus, format_response,
|
||||
exp_analysis_thread_pool
|
||||
)
|
||||
|
||||
from ..cores.db import get_database, get_redis
|
||||
from .login import get_current_user, UserModel
|
||||
from ..cores.db import get_database, get_redis, get_redis_key
|
||||
from .login import get_current_user
|
||||
from ..models.basemodel import QuestionModel,UserModel
|
||||
from ..models.experiment import run_experiment_in_thread,process_experiment_question
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/lab",
|
||||
tags=["experiment_report"]
|
||||
)
|
||||
|
||||
# DeepSeek API Configuration
|
||||
client = OpenAI(
|
||||
base_url=DEEPSEEK_API_CONFIG["base_url"],
|
||||
api_key=DEEPSEEK_API_CONFIG["api_key"]
|
||||
)
|
||||
|
||||
async def analyze_experiment_data(experiment_info):
|
||||
system_prompt = """
|
||||
You are an AI assistant tasked with analyzing experimental data.
|
||||
Generate a comprehensive experiment analysis report in JSON format.
|
||||
The JSON structure must strictly follow the provided template.
|
||||
"""
|
||||
|
||||
user_prompt = f"""Analyze the experiment data based on the following information:
|
||||
Experiment data: {json.dumps(experiment_info['sessions'], ensure_ascii=False)}
|
||||
|
||||
Generate a JSON response with the following structure:
|
||||
|
||||
{{
|
||||
"Experiment Analysis Report": {{
|
||||
"1. Basic Information": {{
|
||||
"Total Sessions": "{experiment_info['total_sessions']}",
|
||||
"Total Duration": "{experiment_info['total_duration']} seconds",
|
||||
"Data Points": "{experiment_info['total_points']}",
|
||||
"Devices number": "{experiment_info['device_stats']['total_devices']}",
|
||||
"Sensors number": "{experiment_info['device_stats']['total_sensors']}"
|
||||
}},
|
||||
"2. Session Data Analysis": {{
|
||||
"[Session ID]": {{
|
||||
"Duration": "[Duration] seconds",
|
||||
"Data Points": "[data_points]"
|
||||
}},
|
||||
// Repeat for each session dynamically
|
||||
}},
|
||||
"3. Key Findings": [
|
||||
"[Finding 1]",
|
||||
"[Finding 2]",
|
||||
"[Finding 3]"
|
||||
],
|
||||
"4. Recommendations": [
|
||||
"[Recommendation 1]",
|
||||
"[Recommendation 2]",
|
||||
"[Recommendation 3]"
|
||||
]
|
||||
}}
|
||||
}}
|
||||
"""
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_prompt}
|
||||
]
|
||||
|
||||
try:
|
||||
response = client.chat.completions.create(
|
||||
model="deepseek-chat",
|
||||
messages=messages,
|
||||
response_format={'type': 'json_object'}
|
||||
)
|
||||
return json.loads(response.choices[0].message.content)
|
||||
except Exception as e:
|
||||
print(f"Error calling DeepSeek API: {e}")
|
||||
return None
|
||||
|
||||
def run_experiment_in_thread(experiment_id: str):
|
||||
"""在独立线程中运行分析任务"""
|
||||
# 创建新的事件循环
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
# 在新的事件循环中运行异步任务
|
||||
loop.run_until_complete(process_experiment_analysis(experiment_id))
|
||||
except Exception as e:
|
||||
print(f"Error in analysis thread: {str(e)}")
|
||||
finally:
|
||||
try:
|
||||
# 清理所有待处理的任务
|
||||
pending = asyncio.all_tasks(loop)
|
||||
for task in pending:
|
||||
task.cancel()
|
||||
# 运行直到所有任务完成
|
||||
if pending:
|
||||
loop.run_until_complete(asyncio.gather(*pending, return_exceptions=True))
|
||||
except Exception as e:
|
||||
print(f"Error cleaning up tasks: {str(e)}")
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
async def process_experiment_analysis(experiment_id: str):
|
||||
"""后台处理实验分析任务"""
|
||||
print(f"\n=== 开始实验分析 ===")
|
||||
print(f"实验ID: {experiment_id}")
|
||||
|
||||
# 创建新的数据库和Redis连接
|
||||
mongo_client = AsyncIOMotorClient(MONGODB_URL)
|
||||
db = mongo_client["lab"]
|
||||
redis = await aioredis.from_url(REDIS_URL, encoding="utf-8", decode_responses=True)
|
||||
|
||||
try:
|
||||
await redis.select(REDIS_DB_MAP["experiment_analysis"])
|
||||
status_key = f"experiment_analysis_status:{experiment_id}"
|
||||
|
||||
# 更新状态为进行中
|
||||
status_data = {
|
||||
"status": "processing",
|
||||
"start_time": datetime.now(timezone.utc).isoformat()
|
||||
}
|
||||
await redis.set(status_key, json.dumps(status_data))
|
||||
|
||||
# 获取实验基本信息
|
||||
experiment = await db.experiments.find_one({"_id": ObjectId(experiment_id)})
|
||||
if not experiment:
|
||||
raise ValueError("实验不存在")
|
||||
|
||||
# 收集会话数据
|
||||
sessions = []
|
||||
total_points = 0
|
||||
total_duration = 0
|
||||
devices_set = set()
|
||||
sensors_count = 0
|
||||
|
||||
async for session in db.experiment_sessions.find({
|
||||
"experiment_id": ObjectId(experiment_id)
|
||||
}):
|
||||
# 计算会话持续时间
|
||||
end_time = session.get("end_time") or datetime.now(timezone.utc)
|
||||
duration = (end_time - session["start_time"]).total_seconds()
|
||||
total_duration += duration
|
||||
|
||||
# 转换时间戳为毫秒
|
||||
start_ms = int(session["start_time"].timestamp() * 1000)
|
||||
end_ms = int(end_time.timestamp() * 1000)
|
||||
|
||||
# 统计设备和传感器数量
|
||||
session_devices = session.get("devices", [])
|
||||
for device in session_devices:
|
||||
devices_set.add(device["serial_number"])
|
||||
sensors_count += len(device.get("sensors", []))
|
||||
|
||||
# 获取会话数据点数
|
||||
session_points = 0
|
||||
for device in session_devices:
|
||||
stream_key = f"experiment:{experiment_id}:{device['serial_number']}"
|
||||
await redis.select(200)
|
||||
data_points = await redis.xrange(
|
||||
stream_key,
|
||||
min=str(start_ms),
|
||||
max=str(end_ms)
|
||||
)
|
||||
session_points += len(data_points)
|
||||
|
||||
total_points += session_points
|
||||
sessions.append({
|
||||
"session_id": str(session["_id"]),
|
||||
"duration": duration,
|
||||
"data_points": session_points
|
||||
})
|
||||
|
||||
if not sessions:
|
||||
raise ValueError("没有找到实验会话数据")
|
||||
|
||||
experiment_info = {
|
||||
"experiment_name": experiment["experiment_name"],
|
||||
"total_sessions": len(sessions),
|
||||
"total_duration": total_duration,
|
||||
"total_points": total_points,
|
||||
"device_stats": {
|
||||
"total_devices": len(devices_set),
|
||||
"total_sensors": sensors_count
|
||||
},
|
||||
"sessions": sessions
|
||||
}
|
||||
|
||||
# 执行分析
|
||||
analysis_result = await analyze_experiment_data(experiment_info)
|
||||
|
||||
if analysis_result:
|
||||
# 保存分析结果
|
||||
await redis.select(REDIS_DB_MAP["experiment_analysis"])
|
||||
report_key = f"experiment_report:{experiment_id}"
|
||||
await redis.set(report_key, json.dumps(analysis_result))
|
||||
|
||||
# 更新状态为完成
|
||||
status_data = {
|
||||
"status": "completed",
|
||||
"completion_time": datetime.now(timezone.utc).isoformat()
|
||||
}
|
||||
await redis.set(status_key, json.dumps(status_data))
|
||||
print("分析报告已保存")
|
||||
else:
|
||||
print("分析失败")
|
||||
status_data = {
|
||||
"status": "failed",
|
||||
"error": "Failed to generate analysis result",
|
||||
"completion_time": datetime.now(timezone.utc).isoformat()
|
||||
}
|
||||
await redis.set(status_key, json.dumps(status_data))
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error in experiment analysis: {str(e)}")
|
||||
try:
|
||||
await redis.select(REDIS_DB_MAP["experiment_analysis"])
|
||||
status_key = f"experiment_analysis_status:{experiment_id}"
|
||||
status_data = {
|
||||
"status": "failed",
|
||||
"error": str(e),
|
||||
"completion_time": datetime.now(timezone.utc).isoformat()
|
||||
}
|
||||
await redis.set(status_key, json.dumps(status_data))
|
||||
except Exception as redis_error:
|
||||
print(f"Error updating Redis status: {redis_error}")
|
||||
finally:
|
||||
try:
|
||||
await redis.aclose()
|
||||
mongo_client.close()
|
||||
except Exception as e:
|
||||
print(f"Error closing connections: {e}")
|
||||
|
||||
@router.get("/experiments/{experiment_id}/analyze")
|
||||
async def analyze_data(
|
||||
@@ -296,8 +73,12 @@ async def get_experiment_analysis_status(
|
||||
redis = await get_redis()
|
||||
|
||||
try:
|
||||
await redis.select(REDIS_DB_MAP["experiment_analysis"])
|
||||
status_key = f"experiment_analysis_status:{experiment_id}"
|
||||
db_number, status_key = await get_redis_key(
|
||||
"experiment_analysis",
|
||||
"status", # 使用 status 键模式
|
||||
experiment_id=experiment_id
|
||||
)
|
||||
await redis.select(db_number)
|
||||
status_data = await redis.get(status_key)
|
||||
|
||||
if not status_data:
|
||||
@@ -326,8 +107,12 @@ async def get_saved_report(
|
||||
redis = await get_redis()
|
||||
|
||||
try:
|
||||
await redis.select(REDIS_DB_MAP["experiment_analysis"])
|
||||
report_key = f"experiment_report:{experiment_id}"
|
||||
db_number, report_key = await get_redis_key(
|
||||
"experiment_analysis",
|
||||
"report",
|
||||
experiment_id=experiment_id
|
||||
)
|
||||
await redis.select(db_number)
|
||||
|
||||
existing_report = await redis.get(report_key)
|
||||
|
||||
@@ -343,11 +128,7 @@ async def get_saved_report(
|
||||
try:
|
||||
await redis.aclose()
|
||||
except Exception as e:
|
||||
print(f"Error closing Redis connection: {e}")
|
||||
|
||||
class QuestionModel(BaseModel):
|
||||
"""问题模型"""
|
||||
question: str
|
||||
print(f"Error closing Redis connection: {e}")
|
||||
|
||||
@router.post("/experiments/{experiment_id}/qa")
|
||||
async def ask_experiment_question(
|
||||
@@ -391,100 +172,6 @@ async def ask_experiment_question(
|
||||
except Exception as e:
|
||||
print(f"Error closing Redis connection: {e}")
|
||||
|
||||
async def process_experiment_question(task_id: str, experiment_id: str, question: str, experiment: dict):
|
||||
"""处理实验问答的后台任务"""
|
||||
redis = await get_redis()
|
||||
|
||||
try:
|
||||
# 更新任务状态为处理中
|
||||
await redis.select(208) # 使用db208存储任务状态
|
||||
await redis.hset(
|
||||
f"task:{task_id}",
|
||||
mapping={
|
||||
"status": TaskStatus.PROCESSING,
|
||||
"experiment_id": experiment_id,
|
||||
"question": question
|
||||
}
|
||||
)
|
||||
|
||||
# 从Redis获取分析报告
|
||||
await redis.select(201)
|
||||
report_key = f"experiment_report:{experiment_id}"
|
||||
report_data = await redis.get(report_key)
|
||||
|
||||
if not report_data:
|
||||
raise Exception("实验分析报告不存在,请先进行分析")
|
||||
|
||||
# 构建系统提示和用户提示
|
||||
system_prompt = """
|
||||
你是一个专业的实验助手,负责回答关于实验报告的问题。
|
||||
你应该基于实验报告提供准确、专业的回答。
|
||||
回答应当简洁明了,并尽可能引用报告中的具体内容。
|
||||
"""
|
||||
|
||||
# 构建上下文
|
||||
context = {
|
||||
"实验名称": experiment.get("experiment_name", "未知实验"),
|
||||
"分析报告": json.loads(report_data)
|
||||
}
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": f"基于以下实验报告回答问题:\n\n实验信息:{json.dumps(context, ensure_ascii=False)}\n\n问题:{question}"}
|
||||
]
|
||||
|
||||
# 调用DeepSeek API获取回答
|
||||
response = client.chat.completions.create(
|
||||
model="deepseek-chat",
|
||||
messages=messages
|
||||
)
|
||||
answer = response.choices[0].message.content
|
||||
|
||||
# 保存对话历史到Redis db207
|
||||
await redis.select(207)
|
||||
chat_history_key = f"experiment_chat_history:{experiment_id}"
|
||||
|
||||
# 获取现有历史记录
|
||||
existing_history = await redis.get(chat_history_key)
|
||||
history = json.loads(existing_history) if existing_history else []
|
||||
|
||||
# 添加新的对话
|
||||
history.append({
|
||||
"question": question,
|
||||
"answer": answer,
|
||||
"timestamp": datetime.now(timezone.utc).isoformat()
|
||||
})
|
||||
|
||||
# 保存更新后的历史记录
|
||||
await redis.set(chat_history_key, json.dumps(history))
|
||||
|
||||
# 更新任务状态为完成
|
||||
await redis.select(208)
|
||||
await redis.hset(
|
||||
f"task:{task_id}",
|
||||
mapping={
|
||||
"status": TaskStatus.COMPLETED,
|
||||
"answer": answer,
|
||||
"experiment_name": experiment.get("experiment_name")
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error processing question: {e}")
|
||||
# 更新任务状态为失败
|
||||
await redis.select(208)
|
||||
await redis.hset(
|
||||
f"task:{task_id}",
|
||||
mapping={
|
||||
"status": TaskStatus.FAILED,
|
||||
"error": str(e)
|
||||
}
|
||||
)
|
||||
finally:
|
||||
try:
|
||||
await redis.aclose()
|
||||
except Exception as e:
|
||||
print(f"Error closing Redis connection: {e}")
|
||||
|
||||
@router.get("/experiments/{experiment_id}/qa/history")
|
||||
async def get_experiment_qa_history(
|
||||
|
||||
@@ -3,10 +3,8 @@ from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
|
||||
from datetime import datetime, timedelta, timezone
|
||||
import jwt
|
||||
from jwt.exceptions import ExpiredSignatureError, InvalidSignatureError
|
||||
from pydantic import BaseModel, EmailStr, Field
|
||||
from typing import Optional
|
||||
from bson import ObjectId
|
||||
from ..cores.db import PyObjectId, get_database
|
||||
from ..cores.db import get_database
|
||||
from ..models.basemodel import UserModel
|
||||
|
||||
# JWT Configuration
|
||||
SECRET_KEY = "Obscura@2024"
|
||||
@@ -19,26 +17,9 @@ oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
|
||||
# Router
|
||||
router = APIRouter(
|
||||
prefix="/lab",
|
||||
tags=["authentication"]
|
||||
tags=["login"]
|
||||
)
|
||||
|
||||
# Pydantic models
|
||||
class UserModel(BaseModel):
|
||||
"""
|
||||
用户模型
|
||||
包含用户基本信息:用户名、密码、邮箱、姓名、所属机构
|
||||
"""
|
||||
id: Optional[PyObjectId] = Field(alias="_id", default=None)
|
||||
username: str
|
||||
password: str
|
||||
email: EmailStr
|
||||
name: str
|
||||
institution: str
|
||||
|
||||
class Config:
|
||||
populate_by_name = True
|
||||
arbitrary_types_allowed = True
|
||||
json_encoders = {ObjectId: str}
|
||||
|
||||
# Security functions
|
||||
def create_access_token(data: dict):
|
||||
|
||||
+27
-22
@@ -1,21 +1,14 @@
|
||||
from fastapi import APIRouter, HTTPException, Depends
|
||||
from pydantic import BaseModel, Field
|
||||
from datetime import datetime, timezone
|
||||
from ..models.basemodel import MemoModel,UserModel
|
||||
from .login import get_current_user
|
||||
import json
|
||||
|
||||
from ..cores.db import get_redis
|
||||
from .login import get_current_user, UserModel
|
||||
from ..cores.db import get_redis, get_redis_key
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/lab",
|
||||
tags=["memo"]
|
||||
)
|
||||
|
||||
class MemoModel(BaseModel):
|
||||
"""项目备忘录模型"""
|
||||
content: str
|
||||
create_time: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
|
||||
@router.post("/projects/{project_id}/memo")
|
||||
async def save_project_memo(
|
||||
project_id: str,
|
||||
@@ -27,8 +20,12 @@ async def save_project_memo(
|
||||
|
||||
try:
|
||||
# 选择 db205
|
||||
await redis.select(205)
|
||||
memo_key = f"memo:{project_id}"
|
||||
db_number, memo_key = await get_redis_key(
|
||||
"project_memo",
|
||||
"key_pattern",
|
||||
project_id=project_id
|
||||
)
|
||||
await redis.select(db_number)
|
||||
|
||||
# 保存备忘录内容和创建时间
|
||||
memo_data = {
|
||||
@@ -58,8 +55,12 @@ async def get_project_memo(
|
||||
|
||||
try:
|
||||
# 选择 db205
|
||||
await redis.select(205)
|
||||
memo_key = f"memo:{project_id}"
|
||||
db_number, memo_key = await get_redis_key(
|
||||
"project_memo",
|
||||
"key_pattern",
|
||||
project_id=project_id
|
||||
)
|
||||
await redis.select(db_number)
|
||||
|
||||
# 获取备忘录
|
||||
memo_data = await redis.get(memo_key)
|
||||
@@ -88,11 +89,13 @@ async def save_experiment_memo(
|
||||
redis = await get_redis()
|
||||
|
||||
try:
|
||||
# 选择 db206
|
||||
await redis.select(206)
|
||||
memo_key = f"memo:{experiment_id}"
|
||||
db_number, memo_key = await get_redis_key(
|
||||
"experiment_memo",
|
||||
"key_pattern",
|
||||
experiment_id=experiment_id
|
||||
)
|
||||
await redis.select(db_number)
|
||||
|
||||
# 保存备忘录内容和创建时间
|
||||
memo_data = {
|
||||
"content": memo.content,
|
||||
"create_time": memo.create_time.isoformat()
|
||||
@@ -119,11 +122,13 @@ async def get_experiment_memo(
|
||||
redis = await get_redis()
|
||||
|
||||
try:
|
||||
# 选择 db206
|
||||
await redis.select(206)
|
||||
memo_key = f"memo:{experiment_id}"
|
||||
db_number, memo_key = await get_redis_key(
|
||||
"experiment_memo",
|
||||
"key_pattern",
|
||||
experiment_id=experiment_id
|
||||
)
|
||||
await redis.select(db_number)
|
||||
|
||||
# 获取备忘录
|
||||
memo_data = await redis.get(memo_key)
|
||||
|
||||
if not memo_data:
|
||||
|
||||
+49
-23
@@ -6,14 +6,14 @@ import asyncio
|
||||
from datetime import datetime, timezone
|
||||
import json
|
||||
|
||||
from ..cores.db import get_database, get_redis
|
||||
from ..cores.db import get_database, get_redis, get_redis_key
|
||||
from ..cores.config import UPLOAD_PATH, TaskStatus
|
||||
from ..routers.login import get_current_user, UserModel
|
||||
from ..routers.login import get_current_user
|
||||
from ..models.paper import (
|
||||
ReferenceModel, QuestionModel,
|
||||
process_batch_analysis, process_reference_question,
|
||||
run_analysis_in_thread
|
||||
)
|
||||
from ..models.basemodel import ReferenceModel, QuestionModel, UserModel
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/lab",
|
||||
@@ -90,13 +90,16 @@ async def get_reference_report(
|
||||
reference_id: str,
|
||||
current_user: UserModel = Depends(get_current_user)
|
||||
):
|
||||
"""从 Redis db203 读取已保存的文献报告"""
|
||||
"""从 Redis 读取已保存的文献报告"""
|
||||
redis = await get_redis()
|
||||
|
||||
try:
|
||||
await redis.select(203)
|
||||
report_key = f"reference_report:{reference_id}"
|
||||
|
||||
db_number, report_key = await get_redis_key(
|
||||
"reference_analysis", # 使用 REDIS_DB_MAPPING 中定义的类型
|
||||
"report", # 使用 key_patterns 中定义的模式
|
||||
reference_id=reference_id
|
||||
)
|
||||
await redis.select(db_number)
|
||||
existing_report = await redis.get(report_key)
|
||||
if not existing_report:
|
||||
raise HTTPException(status_code=404, detail="No saved reference report found")
|
||||
@@ -122,8 +125,12 @@ async def analyze_reference_summary_report(
|
||||
redis = await get_redis()
|
||||
|
||||
try:
|
||||
await redis.select(204)
|
||||
status_key = f"reference_analysis_status:{project_id}"
|
||||
db_number, status_key = await get_redis_key(
|
||||
"reference_summary", # 使用 REDIS_DB_MAPPING 中定义的类型
|
||||
"status", # 使用 key_patterns 中定义的模式
|
||||
project_id=project_id
|
||||
)
|
||||
await redis.select(db_number)
|
||||
current_status = await redis.get(status_key)
|
||||
|
||||
if current_status:
|
||||
@@ -183,8 +190,12 @@ async def get_reference_analysis_status(
|
||||
redis = await get_redis()
|
||||
|
||||
try:
|
||||
await redis.select(204)
|
||||
status_key = f"reference_analysis_status:{project_id}"
|
||||
db_number, status_key = await get_redis_key(
|
||||
"reference_summary",
|
||||
"status", # 使用 status 键模式
|
||||
project_id=project_id
|
||||
)
|
||||
await redis.select(db_number)
|
||||
status_data = await redis.get(status_key)
|
||||
|
||||
if not status_data:
|
||||
@@ -209,13 +220,16 @@ async def get_reference_summary_report(
|
||||
project_id: str,
|
||||
current_user: UserModel = Depends(get_current_user)
|
||||
):
|
||||
"""从 Redis db204 读取已保存的文献报告"""
|
||||
"""从 Redis 读取已保存的文献汇总报告"""
|
||||
redis = await get_redis()
|
||||
|
||||
try:
|
||||
await redis.select(204)
|
||||
report_key = f"reference_summary_report:{project_id}"
|
||||
|
||||
db_number, report_key = await get_redis_key(
|
||||
"reference_summary", # 使用 REDIS_DB_MAPPING 中定义的类型
|
||||
"report", # 使用 key_patterns 中定义的模式
|
||||
project_id=project_id
|
||||
)
|
||||
await redis.select(db_number)
|
||||
existing_report = await redis.get(report_key)
|
||||
if not existing_report:
|
||||
raise HTTPException(status_code=404, detail="No saved reference report found")
|
||||
@@ -276,8 +290,13 @@ async def get_task_status(task_id: str):
|
||||
redis = await get_redis()
|
||||
|
||||
try:
|
||||
await redis.select(208)
|
||||
task_data = await redis.hgetall(f"task:{task_id}")
|
||||
db_number, task_key = await get_redis_key(
|
||||
"qa_task", # 使用 REDIS_DB_MAPPING 中定义的类型
|
||||
"task", # 使用 key_pattern
|
||||
task_id=task_id
|
||||
)
|
||||
await redis.select(db_number)
|
||||
task_data = await redis.hgetall(task_key)
|
||||
|
||||
if not task_data:
|
||||
raise HTTPException(status_code=404, detail="任务不存在")
|
||||
@@ -317,10 +336,13 @@ async def get_reference_qa_history(
|
||||
redis = await get_redis()
|
||||
|
||||
try:
|
||||
await redis.select(207)
|
||||
chat_history_key = f"chat_history:{reference_id}"
|
||||
|
||||
history_data = await redis.get(chat_history_key)
|
||||
db_number, history_key = await get_redis_key(
|
||||
"chat_history", # 使用 REDIS_DB_MAPPING 中定义的类型
|
||||
"reference", # 使用 key_patterns 中定义的模式
|
||||
reference_id=reference_id
|
||||
)
|
||||
await redis.select(db_number)
|
||||
history_data = await redis.get(history_key)
|
||||
if not history_data:
|
||||
return []
|
||||
|
||||
@@ -387,8 +409,12 @@ async def batch_upload_references(
|
||||
|
||||
redis = await get_redis()
|
||||
try:
|
||||
await redis.select(203)
|
||||
report_key = f"reference_report:{str(result.inserted_id)}"
|
||||
db_number, report_key = await get_redis_key(
|
||||
"reference_analysis", # 使用 REDIS_DB_MAPPING 中定义的类型
|
||||
"report", # 使用 key_patterns 中定义的模式
|
||||
reference_id=str(result.inserted_id)
|
||||
)
|
||||
await redis.select(db_number)
|
||||
initial_status = {
|
||||
"status": "processing",
|
||||
"message": "Analysis in progress"
|
||||
|
||||
@@ -0,0 +1,147 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from typing import List
|
||||
from bson import ObjectId
|
||||
import asyncio
|
||||
from datetime import datetime, timezone
|
||||
import json
|
||||
|
||||
from ..cores.db import get_database, get_redis, get_redis_key
|
||||
from ..routers.login import get_current_user
|
||||
from ..models.basemodel import UserModel
|
||||
from ..models.paper_summary import run_analysis_in_thread
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/lab",
|
||||
tags=["paper_summary"]
|
||||
)
|
||||
|
||||
|
||||
@router.get("/references/{project_id}/analyze_report")
|
||||
async def analyze_reference_summary_report(
|
||||
project_id: str,
|
||||
current_user: UserModel = Depends(get_current_user)
|
||||
):
|
||||
"""分析文献数据"""
|
||||
db = await get_database()
|
||||
redis = await get_redis()
|
||||
|
||||
try:
|
||||
db_number, status_key = await get_redis_key(
|
||||
"reference_summary", # 使用 REDIS_DB_MAPPING 中定义的类型
|
||||
"status", # 使用 key_patterns 中定义的模式
|
||||
project_id=project_id
|
||||
)
|
||||
await redis.select(db_number)
|
||||
current_status = await redis.get(status_key)
|
||||
|
||||
if current_status:
|
||||
status_data = json.loads(current_status)
|
||||
if status_data.get("status") == "processing":
|
||||
return {
|
||||
"message": "文献分析任务正在进行中",
|
||||
"status": "processing",
|
||||
"project_id": project_id,
|
||||
"start_time": status_data.get("start_time")
|
||||
}
|
||||
|
||||
reference_cursor = db.references.find({
|
||||
"project_id": ObjectId(project_id)
|
||||
})
|
||||
|
||||
reference_ids = []
|
||||
async for reference in reference_cursor:
|
||||
reference_ids.append(str(reference["_id"]))
|
||||
|
||||
if not reference_ids:
|
||||
raise HTTPException(status_code=404, detail="No references found for this project")
|
||||
|
||||
status_data = {
|
||||
"status": "processing",
|
||||
"start_time": datetime.now(timezone.utc).isoformat(),
|
||||
"total_references": len(reference_ids),
|
||||
"completed_references": 0
|
||||
}
|
||||
await redis.set(status_key, json.dumps(status_data))
|
||||
|
||||
asyncio.create_task(run_analysis_in_thread(project_id, reference_ids))
|
||||
|
||||
return {
|
||||
"message": "文献分析任务已启动",
|
||||
"status": "processing",
|
||||
"project_id": project_id,
|
||||
"start_time": status_data["start_time"],
|
||||
"total_references": len(reference_ids)
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error starting reference analysis: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
finally:
|
||||
try:
|
||||
await redis.aclose()
|
||||
except Exception as e:
|
||||
print(f"Error closing Redis connection: {e}")
|
||||
|
||||
@router.get("/references/{project_id}/analysis_status")
|
||||
async def get_reference_analysis_status(
|
||||
project_id: str,
|
||||
current_user: UserModel = Depends(get_current_user)
|
||||
):
|
||||
"""获取文献分析任务的状态"""
|
||||
redis = await get_redis()
|
||||
|
||||
try:
|
||||
db_number, status_key = await get_redis_key(
|
||||
"reference_summary",
|
||||
"status", # 使用 status 键模式
|
||||
project_id=project_id
|
||||
)
|
||||
await redis.select(db_number)
|
||||
status_data = await redis.get(status_key)
|
||||
|
||||
if not status_data:
|
||||
return {
|
||||
"status": "not_started",
|
||||
"project_id": project_id
|
||||
}
|
||||
|
||||
return json.loads(status_data)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error getting analysis status: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
finally:
|
||||
try:
|
||||
await redis.aclose()
|
||||
except Exception as e:
|
||||
print(f"Error closing Redis connection: {e}")
|
||||
|
||||
@router.get("/references/{project_id}/summary_report")
|
||||
async def get_reference_summary_report(
|
||||
project_id: str,
|
||||
current_user: UserModel = Depends(get_current_user)
|
||||
):
|
||||
"""从 Redis 读取已保存的文献汇总报告"""
|
||||
redis = await get_redis()
|
||||
|
||||
try:
|
||||
db_number, report_key = await get_redis_key(
|
||||
"reference_summary", # 使用 REDIS_DB_MAPPING 中定义的类型
|
||||
"report", # 使用 key_patterns 中定义的模式
|
||||
project_id=project_id
|
||||
)
|
||||
await redis.select(db_number)
|
||||
existing_report = await redis.get(report_key)
|
||||
if not existing_report:
|
||||
raise HTTPException(status_code=404, detail="No saved reference report found")
|
||||
|
||||
return json.loads(existing_report)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error getting reference report: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
finally:
|
||||
try:
|
||||
await redis.aclose()
|
||||
except Exception as e:
|
||||
print(f"Error closing Redis connection: {e}")
|
||||
+18
-27
@@ -1,35 +1,16 @@
|
||||
from fastapi import APIRouter, HTTPException, Depends
|
||||
from datetime import datetime, timezone
|
||||
from typing import List, Optional
|
||||
from pydantic import BaseModel, Field
|
||||
from bson import ObjectId
|
||||
|
||||
from ..cores.db import PyObjectId, get_database, get_redis
|
||||
from .login import get_current_user, UserModel
|
||||
|
||||
from ..cores.db import get_database, get_redis, get_redis_key
|
||||
from .login import get_current_user
|
||||
from ..models.basemodel import UserModel,ProjectModel
|
||||
# Router
|
||||
router = APIRouter(
|
||||
prefix="/lab",
|
||||
tags=["projects"]
|
||||
)
|
||||
|
||||
# Pydantic models
|
||||
class ProjectModel(BaseModel):
|
||||
"""
|
||||
项目模型
|
||||
包含项目信息:用户ID、项目名称、创建时间、描述
|
||||
"""
|
||||
id: Optional[PyObjectId] = Field(alias="_id", default=None)
|
||||
user_id: Optional[PyObjectId] = None
|
||||
project_name: str
|
||||
create_time: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
description: str
|
||||
|
||||
class Config:
|
||||
populate_by_name = True
|
||||
arbitrary_types_allowed = True
|
||||
json_encoders = {ObjectId: str}
|
||||
|
||||
# Project routes
|
||||
@router.post("/projects")
|
||||
async def create_project(
|
||||
@@ -79,7 +60,7 @@ async def get_projects(current_user: UserModel = Depends(get_current_user)):
|
||||
|
||||
return projects
|
||||
# 添加删除项目的路由
|
||||
@router.delete("/lab/projects/{project_id}")
|
||||
@router.delete("/projects/{project_id}")
|
||||
async def delete_project(
|
||||
project_id: str,
|
||||
current_user: UserModel = Depends(get_current_user)
|
||||
@@ -108,8 +89,13 @@ async def delete_project(
|
||||
await db.experiment_devices.delete_many({"experiment_id": ObjectId(exp_id)})
|
||||
|
||||
# 删除Redis中的实验报告
|
||||
await redis.select(201)
|
||||
await redis.delete(f"experiment_report:{exp_id}")
|
||||
db_number, report_key = await get_redis_key(
|
||||
"experiment_analysis",
|
||||
"report",
|
||||
experiment_id=exp_id
|
||||
)
|
||||
await redis.select(db_number)
|
||||
await redis.delete(report_key)
|
||||
|
||||
# 删除所有实验
|
||||
await db.experiments.delete_many({"project_id": ObjectId(project_id)})
|
||||
@@ -118,8 +104,13 @@ async def delete_project(
|
||||
await db.projects.delete_one({"_id": ObjectId(project_id)})
|
||||
|
||||
# 删除Redis中的项目报告
|
||||
await redis.select(202)
|
||||
await redis.delete(f"project_report:{project_id}")
|
||||
db_number, report_key = await get_redis_key(
|
||||
"project_analysis",
|
||||
"report",
|
||||
project_id=project_id
|
||||
)
|
||||
await redis.select(db_number)
|
||||
await redis.delete(report_key)
|
||||
|
||||
return {"message": "Project successfully deleted"}
|
||||
|
||||
|
||||
@@ -1,106 +1,19 @@
|
||||
from fastapi import APIRouter, HTTPException, Depends
|
||||
from typing import Dict, List, Optional, Any
|
||||
from datetime import datetime, timezone
|
||||
import json
|
||||
import asyncio
|
||||
from bson import ObjectId
|
||||
from motor.motor_asyncio import AsyncIOMotorClient
|
||||
from openai import OpenAI
|
||||
from redis import asyncio as aioredis
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ..cores.config import MONGODB_URL, REDIS_URL, DEEPSEEK_API_CONFIG, TaskStatus
|
||||
from ..cores.db import get_database, get_redis
|
||||
from .login import get_current_user, UserModel
|
||||
from ..cores.config import TaskStatus
|
||||
from ..cores.db import get_database, get_redis, get_redis_key
|
||||
from .login import get_current_user
|
||||
from ..models.project import pro_analysis_executor, run_project_in_thread, QuestionModel, process_project_question
|
||||
from ..models.basemodel import UserModel
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/lab",
|
||||
tags=["project_report"]
|
||||
)
|
||||
|
||||
# DeepSeek API Configuration
|
||||
client = OpenAI(
|
||||
base_url=DEEPSEEK_API_CONFIG["base_url"],
|
||||
api_key=DEEPSEEK_API_CONFIG["api_key"]
|
||||
)
|
||||
|
||||
async def analyze_project_data(project_data: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
||||
system_prompt = """
|
||||
You are an AI assistant responsible for analyzing lab reports.
|
||||
You will summarize and analyze all lab reports and generate a comprehensive analysis report in JSON format.
|
||||
The JSON structure must strictly follow the provided template.
|
||||
"""
|
||||
|
||||
user_prompt = f"""The project report is summarized based on the following experimental analysis reports:
|
||||
Experimental analysis reports: {json.dumps(project_data, ensure_ascii=False)}
|
||||
|
||||
Generate a JSON response with the following structure:
|
||||
|
||||
{{
|
||||
"Project Analysis Report": {{
|
||||
"1. Project Overview": {{
|
||||
"Project Name": "{project_data['project_stats']['project_name']}",
|
||||
"Total Experiments": "{project_data['project_stats']['total_experiments']}",
|
||||
"Total Data Points": "{project_data['project_stats']['total_data_points']}",
|
||||
}},
|
||||
"2. Aggregated Statistics": {{
|
||||
"Total Sessions": {project_data['project_stats']['total_sessions']},
|
||||
"Total Duration": {project_data['project_stats']['total_duration']},
|
||||
"Total Data Points": {project_data['project_stats']['total_data_points']},
|
||||
"Average Session Duration": {project_data['project_stats']['avg_session_duration']},
|
||||
"Average Data Points per Session": {project_data['project_stats']['avg_data_points_per_session']}
|
||||
}},
|
||||
"3. Performance Analysis": {{
|
||||
"Best Performing Sessions": [
|
||||
{{
|
||||
"Session ID": "session_id",
|
||||
"experiment_id": "experiment_id",
|
||||
"experiment_name": "experiment_name",
|
||||
"Duration": "duration",
|
||||
"Data Points": "data_points",
|
||||
"Success Factors": "success_factors"
|
||||
}}
|
||||
],
|
||||
"Problematic Sessions": [
|
||||
{{
|
||||
"Session ID": "session_id",
|
||||
"experiment_id": "experiment_id",
|
||||
"experiment_name": "experiment_name",
|
||||
"Issues": "specific_issues",
|
||||
"Possible Causes": "possible_causes"
|
||||
}}
|
||||
]
|
||||
}},
|
||||
"4. Common Findings": {{
|
||||
"Recurring Issues": "common_failure_modes",
|
||||
"Equipment performance": "equipment_performance",
|
||||
"Sensor reliability analysis": "sensor_reliability_analysis"
|
||||
}},
|
||||
"5. Recommendations": {{
|
||||
"Equipment optimization suggestions": "equipment_optimization_suggestions",
|
||||
"Experiment process improvement suggestions": "experiment_process_improvement_suggestions",
|
||||
"Data collection strategy adjustment": "data_collection_strategy_adjustment"
|
||||
}}
|
||||
}}
|
||||
}}
|
||||
"""
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_prompt}
|
||||
]
|
||||
|
||||
try:
|
||||
response = client.chat.completions.create(
|
||||
model="deepseek-chat",
|
||||
messages=messages,
|
||||
response_format={'type': 'json_object'}
|
||||
)
|
||||
return json.loads(response.choices[0].message.content)
|
||||
except Exception as e:
|
||||
print(f"Error calling DeepSeek API: {e}")
|
||||
return None
|
||||
|
||||
@router.get("/projects/{project_id}/analyze")
|
||||
async def analyze_project_data_endpoint(
|
||||
project_id: str,
|
||||
@@ -139,16 +52,14 @@ async def analyze_project_data_endpoint(
|
||||
}
|
||||
await redis.set(status_key, json.dumps(status_data))
|
||||
|
||||
print("启动分析线程...")
|
||||
# 创建并启动新线程
|
||||
import threading
|
||||
analysis_thread = threading.Thread(
|
||||
target=run_project_in_thread,
|
||||
args=(project_id,),
|
||||
daemon=True
|
||||
print("使用项目分析线程池启动任务...")
|
||||
# 使用线程池执行任务
|
||||
loop = asyncio.get_event_loop()
|
||||
await loop.run_in_executor(
|
||||
pro_analysis_executor,
|
||||
lambda: run_project_in_thread(project_id)
|
||||
)
|
||||
analysis_thread.start()
|
||||
print("分析线程已启动")
|
||||
print("分析任务已提交到线程池")
|
||||
|
||||
return {
|
||||
"message": "项目分析任务已启动",
|
||||
@@ -166,164 +77,6 @@ async def analyze_project_data_endpoint(
|
||||
except Exception as e:
|
||||
print(f"Error closing Redis connection: {e}")
|
||||
|
||||
def run_project_in_thread(project_id: str):
|
||||
"""在独立线程中运行分析任务"""
|
||||
# 创建新的事件循环
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
# 在新的事件循环中运行异步任务
|
||||
loop.run_until_complete(process_project_analysis(project_id))
|
||||
except Exception as e:
|
||||
print(f"Error in analysis thread: {str(e)}")
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
async def process_project_analysis(project_id: str):
|
||||
"""后台处理项目分析任务"""
|
||||
print(f"\n=== 开始项目分析 ===")
|
||||
print(f"项目ID: {project_id}")
|
||||
|
||||
# 创建新的数据库和Redis连接
|
||||
mongo_client = AsyncIOMotorClient(MONGODB_URL)
|
||||
db = mongo_client["lab"]
|
||||
redis = await aioredis.from_url(REDIS_URL, encoding="utf-8", decode_responses=True)
|
||||
|
||||
try:
|
||||
await redis.select(202)
|
||||
status_key = f"project_analysis_status:{project_id}"
|
||||
|
||||
# 更新状态为进行中
|
||||
status_data = {
|
||||
"status": "processing",
|
||||
"start_time": datetime.now(timezone.utc).isoformat()
|
||||
}
|
||||
await redis.set(status_key, json.dumps(status_data))
|
||||
|
||||
# 获取项目名称
|
||||
project = await db.projects.find_one({"_id": ObjectId(project_id)})
|
||||
project_name = project["project_name"] if project else "Unknown Project"
|
||||
|
||||
# 从MongoDB查找项目下的所有实验
|
||||
experiment_cursor = db.experiments.find({
|
||||
"project_id": ObjectId(project_id)
|
||||
})
|
||||
|
||||
experiment_ids = []
|
||||
experiment_names = {} # 存储实验ID和名称的映射
|
||||
async for experiment in experiment_cursor:
|
||||
exp_id = str(experiment["_id"])
|
||||
experiment_ids.append(exp_id)
|
||||
experiment_names[exp_id] = experiment["experiment_name"]
|
||||
|
||||
if not experiment_ids:
|
||||
print("错误: 未找到任何实验")
|
||||
raise Exception("No experiments found in this project")
|
||||
|
||||
# 从Redis db201获取所有实验报告
|
||||
await redis.select(201)
|
||||
all_experiment_reports = {}
|
||||
|
||||
# 统计数据初始化
|
||||
total_sessions = 0
|
||||
total_duration = 0
|
||||
total_data_points = 0
|
||||
total_devices = set()
|
||||
total_sensors = 0
|
||||
|
||||
for exp_id in experiment_ids:
|
||||
report_key = f"experiment_report:{exp_id}"
|
||||
report_data = await redis.get(report_key)
|
||||
|
||||
if report_data:
|
||||
try:
|
||||
parsed_report = json.loads(report_data)
|
||||
all_experiment_reports[exp_id] = parsed_report
|
||||
|
||||
# 从Basic Information中提取数据
|
||||
basic_info = parsed_report["Experiment Analysis Report"]["1. Basic Information"]
|
||||
total_sessions += int(basic_info["Total Sessions"])
|
||||
total_duration += float(basic_info["Total Duration"].split()[0]) # 去掉"seconds"
|
||||
total_data_points += int(basic_info["Data Points"])
|
||||
total_devices.update([str(i) for i in range(int(basic_info["Devices number"]))])
|
||||
total_sensors += int(basic_info["Sensors number"])
|
||||
|
||||
except (json.JSONDecodeError, KeyError) as e:
|
||||
continue
|
||||
|
||||
if not all_experiment_reports:
|
||||
raise Exception("No experiment reports found")
|
||||
|
||||
# 计算平均值
|
||||
avg_session_duration = total_duration / total_sessions if total_sessions > 0 else 0
|
||||
avg_data_points_per_session = total_data_points / total_sessions if total_sessions > 0 else 0
|
||||
|
||||
# 添加项目级统计数据
|
||||
project_stats = {
|
||||
"project_name": project_name,
|
||||
"total_experiments": len(experiment_ids),
|
||||
"total_sessions": total_sessions,
|
||||
"total_duration": total_duration,
|
||||
"total_data_points": total_data_points,
|
||||
"total_devices": len(total_devices),
|
||||
"total_sensors": total_sensors,
|
||||
"avg_session_duration": avg_session_duration,
|
||||
"avg_data_points_per_session": avg_data_points_per_session,
|
||||
"experiment_names": experiment_names
|
||||
}
|
||||
|
||||
# 将统计数据和原始报告一起发送给分析函数
|
||||
project_data = {
|
||||
"project_stats": project_stats,
|
||||
"experiment_reports": all_experiment_reports
|
||||
}
|
||||
|
||||
# 执行分析
|
||||
analysis_result = await analyze_project_data(project_data)
|
||||
|
||||
if analysis_result:
|
||||
# 保存分析结果
|
||||
await redis.select(202)
|
||||
report_key = f"project_report:{project_id}"
|
||||
await redis.set(report_key, json.dumps(analysis_result))
|
||||
|
||||
# 更新状态为完成
|
||||
status_key = f"project_analysis_status:{project_id}"
|
||||
status_data = {
|
||||
"status": "completed",
|
||||
"completion_time": datetime.now(timezone.utc).isoformat()
|
||||
}
|
||||
await redis.set(status_key, json.dumps(status_data))
|
||||
print("分析报告已保存")
|
||||
else:
|
||||
print("\n7. 分析失败")
|
||||
# 更新失败状态
|
||||
status_data = {
|
||||
"status": "failed",
|
||||
"error": "Failed to generate analysis result",
|
||||
"completion_time": datetime.now(timezone.utc).isoformat()
|
||||
}
|
||||
await redis.set(status_key, json.dumps(status_data))
|
||||
|
||||
except Exception as e:
|
||||
try:
|
||||
await redis.select(202)
|
||||
status_key = f"project_analysis_status:{project_id}"
|
||||
status_data = {
|
||||
"status": "failed",
|
||||
"error": str(e),
|
||||
"completion_time": datetime.now(timezone.utc).isoformat()
|
||||
}
|
||||
await redis.set(status_key, json.dumps(status_data))
|
||||
except Exception as redis_error:
|
||||
print(f"Error updating Redis status: {redis_error}")
|
||||
finally:
|
||||
try:
|
||||
await redis.aclose()
|
||||
mongo_client.close()
|
||||
except Exception as e:
|
||||
print(f"Error closing Redis connection: {e}")
|
||||
|
||||
@router.get("/projects/{project_id}/analysis_status")
|
||||
async def get_project_analysis_status(
|
||||
@@ -334,8 +87,12 @@ async def get_project_analysis_status(
|
||||
redis = await get_redis()
|
||||
|
||||
try:
|
||||
await redis.select(202)
|
||||
status_key = f"project_analysis_status:{project_id}"
|
||||
db_number, status_key = await get_redis_key(
|
||||
"project_analysis",
|
||||
"status", # 使用 status 键模式
|
||||
project_id=project_id
|
||||
)
|
||||
await redis.select(db_number)
|
||||
status_data = await redis.get(status_key)
|
||||
|
||||
if not status_data:
|
||||
@@ -364,11 +121,12 @@ async def get_project_saved_report(
|
||||
redis = await get_redis()
|
||||
|
||||
try:
|
||||
# 选择 db202
|
||||
await redis.select(202)
|
||||
report_key = f"project_report:{project_id}"
|
||||
|
||||
# 获取已保存的报告
|
||||
db_number, report_key = await get_redis_key(
|
||||
"project_analysis",
|
||||
"report", # 使用 report 键模式
|
||||
project_id=project_id
|
||||
)
|
||||
await redis.select(db_number)
|
||||
existing_report = await redis.get(report_key)
|
||||
|
||||
if not existing_report:
|
||||
@@ -386,13 +144,8 @@ async def get_project_saved_report(
|
||||
except Exception as e:
|
||||
print(f"Error closing Redis connection: {e}")
|
||||
|
||||
class QuestionModel(BaseModel):
|
||||
"""问题模型"""
|
||||
question: str
|
||||
|
||||
# 项目问答
|
||||
|
||||
@router.post("/lab/projects/{project_id}/qa")
|
||||
@router.post("/projects/{project_id}/qa")
|
||||
async def ask_project_question(
|
||||
project_id: str,
|
||||
question: QuestionModel,
|
||||
@@ -434,102 +187,7 @@ async def ask_project_question(
|
||||
except Exception as e:
|
||||
print(f"Error closing Redis connection: {e}")
|
||||
|
||||
async def process_project_question(task_id: str, project_id: str, question: str, project: dict):
|
||||
"""处理项目问答的后台任务"""
|
||||
redis = await get_redis()
|
||||
|
||||
try:
|
||||
# 更新任务状态为处理中
|
||||
await redis.select(208) # 使用db208存储任务状态
|
||||
await redis.hset(
|
||||
f"task:{task_id}",
|
||||
mapping={
|
||||
"status": TaskStatus.PROCESSING,
|
||||
"project_id": project_id,
|
||||
"question": question
|
||||
}
|
||||
)
|
||||
|
||||
# 从Redis获取分析报告
|
||||
await redis.select(202)
|
||||
report_key = f"project_report:{project_id}"
|
||||
report_data = await redis.get(report_key)
|
||||
|
||||
if not report_data:
|
||||
raise Exception("项目分析报告不存在,请先进行分析")
|
||||
|
||||
# 构建系统提示和用户提示
|
||||
system_prompt = """
|
||||
你是一个专业的项目助手,负责回答关于项目报告的问题。
|
||||
你应该基于项目报告提供准确、专业的回答。
|
||||
回答应当简洁明了,并尽可能引用报告中的具体内容。
|
||||
"""
|
||||
|
||||
# 构建上下文
|
||||
context = {
|
||||
"项目名称": project.get("project_name", "未知项目"),
|
||||
"分析报告": json.loads(report_data)
|
||||
}
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": f"基于以下项目报告回答问题:\n\n项目信息:{json.dumps(context, ensure_ascii=False)}\n\n问题:{question}"}
|
||||
]
|
||||
|
||||
# 调用DeepSeek API获取回答
|
||||
response = client.chat.completions.create(
|
||||
model="deepseek-chat",
|
||||
messages=messages
|
||||
)
|
||||
answer = response.choices[0].message.content
|
||||
|
||||
# 保存对话历史到Redis db207
|
||||
await redis.select(207)
|
||||
chat_history_key = f"project_chat_history:{project_id}"
|
||||
|
||||
# 获取现有历史记录
|
||||
existing_history = await redis.get(chat_history_key)
|
||||
history = json.loads(existing_history) if existing_history else []
|
||||
|
||||
# 添加新的对话
|
||||
history.append({
|
||||
"question": question,
|
||||
"answer": answer,
|
||||
"timestamp": datetime.now(timezone.utc).isoformat()
|
||||
})
|
||||
|
||||
# 保存更新后的历史记录
|
||||
await redis.set(chat_history_key, json.dumps(history))
|
||||
|
||||
# 更新任务状态为完成
|
||||
await redis.select(208)
|
||||
await redis.hset(
|
||||
f"task:{task_id}",
|
||||
mapping={
|
||||
"status": TaskStatus.COMPLETED,
|
||||
"answer": answer,
|
||||
"project_name": project.get("project_name")
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error processing question: {e}")
|
||||
# 更新任务状态为失败
|
||||
await redis.select(208)
|
||||
await redis.hset(
|
||||
f"task:{task_id}",
|
||||
mapping={
|
||||
"status": TaskStatus.FAILED,
|
||||
"error": str(e)
|
||||
}
|
||||
)
|
||||
finally:
|
||||
try:
|
||||
await redis.aclose()
|
||||
except Exception as e:
|
||||
print(f"Error closing Redis connection: {e}")
|
||||
|
||||
@router.get("/lab/projects/{project_id}/qa/history")
|
||||
@router.get("/projects/{project_id}/qa/history")
|
||||
async def get_project_qa_history(
|
||||
project_id: str,
|
||||
current_user: UserModel = Depends(get_current_user)
|
||||
@@ -539,8 +197,12 @@ async def get_project_qa_history(
|
||||
|
||||
try:
|
||||
# 从Redis db207获取对话历史
|
||||
await redis.select(207)
|
||||
chat_history_key = f"project_chat_history:{project_id}"
|
||||
db_number, chat_history_key = await get_redis_key(
|
||||
"chat_history",
|
||||
"project",
|
||||
project_id=project_id
|
||||
)
|
||||
await redis.select(db_number)
|
||||
|
||||
history_data = await redis.get(chat_history_key)
|
||||
if not history_data:
|
||||
|
||||
@@ -4,7 +4,10 @@ import asyncio
|
||||
import json
|
||||
import time
|
||||
|
||||
router = APIRouter()
|
||||
router = APIRouter(
|
||||
prefix="/lab",
|
||||
tags=["websocket"]
|
||||
)
|
||||
|
||||
# Connection Manager for WebSocket clients
|
||||
class ConnectionManager:
|
||||
@@ -90,7 +93,7 @@ class ConnectionManager:
|
||||
# 创建连接管理器实例
|
||||
manager = ConnectionManager()
|
||||
|
||||
@router.websocket("/lab/ws/{experiment_id}/{serial_number}")
|
||||
@router.websocket("/ws/{experiment_id}/{serial_number}")
|
||||
async def websocket_endpoint(
|
||||
websocket: WebSocket,
|
||||
experiment_id: str,
|
||||
@@ -133,7 +136,7 @@ async def websocket_endpoint(
|
||||
finally:
|
||||
await manager.disconnect(websocket, experiment_id, serial_number)
|
||||
|
||||
@router.get("/lab/status")
|
||||
@router.get("/status")
|
||||
async def get_status():
|
||||
return {
|
||||
"status": "running",
|
||||
|
||||
Reference in New Issue
Block a user