290 lines
10 KiB
Python
290 lines
10 KiB
Python
import os
|
|
import cv2
|
|
import torch
|
|
import numpy as np
|
|
from redis import Redis
|
|
from ultralytics import YOLO
|
|
import json
|
|
from kafka import KafkaConsumer
|
|
import threading
|
|
import redis
|
|
import torch
|
|
from config import *
|
|
torch.cuda.set_device(1)
|
|
|
|
# 配置
|
|
MODEL_PATH = FACE_MODEL_PATH
|
|
KAFKA_BROKER = KAFKA_BROKER
|
|
KAFKA_TOPIC = WORKER_CONFIGS["face"]["kafka_topic"]
|
|
KAFKA_GROUP_ID = f"face_{KAFKA_GROUP_ID_PREFIX}"
|
|
|
|
REDIS_HOST = REDIS_HOST
|
|
REDIS_PORT = REDIS_PORT
|
|
REDIS_PASSWORD = REDIS_PASSWORD
|
|
REDIS_DB = WORKER_CONFIGS["face"]["redis_db"] # Worker使用的Redis DB
|
|
MAIN_REDIS_DB = MAIN_REDIS_DB # 主Redis DB
|
|
|
|
UPLOAD_DIR = UPLOAD_DIR
|
|
RESULT_DIR = RESULT_DIR
|
|
|
|
# 初始化 Kafka
|
|
consumer = KafkaConsumer(
|
|
KAFKA_TOPIC,
|
|
bootstrap_servers=[KAFKA_BROKER],
|
|
group_id=KAFKA_GROUP_ID,
|
|
auto_offset_reset='earliest',
|
|
enable_auto_commit=True,
|
|
value_deserializer=lambda x: json.loads(x.decode('utf-8'))
|
|
)
|
|
|
|
# 初始化 Redis
|
|
redis_client = Redis(
|
|
host=REDIS_HOST,
|
|
port=REDIS_PORT,
|
|
password=REDIS_PASSWORD,
|
|
db=REDIS_DB
|
|
)
|
|
|
|
main_redis_client = Redis(
|
|
host=REDIS_HOST,
|
|
port=REDIS_PORT,
|
|
password=REDIS_PASSWORD,
|
|
db=MAIN_REDIS_DB
|
|
)
|
|
|
|
class faceDetector:
|
|
def __init__(self, model_path):
|
|
self.model = YOLO(model_path).to('cuda:1')
|
|
|
|
def detect(self, frame):
|
|
results = self.model(frame, device='cuda:1')
|
|
return results
|
|
|
|
def format_results(self, results, original_shape):
|
|
formatted_results = []
|
|
for r in results:
|
|
boxes = r.boxes
|
|
keypoints = r.keypoints
|
|
for i in range(len(boxes)):
|
|
box = boxes[i]
|
|
kpts = keypoints[i]
|
|
|
|
# 调整边界框坐标以适应原始图像大小
|
|
orig_h, orig_w = original_shape[:2]
|
|
model_h, model_w = r.orig_shape
|
|
scale_x, scale_y = orig_w / model_w, orig_h / model_h
|
|
|
|
bbox = box.xyxy[0].cpu().numpy()
|
|
bbox_scaled = [
|
|
bbox[0] * scale_x, bbox[1] * scale_y,
|
|
bbox[2] * scale_x, bbox[3] * scale_y
|
|
]
|
|
|
|
# 调整关键点坐标以适应原始图像大小
|
|
kpts_scaled = kpts.xy[0].cpu().numpy() * np.array([scale_x, scale_y])
|
|
|
|
formatted_results.append({
|
|
"bbox": bbox_scaled,
|
|
"confidence": box.conf.item(),
|
|
"keypoints": kpts_scaled.tolist()
|
|
})
|
|
return formatted_results
|
|
|
|
def draw_results(self, frame, results):
|
|
annotated_frame = frame.copy()
|
|
for r in results:
|
|
bbox = r["bbox"]
|
|
keypoints = r["keypoints"]
|
|
|
|
# 绘制边界框
|
|
cv2.rectangle(annotated_frame,
|
|
(int(bbox[0]), int(bbox[1])),
|
|
(int(bbox[2]), int(bbox[3])),
|
|
(0, 255, 0), 2)
|
|
|
|
# 绘制关键点
|
|
for kp in keypoints:
|
|
cv2.circle(annotated_frame,
|
|
(int(kp[0]), int(kp[1])),
|
|
5, (255, 0, 0), -1)
|
|
return annotated_frame
|
|
|
|
|
|
detector = faceDetector(MODEL_PATH)
|
|
|
|
def process_image(image_path):
|
|
try:
|
|
original_img = cv2.imread(image_path)
|
|
original_shape = original_img.shape
|
|
|
|
img = cv2.cvtColor(original_img, cv2.COLOR_BGR2RGB)
|
|
img = cv2.resize(img, (640, 640))
|
|
img = img.transpose((2, 0, 1))
|
|
img = np.ascontiguousarray(img)
|
|
img = torch.from_numpy(img).float()
|
|
img /= 255.0
|
|
img = img.unsqueeze(0)
|
|
|
|
results = detector.detect(img)
|
|
|
|
json_results = detector.format_results(results, original_shape)
|
|
|
|
annotated_img = detector.draw_results(original_img, json_results)
|
|
|
|
return json_results, annotated_img
|
|
except Exception as e:
|
|
print(f"处理图像时出错: {str(e)}")
|
|
return None, None
|
|
|
|
def process_video(video_path):
|
|
try:
|
|
cap = cv2.VideoCapture(video_path)
|
|
frame_count = 0
|
|
json_results = []
|
|
|
|
fps = int(cap.get(cv2.CAP_PROP_FPS))
|
|
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
|
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
|
original_shape = (height, width)
|
|
|
|
out = cv2.VideoWriter(video_path.replace(UPLOAD_DIR, RESULT_DIR), cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height))
|
|
|
|
while cap.isOpened():
|
|
ret, frame = cap.read()
|
|
if not ret:
|
|
break
|
|
|
|
if frame_count % fps == 0:
|
|
preprocessed_frame = preprocess_frame(frame)
|
|
|
|
results = detector.detect(preprocessed_frame)
|
|
frame_json_results = detector.format_results(results, original_shape)
|
|
json_results.append({"frame": frame_count, "detections": frame_json_results})
|
|
|
|
annotated_frame = detector.draw_results(frame, frame_json_results)
|
|
out.write(annotated_frame)
|
|
|
|
frame_count += 1
|
|
|
|
cap.release()
|
|
out.release()
|
|
|
|
return json_results
|
|
except Exception as e:
|
|
print(f"处理视频时出错: {str(e)}")
|
|
return None
|
|
|
|
def preprocess_frame(frame):
|
|
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
|
frame_resized = cv2.resize(frame_rgb, (640, 640))
|
|
frame_transposed = frame_resized.transpose((2, 0, 1))
|
|
frame_contiguous = np.ascontiguousarray(frame_transposed)
|
|
frame_tensor = torch.from_numpy(frame_contiguous).float()
|
|
frame_normalized = frame_tensor / 255.0
|
|
frame_batched = frame_normalized.unsqueeze(0)
|
|
return frame_batched
|
|
|
|
def process_task():
|
|
print("开始处理任务,等待Kafka消息...")
|
|
for message in consumer:
|
|
print(f"收到Kafka消息: topic={message.topic}, partition={message.partition}, offset={message.offset}")
|
|
task = message.value
|
|
task_id = task['task_id']
|
|
filename = task['filename']
|
|
file_type = task['file_type']
|
|
|
|
print(f"解析任务信息: ID={task_id}, 文件名={filename}, 类型={file_type}")
|
|
|
|
file_path = os.path.join(UPLOAD_DIR, filename)
|
|
# 检查键类型并更新状态
|
|
task_key = f"task:{task_id}"
|
|
try:
|
|
key_type = main_redis_client.type(task_key)
|
|
if key_type != b'hash':
|
|
main_redis_client.delete(task_key)
|
|
main_redis_client.hset(task_key, "status", "processing")
|
|
print(f"任务 {task_id} 状态更新为 'processing'")
|
|
except redis.exceptions.ResponseError as e:
|
|
print(f"更新任务 {task_id} 状态时出错: {str(e)}")
|
|
continue # 跳过这个任务,继续处理下一个
|
|
|
|
try:
|
|
if file_type == "image":
|
|
print(f"开始处理图像: {filename}")
|
|
json_results, annotated_img = process_image(file_path)
|
|
if json_results and annotated_img is not None:
|
|
result_filename = f"face_{filename}"
|
|
result_path = os.path.join(RESULT_DIR, result_filename)
|
|
cv2.imwrite(result_path, annotated_img)
|
|
|
|
redis_client.hmset(f"face_result:{task_id}", {
|
|
"result": json.dumps(json_results),
|
|
"result_file": result_filename
|
|
})
|
|
main_redis_client.hmset(f"task:{task_id}", {
|
|
"status": "completed",
|
|
"result_type": "face",
|
|
"result_key": f"face_result:{task_id}"
|
|
})
|
|
print(f"图像 {filename} 处理完成,结果已保存")
|
|
else:
|
|
print(f"图像 {filename} 处理失败")
|
|
main_redis_client.hset(f"task:{task_id}", "status", "failed")
|
|
else: # video
|
|
print(f"开始处理视频: {filename}")
|
|
json_results = process_video(file_path)
|
|
if json_results:
|
|
result_filename = f"face_{filename}"
|
|
redis_client.hmset(f"face_result:{task_id}", {
|
|
"result": json.dumps(json_results),
|
|
"result_file": result_filename
|
|
})
|
|
main_redis_client.hmset(f"task:{task_id}", {
|
|
"status": "completed",
|
|
"result_type": "face",
|
|
"result_key": f"face_result:{task_id}"
|
|
})
|
|
print(f"视频 {filename} 处理完成,结果已保存")
|
|
else:
|
|
print(f"视频 {filename} 处理失败")
|
|
main_redis_client.hset(f"task:{task_id}", "status", "failed")
|
|
except Exception as e:
|
|
print(f"处理任务 {task_id} 时出错: {str(e)}")
|
|
main_redis_client.hmset(f"task:{task_id}", {
|
|
"status": "failed",
|
|
"error": str(e)
|
|
})
|
|
|
|
print(f"任务 {task_id} 处理完毕,等待下一个Kafka消息...")
|
|
def listen_redis_changes():
|
|
pubsub = redis_client.pubsub()
|
|
pubsub.psubscribe('__keyspace@3__:face_result:*') # 监听所有face_result键的变化
|
|
|
|
for message in pubsub.listen():
|
|
if message['type'] == 'pmessage':
|
|
key = message['channel'].decode('utf-8').split(':')[-1]
|
|
operation = message['data'].decode('utf-8')
|
|
|
|
if operation == 'hmset':
|
|
value = redis_client.hgetall(f"face_result:{key}")
|
|
if value:
|
|
result = {k.decode(): v.decode() for k, v in value.items()}
|
|
print(f"Result update for task {key}: {result}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
print("face处理程序启动...")
|
|
# 启动处理任务的线程
|
|
task_thread = threading.Thread(target=process_task, daemon=True)
|
|
task_thread.start()
|
|
print("任务处理线程已启动")
|
|
|
|
# 启动Redis监听线程
|
|
redis_thread = threading.Thread(target=listen_redis_changes, daemon=True)
|
|
redis_thread.start()
|
|
print("Redis监听线程已启动")
|
|
|
|
print("主程序进入等待状态...")
|
|
# 保持主线程运行
|
|
task_thread.join()
|
|
redis_thread.join() |