299 lines
11 KiB
Python
299 lines
11 KiB
Python
import os
|
|
import cv2
|
|
import torch
|
|
import numpy as np
|
|
from redis import Redis
|
|
import json
|
|
from kafka import KafkaConsumer
|
|
import threading
|
|
import redis
|
|
import torch
|
|
import mediapipe as mp
|
|
from mediapipe.tasks import python
|
|
from mediapipe.tasks.python import vision
|
|
from config import *
|
|
# 配置
|
|
MODEL_PATH = MEDIAPIPE_MODEL_PATH
|
|
KAFKA_BROKER = KAFKA_BROKER
|
|
KAFKA_TOPIC = WORKER_CONFIGS["mediapipe"]["kafka_topic"]
|
|
KAFKA_GROUP_ID = f"mediapipe_{KAFKA_GROUP_ID_PREFIX}"
|
|
|
|
REDIS_HOST = REDIS_HOST
|
|
REDIS_PORT = REDIS_PORT
|
|
REDIS_PASSWORD = REDIS_PASSWORD
|
|
REDIS_DB = WORKER_CONFIGS["mediapipe"]["redis_db"] # Worker使用的Redis DB
|
|
MAIN_REDIS_DB = MAIN_REDIS_DB # 主Redis DB
|
|
|
|
UPLOAD_DIR = UPLOAD_DIR
|
|
RESULT_DIR = RESULT_DIR
|
|
|
|
|
|
# Ensure directories exist
|
|
os.makedirs(UPLOAD_DIR, exist_ok=True)
|
|
os.makedirs(RESULT_DIR, exist_ok=True)
|
|
|
|
# Initialize Kafka
|
|
consumer = KafkaConsumer(
|
|
KAFKA_TOPIC,
|
|
bootstrap_servers=[KAFKA_BROKER],
|
|
group_id=KAFKA_GROUP_ID,
|
|
auto_offset_reset='earliest',
|
|
enable_auto_commit=True,
|
|
value_deserializer=lambda x: json.loads(x.decode('utf-8'))
|
|
)
|
|
|
|
# 初始化 Redis
|
|
redis_client = Redis(
|
|
host=REDIS_HOST,
|
|
port=REDIS_PORT,
|
|
password=REDIS_PASSWORD,
|
|
db=REDIS_DB
|
|
)
|
|
|
|
main_redis_client = Redis(
|
|
host=REDIS_HOST,
|
|
port=REDIS_PORT,
|
|
password=REDIS_PASSWORD,
|
|
db=MAIN_REDIS_DB
|
|
)
|
|
|
|
|
|
class mediapipeEmbedder:
|
|
def __init__(self, model_path):
|
|
base_options = python.BaseOptions(model_asset_path=model_path)
|
|
options = vision.FaceLandmarkerOptions(
|
|
base_options=base_options,
|
|
output_face_blendshapes=True,
|
|
output_facial_transformation_matrixes=True,
|
|
num_faces=1
|
|
)
|
|
self.detector = vision.FaceLandmarker.create_from_options(options)
|
|
|
|
def get_mediapipe_landmarks(self, image):
|
|
mp_image = mp.Image(image_format=mp.ImageFormat.SRGB, data=image)
|
|
detection_result = self.detector.detect(mp_image)
|
|
if detection_result.face_landmarks:
|
|
return np.array([(lm.x, lm.y, lm.z) for lm in detection_result.face_landmarks[0]])
|
|
return None
|
|
|
|
def process_image(self, image_data):
|
|
img = cv2.imdecode(np.frombuffer(image_data, np.uint8), cv2.IMREAD_COLOR)
|
|
landmarks = self.get_mediapipe_landmarks(img)
|
|
|
|
if landmarks is not None:
|
|
# Calculate a more detailed mediapipe embedding
|
|
embedding = self.calculate_detailed_embedding(landmarks)
|
|
|
|
# Draw landmarks on the image
|
|
for lm in landmarks:
|
|
cv2.circle(img, (int(lm[0]*img.shape[1]), int(lm[1]*img.shape[0])), 2, (0,255,0), -1)
|
|
|
|
return {
|
|
"embedding": embedding,
|
|
"landmarks": landmarks.tolist()
|
|
}, img
|
|
else:
|
|
return None, img
|
|
|
|
def calculate_detailed_embedding(self, landmarks):
|
|
# Calculate various statistical features
|
|
mean = np.mean(landmarks, axis=0)
|
|
std = np.std(landmarks, axis=0)
|
|
median = np.median(landmarks, axis=0)
|
|
min_vals = np.min(landmarks, axis=0)
|
|
max_vals = np.max(landmarks, axis=0)
|
|
|
|
# Calculate pairwise distances between key facial landmarks
|
|
nose_tip = landmarks[4]
|
|
left_eye = landmarks[159]
|
|
right_eye = landmarks[386]
|
|
left_mouth = landmarks[61]
|
|
right_mouth = landmarks[291]
|
|
|
|
eye_distance = np.linalg.norm(left_eye - right_eye)
|
|
mouth_width = np.linalg.norm(left_mouth - right_mouth)
|
|
nose_to_mouth = np.linalg.norm(nose_tip - (left_mouth + right_mouth) / 2)
|
|
|
|
# Calculate face shape features
|
|
face_width = np.max(landmarks[:, 0]) - np.min(landmarks[:, 0])
|
|
face_height = np.max(landmarks[:, 1]) - np.min(landmarks[:, 1])
|
|
face_depth = np.max(landmarks[:, 2]) - np.min(landmarks[:, 2])
|
|
|
|
# Combine all features into a single embedding
|
|
embedding = np.concatenate([
|
|
mean, std, median, min_vals, max_vals,
|
|
[eye_distance, mouth_width, nose_to_mouth, face_width, face_height, face_depth]
|
|
])
|
|
|
|
return embedding.tolist()
|
|
|
|
embedder = mediapipeEmbedder(MODEL_PATH)
|
|
|
|
def process_image(image_data, filename):
|
|
try:
|
|
results, annotated_img = embedder.process_image(image_data)
|
|
|
|
if results:
|
|
# Save annotated image
|
|
annotated_filename = f"mediapipe_{filename}"
|
|
annotated_path = os.path.join(RESULT_DIR, annotated_filename)
|
|
cv2.imwrite(annotated_path, annotated_img)
|
|
|
|
return results, annotated_filename
|
|
else:
|
|
print(f"No face landmarks detected in image: {filename}")
|
|
return None, None
|
|
except Exception as e:
|
|
print(f"Error processing image {filename}: {str(e)}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
return None, None
|
|
|
|
def process_video(video_data, filename):
|
|
try:
|
|
temp_video_path = os.path.join(UPLOAD_DIR, f"mediapipe_{filename}")
|
|
with open(temp_video_path, 'wb') as temp_video:
|
|
temp_video.write(video_data)
|
|
|
|
cap = cv2.VideoCapture(temp_video_path)
|
|
frame_count = 0
|
|
results = []
|
|
|
|
fps = int(cap.get(cv2.CAP_PROP_FPS))
|
|
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
|
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
|
|
|
annotated_filename = f"mediapipe_{filename}"
|
|
output_path = os.path.join(RESULT_DIR, annotated_filename)
|
|
out = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height))
|
|
|
|
while cap.isOpened():
|
|
ret, frame = cap.read()
|
|
if not ret:
|
|
break
|
|
|
|
if frame_count % fps == 0:
|
|
frame_results, annotated_frame = embedder.process_image(cv2.imencode('.jpg', frame)[1].tobytes())
|
|
if frame_results:
|
|
results.append({"frame": frame_count, "results": frame_results})
|
|
else:
|
|
annotated_frame = frame
|
|
|
|
out.write(annotated_frame)
|
|
frame_count += 1
|
|
|
|
cap.release()
|
|
out.release()
|
|
|
|
os.remove(temp_video_path)
|
|
|
|
return results, annotated_filename
|
|
except Exception as e:
|
|
print(f"Error processing video: {str(e)}")
|
|
return None, None
|
|
|
|
def process_task():
|
|
print("开始处理任务,等待Kafka消息...")
|
|
for message in consumer:
|
|
print(f"收到Kafka消息: topic={message.topic}, partition={message.partition}, offset={message.offset}")
|
|
task = message.value
|
|
task_id = task['task_id']
|
|
filename = task['filename']
|
|
file_type = task['file_type']
|
|
|
|
print(f"解析任务信息: ID={task_id}, 文件名={filename}, 类型={file_type}")
|
|
|
|
file_path = os.path.join(UPLOAD_DIR, filename)
|
|
# 检查键类型并更新状态
|
|
task_key = f"task:{task_id}"
|
|
try:
|
|
key_type = main_redis_client.type(task_key)
|
|
if key_type != b'hash':
|
|
main_redis_client.delete(task_key)
|
|
main_redis_client.hset(f"task:{task_id}", "status", "processing")
|
|
print(f"任务 {task_id} 状态更新为 'processing'")
|
|
except redis.exceptions.ResponseError as e:
|
|
print(f"更新任务 {task_id} 状态时出错: {str(e)}")
|
|
continue # 跳过这个任务,继续处理下一个
|
|
|
|
try:
|
|
if file_type == "image":
|
|
print(f"开始处理图像: {filename}")
|
|
with open(file_path, 'rb') as f:
|
|
image_data = f.read()
|
|
json_results, annotated_filename = process_image(image_data, filename)
|
|
if json_results and annotated_filename is not None:
|
|
result_path = os.path.join(RESULT_DIR, annotated_filename)
|
|
|
|
redis_client.hset(f"mediapipe_result:{task_id}", mapping={
|
|
"result": json.dumps(json_results),
|
|
"result_file": annotated_filename
|
|
})
|
|
main_redis_client.hset(f"task:{task_id}", "status", "completed")
|
|
main_redis_client.hset(f"task:{task_id}", "result_type", "mediapipe")
|
|
main_redis_client.hset(f"task:{task_id}", "result_key", f"mediapipe_result:{task_id}")
|
|
print(f"图像 {filename} 处理完成,结果已保存")
|
|
else:
|
|
print(f"图像 {filename} 处理失败")
|
|
main_redis_client.hset(f"task:{task_id}", "status", "failed")
|
|
else: # video
|
|
print(f"开始处理视频: {filename}")
|
|
with open(file_path, 'rb') as f:
|
|
video_data = f.read()
|
|
json_results, annotated_filename = process_video(video_data, filename)
|
|
if json_results and annotated_filename:
|
|
redis_client.hset(f"mediapipe_result:{task_id}", mapping={
|
|
"result": json.dumps(json_results),
|
|
"result_file": annotated_filename
|
|
})
|
|
main_redis_client.hset(f"task:{task_id}", "status", "completed")
|
|
main_redis_client.hset(f"task:{task_id}", "result_type", "mediapipe")
|
|
main_redis_client.hset(f"task:{task_id}", "result_key", f"mediapipe_result:{task_id}")
|
|
|
|
|
|
|
|
print(f"视频 {filename} 处理完成,结果已保存")
|
|
else:
|
|
print(f"视频 {filename} 处理失败")
|
|
main_redis_client.hset(f"task:{task_id}", "status", "failed")
|
|
except Exception as e:
|
|
print(f"处理任务 {task_id} 时出错: {str(e)}")
|
|
main_redis_client.hset(f"task:{task_id}", {
|
|
"status": "failed",
|
|
"error": str(e)
|
|
})
|
|
|
|
print(f"任务 {task_id} 处理完毕,等待下一个Kafka消息...")
|
|
|
|
def listen_redis_changes():
|
|
pubsub = redis_client.pubsub()
|
|
pubsub.psubscribe('__keyspace@3__:mediapipe_result:*') # 监听所有fall_result键的变化
|
|
|
|
for message in pubsub.listen():
|
|
if message['type'] == 'pmessage':
|
|
key = message['channel'].decode('utf-8').split(':')[-1]
|
|
operation = message['data'].decode('utf-8')
|
|
|
|
if operation == 'hset':
|
|
value = redis_client.hgetall(f"mediapipe_result:{key}")
|
|
if value:
|
|
result = {k.decode(): v.decode() for k, v in value.items()}
|
|
print(f"Result update for task {key}: {result}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
print("mediapipe处理程序启动...")
|
|
# 启动处理任务的线程
|
|
task_thread = threading.Thread(target=process_task, daemon=True)
|
|
task_thread.start()
|
|
print("任务处理线程已启动")
|
|
|
|
# 启动Redis监听线程
|
|
redis_thread = threading.Thread(target=listen_redis_changes, daemon=True)
|
|
redis_thread.start()
|
|
print("Redis监听线程已启动")
|
|
|
|
print("主程序进入等待状态...")
|
|
# 保持主线程运行
|
|
task_thread.join()
|
|
redis_thread.join() |