Files
2025-01-12 06:15:15 +00:00

292 lines
10 KiB
Python

import os
import cv2
import torch
import numpy as np
from redis import Redis
from ultralytics import YOLO
import json
from kafka import KafkaConsumer
import threading
import redis
import torch
from config import *
torch.cuda.set_device(CUDA_DEVICE_1)
# 配置
MODEL_PATH = POSE_MODEL_PATH
KAFKA_BROKER = KAFKA_BROKER
KAFKA_TOPIC = WORKER_CONFIGS["pose"]["kafka_topic"]
KAFKA_GROUP_ID = f"pose_{KAFKA_GROUP_ID_PREFIX}"
REDIS_HOST = REDIS_HOST
REDIS_PORT = REDIS_PORT
REDIS_PASSWORD = REDIS_PASSWORD
REDIS_DB = WORKER_CONFIGS["pose"]["redis_db"] # Worker使用的Redis DB
MAIN_REDIS_DB = MAIN_REDIS_DB # 主Redis DB
UPLOAD_DIR = UPLOAD_DIR
RESULT_DIR = RESULT_DIR
# 确保目录存在
os.makedirs(UPLOAD_DIR, exist_ok=True)
os.makedirs(RESULT_DIR, exist_ok=True)
# 初始化 Kafka
consumer = KafkaConsumer(
KAFKA_TOPIC,
bootstrap_servers=[KAFKA_BROKER],
group_id=KAFKA_GROUP_ID,
auto_offset_reset='earliest',
enable_auto_commit=True,
value_deserializer=lambda x: json.loads(x.decode('utf-8'))
)
# 初始化 Redis
redis_client = Redis(
host=REDIS_HOST,
port=REDIS_PORT,
password=REDIS_PASSWORD,
db=REDIS_DB
)
main_redis_client = Redis(
host=REDIS_HOST,
port=REDIS_PORT,
password=REDIS_PASSWORD,
db=MAIN_REDIS_DB
)
class PoseDetector:
def __init__(self, model_path):
self.model = YOLO(model_path).to(CUDA_DEVICE_1)
def detect(self, frame):
results = self.model(frame, device=CUDA_DEVICE_1)
return results
def format_results(self, results, original_shape):
formatted_results = []
for r in results:
boxes = r.boxes
keypoints = r.keypoints
for i in range(len(boxes)):
box = boxes[i]
kpts = keypoints[i]
# 调整边界框坐标以适应原始图像大小
orig_h, orig_w = original_shape[:2]
model_h, model_w = r.orig_shape
scale_x, scale_y = orig_w / model_w, orig_h / model_h
bbox = box.xyxy[0].cpu().numpy()
bbox_scaled = [
bbox[0] * scale_x, bbox[1] * scale_y,
bbox[2] * scale_x, bbox[3] * scale_y
]
# 调整关键点坐标以适应原始图像大小
kpts_scaled = kpts.xy[0].cpu().numpy() * np.array([scale_x, scale_y])
formatted_results.append({
"bbox": bbox_scaled,
"confidence": box.conf.item(),
"keypoints": kpts_scaled.tolist()
})
return formatted_results
def draw_results(self, frame, results):
annotated_frame = frame.copy()
for r in results:
bbox = r["bbox"]
keypoints = r["keypoints"]
# 绘制边界框
cv2.rectangle(annotated_frame,
(int(bbox[0]), int(bbox[1])),
(int(bbox[2]), int(bbox[3])),
(0, 255, 0), 2)
# 绘制关键点
for kp in keypoints:
cv2.circle(annotated_frame,
(int(kp[0]), int(kp[1])),
5, (255, 0, 0), -1)
return annotated_frame
detector = PoseDetector(MODEL_PATH)
def process_image(image_path):
try:
original_img = cv2.imread(image_path)
original_shape = original_img.shape
img = cv2.cvtColor(original_img, cv2.COLOR_BGR2RGB)
img = cv2.resize(img, (640, 640))
img = img.transpose((2, 0, 1))
img = np.ascontiguousarray(img)
img = torch.from_numpy(img).float()
img /= 255.0
img = img.unsqueeze(0)
results = detector.detect(img)
json_results = detector.format_results(results, original_shape)
annotated_img = detector.draw_results(original_img, json_results)
return json_results, annotated_img
except Exception as e:
print(f"处理图像时出错: {str(e)}")
return None, None
def process_video(video_path):
try:
cap = cv2.VideoCapture(video_path)
frame_count = 0
json_results = []
fps = int(cap.get(cv2.CAP_PROP_FPS))
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
original_shape = (height, width)
out = cv2.VideoWriter(video_path.replace(UPLOAD_DIR, RESULT_DIR), cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height))
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
if frame_count % fps == 0:
preprocessed_frame = preprocess_frame(frame)
results = detector.detect(preprocessed_frame)
frame_json_results = detector.format_results(results, original_shape)
json_results.append({"frame": frame_count, "detections": frame_json_results})
annotated_frame = detector.draw_results(frame, frame_json_results)
out.write(annotated_frame)
frame_count += 1
cap.release()
out.release()
return json_results
except Exception as e:
print(f"处理视频时出错: {str(e)}")
return None
def preprocess_frame(frame):
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frame_resized = cv2.resize(frame_rgb, (640, 640))
frame_transposed = frame_resized.transpose((2, 0, 1))
frame_contiguous = np.ascontiguousarray(frame_transposed)
frame_tensor = torch.from_numpy(frame_contiguous).float()
frame_normalized = frame_tensor / 255.0
frame_batched = frame_normalized.unsqueeze(0)
return frame_batched
def process_task():
print("开始处理任务,等待Kafka消息...")
for message in consumer:
print(f"收到Kafka消息: topic={message.topic}, partition={message.partition}, offset={message.offset}")
task = message.value
task_id = task['task_id']
filename = task['filename']
file_type = task['file_type']
print(f"解析任务信息: ID={task_id}, 文件名={filename}, 类型={file_type}")
file_path = os.path.join(UPLOAD_DIR, filename)
# 检查键类型并更新状态
task_key = f"task:{task_id}"
try:
key_type = main_redis_client.type(task_key)
if key_type != b'hash':
main_redis_client.delete(task_key)
main_redis_client.hset(task_key, "status", "processing")
print(f"任务 {task_id} 状态更新为 'processing'")
except redis.exceptions.ResponseError as e:
print(f"更新任务 {task_id} 状态时出错: {str(e)}")
continue # 跳过这个任务,继续处理下一个
try:
if file_type == "image":
print(f"开始处理图像: {filename}")
json_results, annotated_img = process_image(file_path)
if json_results and annotated_img is not None:
result_filename = f"pose_{filename}"
result_path = os.path.join(RESULT_DIR, result_filename)
cv2.imwrite(result_path, annotated_img)
redis_client.hmset(f"pose_result:{task_id}", {
"result": json.dumps(json_results),
"result_file": result_filename
})
main_redis_client.hmset(f"task:{task_id}", {
"status": "completed",
"result_type": "pose",
"result_key": f"pose_result:{task_id}"
})
print(f"图像 {filename} 处理完成,结果已保存")
else:
print(f"图像 {filename} 处理失败")
main_redis_client.hset(f"task:{task_id}", "status", "failed")
else: # video
print(f"开始处理视频: {filename}")
json_results = process_video(file_path)
if json_results:
result_filename = f"pose_{filename}"
redis_client.hmset(f"pose_result:{task_id}", {
"result": json.dumps(json_results),
"result_file": result_filename
})
main_redis_client.hmset(f"task:{task_id}", {
"status": "completed",
"result_type": "pose",
"result_key": f"pose_result:{task_id}"
})
print(f"视频 {filename} 处理完成,结果已保存")
else:
print(f"视频 {filename} 处理失败")
main_redis_client.hset(f"task:{task_id}", "status", "failed")
except Exception as e:
print(f"处理任务 {task_id} 时出错: {str(e)}")
main_redis_client.hmset(f"task:{task_id}", {
"status": "failed",
"error": str(e)
})
print(f"任务 {task_id} 处理完毕,等待下一个Kafka消息...")
def listen_redis_changes():
pubsub = redis_client.pubsub()
pubsub.psubscribe('__keyspace@3__:pose_result:*') # 监听所有pose_result键的变化
for message in pubsub.listen():
if message['type'] == 'pmessage':
key = message['channel'].decode('utf-8').split(':')[-1]
operation = message['data'].decode('utf-8')
if operation == 'hmset':
value = redis_client.hgetall(f"pose_result:{key}")
if value:
result = {k.decode(): v.decode() for k, v in value.items()}
print(f"Result update for task {key}: {result}")
if __name__ == "__main__":
print("pose处理程序启动...")
# 启动处理任务的线程
task_thread = threading.Thread(target=process_task, daemon=True)
task_thread.start()
print("任务处理线程已启动")
# 启动Redis监听线程
redis_thread = threading.Thread(target=listen_redis_changes, daemon=True)
redis_thread.start()
print("Redis监听线程已启动")
print("主程序进入等待状态...")
# 保持主线程运行
task_thread.join()
redis_thread.join()