293 lines
10 KiB
Python
293 lines
10 KiB
Python
from fastapi import FastAPI, HTTPException, File, UploadFile
|
|
from fastapi.responses import StreamingResponse, JSONResponse
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from ultralytics import YOLO
|
|
import cv2
|
|
import numpy as np
|
|
import json
|
|
import uvicorn
|
|
from kafka import KafkaProducer, KafkaConsumer
|
|
from redis import Redis
|
|
import io
|
|
import uuid
|
|
import os
|
|
from datetime import datetime, timedelta
|
|
import threading
|
|
|
|
app = FastAPI()
|
|
fall_app = FastAPI()
|
|
app.mount("/fall", fall_app)
|
|
|
|
# CORS配置
|
|
ALLOWED_ORIGINS = 'https://beta.obscura.work'
|
|
|
|
# 只为主应用添加CORS中间件
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=ALLOWED_ORIGINS,
|
|
allow_credentials=True,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
# 配置
|
|
MODEL_PATH = "/home/zydi/models/yolov8n-fall.pt" # 请替换为您的模型路径
|
|
KAFKA_BROKER = "222.186.10.253:9092"
|
|
KAFKA_TOPIC = "fall" # 指定Kafka topic
|
|
KAFKA_GROUP_ID = "fall_group" # 指定消费者组ID
|
|
|
|
REDIS_HOST = "222.186.10.253"
|
|
REDIS_PORT = 6379
|
|
REDIS_PASSWORD = "Obscura@2024"
|
|
REDIS_DB = 4
|
|
|
|
UPLOAD_DIR = "/www/wwwroot/beta.obscura.work/upload_files/upload"
|
|
RESULT_DIR = "/www/wwwroot/beta.obscura.work/upload_files/result"
|
|
MAX_FILE_AGE = timedelta(hours=1)
|
|
|
|
# 确保目录存在
|
|
os.makedirs(UPLOAD_DIR, exist_ok=True)
|
|
os.makedirs(RESULT_DIR, exist_ok=True)
|
|
|
|
# 初始化 Kafka
|
|
producer = KafkaProducer(bootstrap_servers=[KAFKA_BROKER])
|
|
consumer = KafkaConsumer(
|
|
KAFKA_TOPIC,
|
|
bootstrap_servers=[KAFKA_BROKER],
|
|
group_id=KAFKA_GROUP_ID,
|
|
auto_offset_reset='earliest',
|
|
enable_auto_commit=True,
|
|
value_deserializer=lambda x: json.loads(x.decode('utf-8'))
|
|
)
|
|
|
|
# 初始化 Redis
|
|
redis_client = Redis(
|
|
host=REDIS_HOST,
|
|
port=REDIS_PORT,
|
|
password=REDIS_PASSWORD,
|
|
db=REDIS_DB
|
|
)
|
|
|
|
class fallDetector:
|
|
def __init__(self, model_path):
|
|
self.model = YOLO(model_path)
|
|
def detect(self, frame):
|
|
results = self.model(frame)
|
|
return results
|
|
def format_results(self, results):
|
|
formatted_results = []
|
|
for r in results:
|
|
if not hasattr(r, 'boxes') or len(r.boxes) == 0:
|
|
print("没有检测到任何对象")
|
|
return [{"message": "No objects detected"}]
|
|
|
|
boxes = r.boxes
|
|
names = getattr(r, 'names', {})
|
|
|
|
for i in range(len(boxes)):
|
|
box = boxes[i]
|
|
if not hasattr(box, 'cls') or not hasattr(box, 'conf') or not hasattr(box, 'xyxy'):
|
|
print(f"警告: 第 {i} 个框缺少必要的属性")
|
|
continue
|
|
|
|
try:
|
|
class_id = int(box.cls.item())
|
|
formatted_result = {
|
|
"bbox": box.xyxy.tolist()[0],
|
|
"confidence": box.conf.item(),
|
|
"class_id": class_id,
|
|
"class": names.get(class_id, f"Unknown-{class_id}")
|
|
}
|
|
formatted_results.append(formatted_result)
|
|
except Exception as e:
|
|
print(f"处理第 {i} 个框时出错: {str(e)}")
|
|
|
|
# print("格式化后的结果:", formatted_results)
|
|
return formatted_results
|
|
|
|
def draw_results(self, frame, results):
|
|
for r in results:
|
|
annotated_frame = r.plot()
|
|
return annotated_frame
|
|
|
|
detector = fallDetector(MODEL_PATH)
|
|
|
|
def process_image(image_data, filename):
|
|
try:
|
|
img = cv2.imdecode(np.frombuffer(image_data, np.uint8), cv2.IMREAD_COLOR)
|
|
|
|
results = detector.detect(img)
|
|
|
|
# Format results for JSON
|
|
json_results = detector.format_results(results)
|
|
|
|
# Draw results on image
|
|
annotated_img = detector.draw_results(img, results)
|
|
|
|
# Save annotated image
|
|
annotated_filename = f"fall_{filename}"
|
|
annotated_path = os.path.join(RESULT_DIR, annotated_filename)
|
|
cv2.imwrite(annotated_path, annotated_img)
|
|
|
|
return json_results, annotated_filename
|
|
except Exception as e:
|
|
print(f"Error processing image: {str(e)}")
|
|
return None, None
|
|
|
|
|
|
|
|
def process_video(video_data, filename):
|
|
try:
|
|
temp_video_path = os.path.join(UPLOAD_DIR, f"fall_{filename}")
|
|
with open(temp_video_path, 'wb') as temp_video:
|
|
temp_video.write(video_data)
|
|
|
|
cap = cv2.VideoCapture(temp_video_path)
|
|
frame_count = 0
|
|
json_results = []
|
|
|
|
# Get video properties
|
|
fps = int(cap.get(cv2.CAP_PROP_FPS))
|
|
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
|
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
|
|
|
# Create output video file
|
|
annotated_filename = f"fall_{filename}"
|
|
output_path = os.path.join(RESULT_DIR, annotated_filename)
|
|
out = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height))
|
|
|
|
while cap.isOpened():
|
|
ret, frame = cap.read()
|
|
if not ret:
|
|
break
|
|
|
|
results = detector.detect(frame)
|
|
frame_json_results = detector.format_results(results)
|
|
json_results.append({"frame": frame_count, "detections": frame_json_results})
|
|
|
|
annotated_frame = detector.draw_results(frame, results)
|
|
out.write(annotated_frame)
|
|
|
|
frame_count += 1
|
|
|
|
cap.release()
|
|
out.release()
|
|
|
|
# Clean up temporary input video file
|
|
os.remove(temp_video_path)
|
|
|
|
return json_results, annotated_filename
|
|
except Exception as e:
|
|
print(f"Error processing video: {str(e)}")
|
|
return None, None
|
|
|
|
@fall_app.post("/upload")
|
|
async def upload_file(file: UploadFile = File(...)):
|
|
content = await file.read()
|
|
file_extension = os.path.splitext(file.filename)[1].lower()
|
|
new_filename = f"{uuid.uuid4()}{file_extension}"
|
|
|
|
# Save the original file
|
|
original_file_path = os.path.join(UPLOAD_DIR, new_filename)
|
|
with open(original_file_path, "wb") as f:
|
|
f.write(content)
|
|
|
|
# Send processing task to Kafka
|
|
producer.send(KAFKA_TOPIC, json.dumps({
|
|
"filename": new_filename,
|
|
"file_type": "image" if file_extension in ['.jpg', '.jpeg', '.png'] else "video"
|
|
}).encode('utf-8'))
|
|
|
|
# Set initial status in Redis
|
|
redis_key = f"fall_result:{new_filename}"
|
|
redis_client.set(redis_key, json.dumps({"status": "queued"}))
|
|
|
|
return JSONResponse(content={"message": "File uploaded and queued for processing", "filename": new_filename})
|
|
|
|
@fall_app.get("/result/{filename}")
|
|
async def get_fall_result(filename: str):
|
|
redis_key = f"fall_result:{filename}"
|
|
result = redis_client.get(redis_key)
|
|
if result:
|
|
result_data = json.loads(result)
|
|
return JSONResponse(content=result_data) # 直接返回整个结果,包括 status
|
|
else:
|
|
raise HTTPException(status_code=404, detail="Result not found")
|
|
|
|
@fall_app.get("/annotated/{filename}")
|
|
async def get_annotated_file(filename: str):
|
|
redis_key = f"fall_result:{filename}"
|
|
result = redis_client.get(redis_key)
|
|
if result:
|
|
result_data = json.loads(result)
|
|
if result_data["status"] == "completed":
|
|
annotated_filename = result_data["annotated_filename"]
|
|
file_path = os.path.join(RESULT_DIR, annotated_filename)
|
|
if os.path.exists(file_path):
|
|
def iterfile():
|
|
with open(file_path, mode="rb") as file_like:
|
|
yield from file_like
|
|
file_extension = os.path.splitext(annotated_filename)[1].lower()
|
|
return StreamingResponse(iterfile(), media_type=f"image/{file_extension[1:]}" if file_extension in ['.jpg', '.jpeg', '.png'] else "video/mp4")
|
|
|
|
raise HTTPException(status_code=404, detail="Annotated file not found")
|
|
|
|
def process_task():
|
|
for message in consumer:
|
|
task = message.value
|
|
filename = task['filename']
|
|
file_type = task['file_type']
|
|
|
|
file_path = os.path.join(UPLOAD_DIR, filename)
|
|
|
|
# Update status to "processing"
|
|
redis_key = f"fall_result:{filename}"
|
|
redis_client.set(redis_key, json.dumps({"status": "processing"}))
|
|
|
|
try:
|
|
if file_type == "image":
|
|
with open(file_path, 'rb') as f:
|
|
content = f.read()
|
|
json_results, annotated_filename = process_image(content, filename)
|
|
else:
|
|
with open(file_path, 'rb') as f:
|
|
content = f.read()
|
|
json_results, annotated_filename = process_video(content, filename)
|
|
|
|
if json_results and annotated_filename:
|
|
redis_client.set(redis_key, json.dumps({
|
|
"json_results": json_results,
|
|
"status": "completed",
|
|
"annotated_filename": annotated_filename
|
|
}))
|
|
else:
|
|
redis_client.set(redis_key, json.dumps({"status": "failed"}))
|
|
except Exception as e:
|
|
print(f"Error processing task: {str(e)}")
|
|
redis_client.set(redis_key, json.dumps({"status": "failed", "error": str(e)}))
|
|
|
|
def listen_redis_changes():
|
|
pubsub = redis_client.pubsub()
|
|
pubsub.psubscribe('__keyspace@3__:fall_result:*') # 监听所有fall_result键的变化
|
|
|
|
for message in pubsub.listen():
|
|
if message['type'] == 'pmessage':
|
|
key = message['channel'].decode('utf-8').split(':')[-1]
|
|
operation = message['data'].decode('utf-8')
|
|
|
|
if operation == 'set':
|
|
value = redis_client.get(f"fall_result:{key}")
|
|
if value:
|
|
result = json.loads(value)
|
|
print(f"Status update for {key}: {result['status']}")
|
|
|
|
# 这里可以添加其他处理逻辑,比如发送通知等
|
|
|
|
if __name__ == "__main__":
|
|
# 启动处理任务的线程
|
|
threading.Thread(target=process_task, daemon=True).start()
|
|
|
|
# 启动Redis监听线程
|
|
threading.Thread(target=listen_redis_changes, daemon=True).start()
|
|
|
|
uvicorn.run(app, host="0.0.0.0", port=7002) |