Files
api/api_history/function/pose.py
T
2025-01-12 06:15:15 +00:00

184 lines
6.3 KiB
Python

import json
import cv2
import numpy as np
from kafka import KafkaConsumer
from ultralytics import YOLO
import threading
import redis
import base64
import yaml
from PIL import Image
# 加载配置文件
with open('worker_sys/function/config.yaml', 'r') as file:
config = yaml.safe_load(file)
class YOLOv8nPoseProcessor:
def __init__(self, model_path):
self.model = YOLO(model_path)
def process_image(self, img):
results = self.model(img)
return results
def format_results(self, results):
result_data = []
for result in results:
boxes = result.boxes.xywh.tolist()
keypoints = result.keypoints.xy.tolist() if hasattr(result, 'keypoints') else None
classes = result.boxes.cls.tolist()
confs = result.boxes.conf.tolist()
for i, (box, cls, conf) in enumerate(zip(boxes, classes, confs)):
result_data.append({
'box': box,
'keypoints': keypoints[i] if keypoints else None,
'class': int(cls),
'class_name': self.model.names[int(cls)],
'confidence': float(conf)
})
if not result_data:
result_data.append({
'box': None,
'keypoints': None,
'class': None,
'class_name': None,
'confidence': None,
'message': 'No object detected'
})
return json.dumps(result_data)
class ImageProcessingNode:
def __init__(self, bootstrap_servers, input_topic, model_path, group_id, redis_host, redis_port, redis_password):
self.consumer = KafkaConsumer(
input_topic,
bootstrap_servers=bootstrap_servers,
group_id=group_id,
value_deserializer=lambda x: json.loads(x.decode('utf-8')),
auto_offset_reset='earliest'
)
self.input_topic = input_topic
self.yolo_processor = YOLOv8nPoseProcessor(model_path)
self.group_id = group_id
self.redis_client_db0 = redis.Redis(host=redis_host, port=redis_port, db=0, password=redis_password)
self.redis_client_db1 = redis.Redis(host=redis_host, port=redis_port, db=1, password=redis_password)
def process_and_produce(self, message):
cache_key = message.get('cache_key')
object_key = message.get('object')
etag = message.get('etag')
size = message.get('size')
print(f"Consumer {self.group_id} processing image with cache_key: {cache_key}")
# 从Redis db0获取图片数据
img_str = self.redis_client_db0.get(cache_key)
if img_str is None:
print(f"Error: Image data not found in Redis db0 for cache_key: {cache_key}")
return
# 将base64编码的图片数据转换为OpenCV格式
try:
img_data = base64.b64decode(img_str)
nparr = np.frombuffer(img_data, np.uint8)
img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
if img is None:
print(f"Error: Unable to decode image data for cache_key: {cache_key}")
return
results = self.yolo_processor.process_image(img)
formatted_results = self.yolo_processor.format_results(results)
# 创建包含所有必要信息的字典
result_data = {
"pose_results": json.loads(formatted_results),
'cache_key': cache_key,
"object": object_key,
"etag": etag,
"size": size
}
# 将结果保存到Redis db1
try:
serialized_data = json.dumps(result_data)
if not serialized_data:
print(f"Error: Serialized data is empty for cache_key: {cache_key}")
return
self.redis_client_db1.set(cache_key, serialized_data)
print(f"Consumer {self.group_id} processed and saved results to Redis db1 with cache_key {cache_key}")
except TypeError as e:
print(f"Error serializing data: {e}")
print(f"Problematic data: {result_data}")
except Exception as e:
print(f"Unexpected error when saving to Redis: {e}")
except Exception as e:
print(f"Error processing image: {e}")
def run(self):
print(f"Consumer {self.group_id} starting to consume messages from {self.input_topic}")
for message in self.consumer:
message_value = message.value
cache_key = message_value.get('cache_key')
if cache_key:
threading.Thread(target=self.process_and_produce, args=(message_value,)).start()
else:
print(f"Consumer {self.group_id} error: Received message without cache_key")
def start_consumer(kafka_bootstrap_servers, input_topic, model_path, group_id, redis_host, redis_port, redis_password):
node = ImageProcessingNode(
kafka_bootstrap_servers,
input_topic,
model_path,
group_id,
redis_host,
redis_port,
redis_password
)
print(f"Image Processing Node {group_id} initialized.")
print(f"Listening on topic: {input_topic}")
node.run()
if __name__ == "__main__":
# 创建多个消费者线程
# Kafka 配置
kafka_bootstrap_servers = config['kafka']['bootstrap_servers']
input_topic = config['kafka']['topics']['all_frames']['name']
num_consumers = config['kafka']['topics']['all_frames']['num_consumers']
# 模型路径
model_path = config['model']['pose-path']
# Redis 配置
redis_host = config['redis']['host']
redis_port = config['redis']['port']
redis_password = config['redis']['password']
consumer_threads = []
for i in range(num_consumers):
group_id = f'pose_group_{i}'
thread = threading.Thread(
target=start_consumer,
args=(
kafka_bootstrap_servers,
input_topic,
model_path,
group_id,
redis_host,
redis_port,
redis_password
)
)
consumer_threads.append(thread)
thread.start()
# 等待所有消费者线程完成
for thread in consumer_threads:
thread.join()