Files
2025-01-12 06:15:15 +00:00

299 lines
11 KiB
Python

import os
import cv2
import torch
import numpy as np
from redis import Redis
import json
from kafka import KafkaConsumer
import threading
import redis
import torch
import mediapipe as mp
from mediapipe.tasks import python
from mediapipe.tasks.python import vision
from config import *
# 配置
MODEL_PATH = MEDIAPIPE_MODEL_PATH
KAFKA_BROKER = KAFKA_BROKER
KAFKA_TOPIC = WORKER_CONFIGS["mediapipe"]["kafka_topic"]
KAFKA_GROUP_ID = f"mediapipe_{KAFKA_GROUP_ID_PREFIX}"
REDIS_HOST = REDIS_HOST
REDIS_PORT = REDIS_PORT
REDIS_PASSWORD = REDIS_PASSWORD
REDIS_DB = WORKER_CONFIGS["mediapipe"]["redis_db"] # Worker使用的Redis DB
MAIN_REDIS_DB = MAIN_REDIS_DB # 主Redis DB
UPLOAD_DIR = UPLOAD_DIR
RESULT_DIR = RESULT_DIR
# Ensure directories exist
os.makedirs(UPLOAD_DIR, exist_ok=True)
os.makedirs(RESULT_DIR, exist_ok=True)
# Initialize Kafka
consumer = KafkaConsumer(
KAFKA_TOPIC,
bootstrap_servers=[KAFKA_BROKER],
group_id=KAFKA_GROUP_ID,
auto_offset_reset='earliest',
enable_auto_commit=True,
value_deserializer=lambda x: json.loads(x.decode('utf-8'))
)
# 初始化 Redis
redis_client = Redis(
host=REDIS_HOST,
port=REDIS_PORT,
password=REDIS_PASSWORD,
db=REDIS_DB
)
main_redis_client = Redis(
host=REDIS_HOST,
port=REDIS_PORT,
password=REDIS_PASSWORD,
db=MAIN_REDIS_DB
)
class mediapipeEmbedder:
def __init__(self, model_path):
base_options = python.BaseOptions(model_asset_path=model_path)
options = vision.FaceLandmarkerOptions(
base_options=base_options,
output_face_blendshapes=True,
output_facial_transformation_matrixes=True,
num_faces=1
)
self.detector = vision.FaceLandmarker.create_from_options(options)
def get_mediapipe_landmarks(self, image):
mp_image = mp.Image(image_format=mp.ImageFormat.SRGB, data=image)
detection_result = self.detector.detect(mp_image)
if detection_result.face_landmarks:
return np.array([(lm.x, lm.y, lm.z) for lm in detection_result.face_landmarks[0]])
return None
def process_image(self, image_data):
img = cv2.imdecode(np.frombuffer(image_data, np.uint8), cv2.IMREAD_COLOR)
landmarks = self.get_mediapipe_landmarks(img)
if landmarks is not None:
# Calculate a more detailed mediapipe embedding
embedding = self.calculate_detailed_embedding(landmarks)
# Draw landmarks on the image
for lm in landmarks:
cv2.circle(img, (int(lm[0]*img.shape[1]), int(lm[1]*img.shape[0])), 2, (0,255,0), -1)
return {
"embedding": embedding,
"landmarks": landmarks.tolist()
}, img
else:
return None, img
def calculate_detailed_embedding(self, landmarks):
# Calculate various statistical features
mean = np.mean(landmarks, axis=0)
std = np.std(landmarks, axis=0)
median = np.median(landmarks, axis=0)
min_vals = np.min(landmarks, axis=0)
max_vals = np.max(landmarks, axis=0)
# Calculate pairwise distances between key facial landmarks
nose_tip = landmarks[4]
left_eye = landmarks[159]
right_eye = landmarks[386]
left_mouth = landmarks[61]
right_mouth = landmarks[291]
eye_distance = np.linalg.norm(left_eye - right_eye)
mouth_width = np.linalg.norm(left_mouth - right_mouth)
nose_to_mouth = np.linalg.norm(nose_tip - (left_mouth + right_mouth) / 2)
# Calculate face shape features
face_width = np.max(landmarks[:, 0]) - np.min(landmarks[:, 0])
face_height = np.max(landmarks[:, 1]) - np.min(landmarks[:, 1])
face_depth = np.max(landmarks[:, 2]) - np.min(landmarks[:, 2])
# Combine all features into a single embedding
embedding = np.concatenate([
mean, std, median, min_vals, max_vals,
[eye_distance, mouth_width, nose_to_mouth, face_width, face_height, face_depth]
])
return embedding.tolist()
embedder = mediapipeEmbedder(MODEL_PATH)
def process_image(image_data, filename):
try:
results, annotated_img = embedder.process_image(image_data)
if results:
# Save annotated image
annotated_filename = f"mediapipe_{filename}"
annotated_path = os.path.join(RESULT_DIR, annotated_filename)
cv2.imwrite(annotated_path, annotated_img)
return results, annotated_filename
else:
print(f"No face landmarks detected in image: {filename}")
return None, None
except Exception as e:
print(f"Error processing image {filename}: {str(e)}")
import traceback
traceback.print_exc()
return None, None
def process_video(video_data, filename):
try:
temp_video_path = os.path.join(UPLOAD_DIR, f"mediapipe_{filename}")
with open(temp_video_path, 'wb') as temp_video:
temp_video.write(video_data)
cap = cv2.VideoCapture(temp_video_path)
frame_count = 0
results = []
fps = int(cap.get(cv2.CAP_PROP_FPS))
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
annotated_filename = f"mediapipe_{filename}"
output_path = os.path.join(RESULT_DIR, annotated_filename)
out = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height))
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
if frame_count % fps == 0:
frame_results, annotated_frame = embedder.process_image(cv2.imencode('.jpg', frame)[1].tobytes())
if frame_results:
results.append({"frame": frame_count, "results": frame_results})
else:
annotated_frame = frame
out.write(annotated_frame)
frame_count += 1
cap.release()
out.release()
os.remove(temp_video_path)
return results, annotated_filename
except Exception as e:
print(f"Error processing video: {str(e)}")
return None, None
def process_task():
print("开始处理任务,等待Kafka消息...")
for message in consumer:
print(f"收到Kafka消息: topic={message.topic}, partition={message.partition}, offset={message.offset}")
task = message.value
task_id = task['task_id']
filename = task['filename']
file_type = task['file_type']
print(f"解析任务信息: ID={task_id}, 文件名={filename}, 类型={file_type}")
file_path = os.path.join(UPLOAD_DIR, filename)
# 检查键类型并更新状态
task_key = f"task:{task_id}"
try:
key_type = main_redis_client.type(task_key)
if key_type != b'hash':
main_redis_client.delete(task_key)
main_redis_client.hset(f"task:{task_id}", "status", "processing")
print(f"任务 {task_id} 状态更新为 'processing'")
except redis.exceptions.ResponseError as e:
print(f"更新任务 {task_id} 状态时出错: {str(e)}")
continue # 跳过这个任务,继续处理下一个
try:
if file_type == "image":
print(f"开始处理图像: {filename}")
with open(file_path, 'rb') as f:
image_data = f.read()
json_results, annotated_filename = process_image(image_data, filename)
if json_results and annotated_filename is not None:
result_path = os.path.join(RESULT_DIR, annotated_filename)
redis_client.hset(f"mediapipe_result:{task_id}", mapping={
"result": json.dumps(json_results),
"result_file": annotated_filename
})
main_redis_client.hset(f"task:{task_id}", "status", "completed")
main_redis_client.hset(f"task:{task_id}", "result_type", "mediapipe")
main_redis_client.hset(f"task:{task_id}", "result_key", f"mediapipe_result:{task_id}")
print(f"图像 {filename} 处理完成,结果已保存")
else:
print(f"图像 {filename} 处理失败")
main_redis_client.hset(f"task:{task_id}", "status", "failed")
else: # video
print(f"开始处理视频: {filename}")
with open(file_path, 'rb') as f:
video_data = f.read()
json_results, annotated_filename = process_video(video_data, filename)
if json_results and annotated_filename:
redis_client.hset(f"mediapipe_result:{task_id}", mapping={
"result": json.dumps(json_results),
"result_file": annotated_filename
})
main_redis_client.hset(f"task:{task_id}", "status", "completed")
main_redis_client.hset(f"task:{task_id}", "result_type", "mediapipe")
main_redis_client.hset(f"task:{task_id}", "result_key", f"mediapipe_result:{task_id}")
print(f"视频 {filename} 处理完成,结果已保存")
else:
print(f"视频 {filename} 处理失败")
main_redis_client.hset(f"task:{task_id}", "status", "failed")
except Exception as e:
print(f"处理任务 {task_id} 时出错: {str(e)}")
main_redis_client.hset(f"task:{task_id}", {
"status": "failed",
"error": str(e)
})
print(f"任务 {task_id} 处理完毕,等待下一个Kafka消息...")
def listen_redis_changes():
pubsub = redis_client.pubsub()
pubsub.psubscribe('__keyspace@3__:mediapipe_result:*') # 监听所有fall_result键的变化
for message in pubsub.listen():
if message['type'] == 'pmessage':
key = message['channel'].decode('utf-8').split(':')[-1]
operation = message['data'].decode('utf-8')
if operation == 'hset':
value = redis_client.hgetall(f"mediapipe_result:{key}")
if value:
result = {k.decode(): v.decode() for k, v in value.items()}
print(f"Result update for task {key}: {result}")
if __name__ == "__main__":
print("mediapipe处理程序启动...")
# 启动处理任务的线程
task_thread = threading.Thread(target=process_task, daemon=True)
task_thread.start()
print("任务处理线程已启动")
# 启动Redis监听线程
redis_thread = threading.Thread(target=listen_redis_changes, daemon=True)
redis_thread.start()
print("Redis监听线程已启动")
print("主程序进入等待状态...")
# 保持主线程运行
task_thread.join()
redis_thread.join()