184 lines
6.3 KiB
Python
184 lines
6.3 KiB
Python
|
|
import json
|
|
import cv2
|
|
import numpy as np
|
|
from kafka import KafkaConsumer
|
|
from ultralytics import YOLO
|
|
import threading
|
|
import redis
|
|
import base64
|
|
import yaml
|
|
from PIL import Image
|
|
|
|
# 加载配置文件
|
|
with open('worker_sys/function/config.yaml', 'r') as file:
|
|
config = yaml.safe_load(file)
|
|
|
|
class YOLOv8nPoseProcessor:
|
|
def __init__(self, model_path):
|
|
self.model = YOLO(model_path)
|
|
|
|
def process_image(self, img):
|
|
results = self.model(img)
|
|
return results
|
|
|
|
def format_results(self, results):
|
|
result_data = []
|
|
for result in results:
|
|
boxes = result.boxes.xywh.tolist()
|
|
keypoints = result.keypoints.xy.tolist() if hasattr(result, 'keypoints') else None
|
|
classes = result.boxes.cls.tolist()
|
|
confs = result.boxes.conf.tolist()
|
|
|
|
for i, (box, cls, conf) in enumerate(zip(boxes, classes, confs)):
|
|
result_data.append({
|
|
'box': box,
|
|
'keypoints': keypoints[i] if keypoints else None,
|
|
'class': int(cls),
|
|
'class_name': self.model.names[int(cls)],
|
|
'confidence': float(conf)
|
|
})
|
|
|
|
if not result_data:
|
|
result_data.append({
|
|
'box': None,
|
|
'keypoints': None,
|
|
'class': None,
|
|
'class_name': None,
|
|
'confidence': None,
|
|
'message': 'No object detected'
|
|
})
|
|
|
|
return json.dumps(result_data)
|
|
|
|
class ImageProcessingNode:
|
|
def __init__(self, bootstrap_servers, input_topic, model_path, group_id, redis_host, redis_port, redis_password):
|
|
self.consumer = KafkaConsumer(
|
|
input_topic,
|
|
bootstrap_servers=bootstrap_servers,
|
|
group_id=group_id,
|
|
value_deserializer=lambda x: json.loads(x.decode('utf-8')),
|
|
auto_offset_reset='earliest'
|
|
)
|
|
self.input_topic = input_topic
|
|
self.yolo_processor = YOLOv8nPoseProcessor(model_path)
|
|
self.group_id = group_id
|
|
self.redis_client_db0 = redis.Redis(host=redis_host, port=redis_port, db=0, password=redis_password)
|
|
self.redis_client_db1 = redis.Redis(host=redis_host, port=redis_port, db=1, password=redis_password)
|
|
|
|
def process_and_produce(self, message):
|
|
|
|
cache_key = message.get('cache_key')
|
|
object_key = message.get('object')
|
|
etag = message.get('etag')
|
|
size = message.get('size')
|
|
|
|
print(f"Consumer {self.group_id} processing image with cache_key: {cache_key}")
|
|
|
|
|
|
# 从Redis db0获取图片数据
|
|
img_str = self.redis_client_db0.get(cache_key)
|
|
if img_str is None:
|
|
print(f"Error: Image data not found in Redis db0 for cache_key: {cache_key}")
|
|
return
|
|
|
|
# 将base64编码的图片数据转换为OpenCV格式
|
|
try:
|
|
img_data = base64.b64decode(img_str)
|
|
nparr = np.frombuffer(img_data, np.uint8)
|
|
img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
|
|
|
|
if img is None:
|
|
print(f"Error: Unable to decode image data for cache_key: {cache_key}")
|
|
return
|
|
|
|
results = self.yolo_processor.process_image(img)
|
|
formatted_results = self.yolo_processor.format_results(results)
|
|
|
|
# 创建包含所有必要信息的字典
|
|
result_data = {
|
|
"pose_results": json.loads(formatted_results),
|
|
'cache_key': cache_key,
|
|
"object": object_key,
|
|
"etag": etag,
|
|
"size": size
|
|
}
|
|
|
|
# 将结果保存到Redis db1
|
|
try:
|
|
serialized_data = json.dumps(result_data)
|
|
if not serialized_data:
|
|
print(f"Error: Serialized data is empty for cache_key: {cache_key}")
|
|
return
|
|
|
|
self.redis_client_db1.set(cache_key, serialized_data)
|
|
print(f"Consumer {self.group_id} processed and saved results to Redis db1 with cache_key {cache_key}")
|
|
except TypeError as e:
|
|
print(f"Error serializing data: {e}")
|
|
print(f"Problematic data: {result_data}")
|
|
except Exception as e:
|
|
print(f"Unexpected error when saving to Redis: {e}")
|
|
except Exception as e:
|
|
print(f"Error processing image: {e}")
|
|
|
|
|
|
def run(self):
|
|
print(f"Consumer {self.group_id} starting to consume messages from {self.input_topic}")
|
|
for message in self.consumer:
|
|
message_value = message.value
|
|
cache_key = message_value.get('cache_key')
|
|
if cache_key:
|
|
threading.Thread(target=self.process_and_produce, args=(message_value,)).start()
|
|
else:
|
|
print(f"Consumer {self.group_id} error: Received message without cache_key")
|
|
|
|
def start_consumer(kafka_bootstrap_servers, input_topic, model_path, group_id, redis_host, redis_port, redis_password):
|
|
node = ImageProcessingNode(
|
|
kafka_bootstrap_servers,
|
|
input_topic,
|
|
model_path,
|
|
group_id,
|
|
redis_host,
|
|
redis_port,
|
|
redis_password
|
|
)
|
|
print(f"Image Processing Node {group_id} initialized.")
|
|
print(f"Listening on topic: {input_topic}")
|
|
node.run()
|
|
|
|
if __name__ == "__main__":
|
|
# 创建多个消费者线程
|
|
# Kafka 配置
|
|
kafka_bootstrap_servers = config['kafka']['bootstrap_servers']
|
|
input_topic = config['kafka']['topics']['all_frames']['name']
|
|
num_consumers = config['kafka']['topics']['all_frames']['num_consumers']
|
|
# 模型路径
|
|
model_path = config['model']['pose-path']
|
|
# Redis 配置
|
|
redis_host = config['redis']['host']
|
|
redis_port = config['redis']['port']
|
|
redis_password = config['redis']['password']
|
|
|
|
|
|
consumer_threads = []
|
|
for i in range(num_consumers):
|
|
group_id = f'pose_group_{i}'
|
|
thread = threading.Thread(
|
|
target=start_consumer,
|
|
args=(
|
|
kafka_bootstrap_servers,
|
|
input_topic,
|
|
model_path,
|
|
group_id,
|
|
redis_host,
|
|
redis_port,
|
|
redis_password
|
|
)
|
|
)
|
|
consumer_threads.append(thread)
|
|
thread.start()
|
|
|
|
# 等待所有消费者线程完成
|
|
for thread in consumer_threads:
|
|
thread.join()
|