363 lines
14 KiB
Python
363 lines
14 KiB
Python
import json
|
|
import io
|
|
from PIL import Image
|
|
import torch
|
|
from kafka import KafkaConsumer, KafkaProducer
|
|
from transformers import AutoModel, AutoTokenizer
|
|
import threading
|
|
import re
|
|
from datetime import datetime
|
|
import time
|
|
import base64
|
|
import numpy as np
|
|
import cv2
|
|
import redis
|
|
from redis import ConnectionPool
|
|
import yaml
|
|
|
|
# 加载配置文件
|
|
with open('worker_sys/function/config.yaml', 'r') as file:
|
|
config = yaml.safe_load(file)
|
|
|
|
class JSONEncoder(json.JSONEncoder):
|
|
def default(self, obj):
|
|
if isinstance(obj, datetime):
|
|
return obj.isoformat()
|
|
return json.JSONEncoder.default(self, obj)
|
|
|
|
class ImageSequenceProcessor:
|
|
def __init__(self, model_dir):
|
|
self.model = AutoModel.from_pretrained(model_dir, trust_remote_code=True,
|
|
attn_implementation='sdpa', torch_dtype=torch.bfloat16).eval().cuda()
|
|
self.tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
|
|
self.max_size = 512 # 设置最大尺寸
|
|
def extract_time_from_key(self, key_name):
|
|
# 从key_name中提取时间信息
|
|
time_str = key_name.split('_')[-2] + '_' + key_name.split('_')[-1].split('.')[0]
|
|
return datetime.strptime(time_str, "%Y%m%d_%H%M%S")
|
|
|
|
def process_image_sequence(self, image_data, key_names):
|
|
frames = []
|
|
image_times = []
|
|
for i, (img, key_name) in enumerate(zip(image_data, key_names)):
|
|
try:
|
|
if isinstance(img, np.ndarray):
|
|
# 确保图像是 RGB 格式
|
|
if img.shape[2] == 3:
|
|
frame = Image.fromarray(img)
|
|
else:
|
|
print(f"Unexpected number of channels for image {i}: {img.shape[2]}")
|
|
continue
|
|
else:
|
|
print(f"Unexpected data type for image {i}: {type(img)}")
|
|
continue
|
|
frames.append(frame)
|
|
image_times.append(self.extract_time_from_key(key_name))
|
|
print(f"Successfully processed frame {i}")
|
|
except Exception as e:
|
|
print(f"Error processing frame {i}: {str(e)}")
|
|
continue
|
|
|
|
if not frames:
|
|
raise ValueError("No valid frames were processed")
|
|
|
|
# 修改时间范围格式
|
|
start_time = min(image_times)
|
|
end_time = max(image_times)
|
|
time_range = {
|
|
'start': start_time.strftime('%Y-%m-%d %H:%M'),
|
|
'end': end_time.strftime('%Y-%m-%d %H:%M')
|
|
}
|
|
|
|
# # 计算平均时间间隔(以分钟为单位)
|
|
# if len(image_times) > 1:
|
|
# sequence_period_seconds = (image_times[-1] - image_times[0]).total_seconds()/ (len(image_times) - 1)
|
|
# else:
|
|
# sequence_period_seconds = 0
|
|
|
|
total_duration = (end_time - start_time).total_seconds()
|
|
num_images = len(image_times)
|
|
if num_images > 1:
|
|
sequence_period_seconds = total_duration / (num_images - 1)
|
|
else:
|
|
sequence_period_seconds = 0
|
|
|
|
question = "Analyze these 3 images as if they were frames from a video. Describe the scene in detail in Chinese, including the setting, number of people, their actions, and any changes or movements observed across the frames."
|
|
msgs = [
|
|
{'role': 'user', 'content': frames + [question]},
|
|
]
|
|
|
|
params = {
|
|
"use_image_id": False,
|
|
"max_slice_nums": 1
|
|
}
|
|
|
|
answer = self.model.chat(
|
|
image=None,
|
|
msgs=msgs,
|
|
tokenizer=self.tokenizer,
|
|
**params
|
|
)
|
|
|
|
extracted_info = self.extract_info(answer)
|
|
|
|
return {
|
|
"original_answer": answer,
|
|
"extracted_info": extracted_info,
|
|
"num_frames": len(frames),
|
|
"time_range": time_range,
|
|
"sequence_period_seconds": sequence_period_seconds
|
|
}
|
|
|
|
@staticmethod
|
|
def extract_info(answer):
|
|
info = {
|
|
"environment": None,
|
|
"num_people": None,
|
|
"actions": [],
|
|
"interactions": [],
|
|
"objects": [],
|
|
"furniture": []
|
|
}
|
|
|
|
# 环境提取
|
|
environments = ["办公室", "室内", "室外", "会议室", "办公"]
|
|
for env in environments:
|
|
if env in answer.lower():
|
|
info["environment"] = env
|
|
break
|
|
|
|
# 改进的人数提取
|
|
people_patterns = [
|
|
r'(\d+)\s*(人|个人|位|名|员工|用户|小朋友|成年人|女性|男性)',
|
|
r'(一|二|三|四|五|六|七|八|九|十)\s*(人|个人|位|名|员工|用户|小朋友|成年人|女性|男性)',
|
|
r'(一个|几个)\s*(人|个人|员工|用户|小朋友|成年人|女性|男性)',
|
|
r'几\s*(名|位)\s*(人|员工|用户|小朋友|成年人|女性|男性)?',
|
|
r'(男|女)(性|生|士)',
|
|
r'(成年|未成年|青少年|老年)\s*(人|群体)',
|
|
r'(员工|职工|工人|学生|顾客|观众|游客|乘客)',
|
|
r'(群众|民众|大众|公众)',
|
|
r'(男女|老少|老幼|大人|小孩)'
|
|
]
|
|
for pattern in people_patterns:
|
|
match = re.search(pattern, answer)
|
|
if match:
|
|
if match.group(1).isdigit():
|
|
info["num_people"] = int(match.group(1))
|
|
elif match.group(1) in ['一个', '一']:
|
|
info["num_people"] = 1
|
|
else:
|
|
num_word_to_digit = {
|
|
'二': 2, '三': 3, '四': 4, '五': 5,
|
|
'六': 6, '七': 7, '八': 8, '九': 9, '十': 10
|
|
}
|
|
info["num_people"] = num_word_to_digit.get(match.group(1), 0)
|
|
break
|
|
|
|
# 动作和互动提取
|
|
actions = ["坐", "站", "摔倒", "跳舞", "转身", "摔", "倒", "倒下", "躺下", "转身", "跳跃", "跳", "躺", "睡", "说话"]
|
|
for action in actions:
|
|
if action in answer:
|
|
info["actions"].append(action)
|
|
|
|
interactions = ["互动", "交流", "身体语言", "交谈", "讨论", "开会"]
|
|
for interaction in interactions:
|
|
if interaction in answer:
|
|
info["interactions"].append(interaction)
|
|
|
|
# 物体和家具提取
|
|
objects = ["水瓶", "办公用品", "文件", "电脑"]
|
|
furniture = ["椅子", "桌子", "咖啡桌", "文件柜", "床", "沙发"]
|
|
|
|
for obj in objects:
|
|
if obj in answer:
|
|
info["objects"].append(obj)
|
|
|
|
for item in furniture:
|
|
if item in answer:
|
|
info["furniture"].append(item)
|
|
|
|
return info
|
|
|
|
class ImageSequenceAnalysisSystem:
|
|
def __init__(self, model_dir):
|
|
self.image_processor = ImageSequenceProcessor(model_dir)
|
|
|
|
def process_image_sequence(self, image_data, cache_keys):
|
|
print(f"Attempting to process sequence of {len(image_data)} images")
|
|
start_time = time.time()
|
|
try:
|
|
print("Processing new image sequence...")
|
|
for i, (img, cache_key) in enumerate(zip(image_data, cache_keys)):
|
|
print(f"Image {i} type: {type(img)}")
|
|
if isinstance(img, np.ndarray):
|
|
print(f"Image {i} shape: {img.shape}, dtype: {img.dtype}")
|
|
else:
|
|
print(f"Unexpected data type for image {i}")
|
|
print(f"Image {i} cache_key: {cache_key}")
|
|
|
|
result = self.image_processor.process_image_sequence(image_data, cache_keys)
|
|
|
|
end_time = time.time()
|
|
processing_time = end_time - start_time
|
|
|
|
print(f"Processed image sequence for time range: {result['time_range']}")
|
|
print(f"Average time between frames: {result['sequence_period_seconds']:.2f} minutes")
|
|
print(f"Processing time: {processing_time:.2f} seconds")
|
|
|
|
return result
|
|
|
|
except Exception as e:
|
|
end_time = time.time()
|
|
processing_time = end_time - start_time
|
|
|
|
print(f"Error processing image sequence: {str(e)}")
|
|
print(f"Processing time (including error): {processing_time:.2f} seconds")
|
|
import traceback
|
|
traceback.print_exc()
|
|
return None
|
|
|
|
class ImageProcessingNode:
|
|
def __init__(self, bootstrap_servers, input_topic, model_dir, group_id, redis_pool_db0, redis_pool_db1):
|
|
self.consumer = KafkaConsumer(
|
|
input_topic,
|
|
bootstrap_servers=bootstrap_servers,
|
|
group_id=group_id,
|
|
value_deserializer=lambda x: json.loads(x.decode('utf-8')),
|
|
auto_offset_reset='earliest'
|
|
)
|
|
self.producer = KafkaProducer(
|
|
bootstrap_servers=bootstrap_servers,
|
|
value_serializer=lambda x: json.dumps(x, cls=JSONEncoder).encode('utf-8')
|
|
)
|
|
self.input_topic = input_topic
|
|
self.analysis_system = ImageSequenceAnalysisSystem(model_dir)
|
|
self.group_id = group_id
|
|
self.redis_pool_db0 = redis_pool_db0
|
|
self.redis_pool_db1 = redis_pool_db1
|
|
self.lock = threading.Lock()
|
|
self.image_buffer = []
|
|
self.image_info_buffer = []
|
|
self.cache_key_buffer = [] # 新增:用于存储cache_key
|
|
|
|
|
|
def process_and_produce(self, message_value):
|
|
cache_key = message_value.get('cache_key')
|
|
etag = message_value.get('etag')
|
|
size = message_value.get('size')
|
|
object_key = message_value.get('object')
|
|
|
|
# print(f"Consumer {self.group_id} received image with key_hash: {key_hash}")
|
|
|
|
try:
|
|
with redis.Redis(connection_pool=self.redis_pool_db0) as redis_client_db0:
|
|
img_str = redis_client_db0.get(cache_key)
|
|
if img_str is None:
|
|
print(f"Error: Image data not found in Redis db0 for cache_key: {cache_key}")
|
|
return
|
|
|
|
img_data = base64.b64decode(img_str)
|
|
nparr = np.frombuffer(img_data, np.uint8)
|
|
img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
|
|
|
|
self.image_buffer.append(img)
|
|
self.image_info_buffer.append({
|
|
'key_name': cache_key, # 使用 cache_key 作为 key_name
|
|
'etag': etag,
|
|
'size': size,
|
|
'object_key': object_key
|
|
# 'key_hash': key_hash
|
|
})
|
|
self.cache_key_buffer.append(cache_key)
|
|
|
|
if len(self.image_buffer) == 3:
|
|
result = self.analysis_system.process_image_sequence(self.image_buffer, self.cache_key_buffer)
|
|
|
|
if result:
|
|
result_data = {
|
|
'results': result,
|
|
'image_sequence': self.image_info_buffer
|
|
}
|
|
|
|
with self.lock:
|
|
with redis.Redis(connection_pool=self.redis_pool_db1) as redis_client_db1:
|
|
# 使用第一张图片的 cache_key 作为 sequence_key
|
|
sequence_key = f"{self.image_info_buffer[0]['key_name']}"
|
|
redis_client_db1.set(sequence_key, json.dumps(result_data))
|
|
print(f"Consumer {self.group_id} processed and saved results to Redis db1 with sequence_key {sequence_key}")
|
|
print(f"Processed image sequence: {[info['key_name'] for info in self.image_info_buffer]}")
|
|
|
|
# 清空缓冲区
|
|
self.image_buffer = []
|
|
self.image_info_buffer = []
|
|
self.cache_key_buffer = []
|
|
|
|
except Exception as e:
|
|
print(f"Error processing image: {str(e)}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
|
|
def run(self):
|
|
print(f"Consumer {self.group_id} starting to consume messages from {self.input_topic}")
|
|
while True:
|
|
messages = self.consumer.poll(timeout_ms=1000)
|
|
for tp, records in messages.items():
|
|
for record in records:
|
|
message_value = record.value
|
|
self.process_and_produce(message_value)
|
|
|
|
# 每分钟打印一次当前缓冲区状态
|
|
if len(self.image_buffer) < 3:
|
|
print(f"Still waiting for more images. Current buffer size: {len(self.image_buffer)}")
|
|
time.sleep(60) # 等待60秒(1分钟)
|
|
|
|
def start_consumer(kafka_bootstrap_servers, input_topic, model_dir, group_id, redis_host, redis_port, redis_password):
|
|
redis_pool_db0 = ConnectionPool(host=redis_host, port=redis_port, db=0, password=redis_password)
|
|
redis_pool_db1 = ConnectionPool(host=redis_host, port=redis_port, db=2, password=redis_password)
|
|
|
|
node = ImageProcessingNode(
|
|
kafka_bootstrap_servers,
|
|
input_topic,
|
|
model_dir,
|
|
group_id,
|
|
redis_pool_db0,
|
|
redis_pool_db1
|
|
)
|
|
print(f"Image Processing Node {group_id} initialized.")
|
|
print(f"Listening on topic: {input_topic}")
|
|
node.run()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
# Configuration
|
|
# kafka_bootstrap_servers = ['222.186.136.78:9092']
|
|
# input_topic = 'cpm-input'
|
|
# model_dir = 'worker_sys/OpenBMB/MiniCPM-V-2_6'
|
|
# redis_host = '222.186.136.78'
|
|
# redis_port = 6379
|
|
# redis_password = 'Obscura@2024'
|
|
# num_consumers = 1
|
|
|
|
kafka_bootstrap_servers = config['kafka']['bootstrap_servers']
|
|
input_topic = config['kafka']['topics']['ten_seconds']['name']
|
|
model_dir = config['model']['cpm-path']
|
|
redis_host = config['redis']['host']
|
|
redis_port = config['redis']['port']
|
|
redis_password = config['redis']['password']
|
|
num_consumers = config['kafka']['topics']['ten_seconds']['num_consumers']
|
|
|
|
# 创建多个消费者线程
|
|
consumer_threads = []
|
|
for i in range(num_consumers):
|
|
group_id = f'cpm_group_{i}'
|
|
thread = threading.Thread(
|
|
target=start_consumer,
|
|
args=(kafka_bootstrap_servers, input_topic, model_dir, group_id, redis_host, redis_port, redis_password)
|
|
)
|
|
consumer_threads.append(thread)
|
|
thread.start()
|
|
|
|
# 等待所有消费者线程完成
|
|
for thread in consumer_threads:
|
|
thread.join()
|