Files
api/api_history/function/cpm.py
T
2025-01-12 06:15:15 +00:00

363 lines
14 KiB
Python

import json
import io
from PIL import Image
import torch
from kafka import KafkaConsumer, KafkaProducer
from transformers import AutoModel, AutoTokenizer
import threading
import re
from datetime import datetime
import time
import base64
import numpy as np
import cv2
import redis
from redis import ConnectionPool
import yaml
# 加载配置文件
with open('worker_sys/function/config.yaml', 'r') as file:
config = yaml.safe_load(file)
class JSONEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, datetime):
return obj.isoformat()
return json.JSONEncoder.default(self, obj)
class ImageSequenceProcessor:
def __init__(self, model_dir):
self.model = AutoModel.from_pretrained(model_dir, trust_remote_code=True,
attn_implementation='sdpa', torch_dtype=torch.bfloat16).eval().cuda()
self.tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
self.max_size = 512 # 设置最大尺寸
def extract_time_from_key(self, key_name):
# 从key_name中提取时间信息
time_str = key_name.split('_')[-2] + '_' + key_name.split('_')[-1].split('.')[0]
return datetime.strptime(time_str, "%Y%m%d_%H%M%S")
def process_image_sequence(self, image_data, key_names):
frames = []
image_times = []
for i, (img, key_name) in enumerate(zip(image_data, key_names)):
try:
if isinstance(img, np.ndarray):
# 确保图像是 RGB 格式
if img.shape[2] == 3:
frame = Image.fromarray(img)
else:
print(f"Unexpected number of channels for image {i}: {img.shape[2]}")
continue
else:
print(f"Unexpected data type for image {i}: {type(img)}")
continue
frames.append(frame)
image_times.append(self.extract_time_from_key(key_name))
print(f"Successfully processed frame {i}")
except Exception as e:
print(f"Error processing frame {i}: {str(e)}")
continue
if not frames:
raise ValueError("No valid frames were processed")
# 修改时间范围格式
start_time = min(image_times)
end_time = max(image_times)
time_range = {
'start': start_time.strftime('%Y-%m-%d %H:%M'),
'end': end_time.strftime('%Y-%m-%d %H:%M')
}
# # 计算平均时间间隔(以分钟为单位)
# if len(image_times) > 1:
# sequence_period_seconds = (image_times[-1] - image_times[0]).total_seconds()/ (len(image_times) - 1)
# else:
# sequence_period_seconds = 0
total_duration = (end_time - start_time).total_seconds()
num_images = len(image_times)
if num_images > 1:
sequence_period_seconds = total_duration / (num_images - 1)
else:
sequence_period_seconds = 0
question = "Analyze these 3 images as if they were frames from a video. Describe the scene in detail in Chinese, including the setting, number of people, their actions, and any changes or movements observed across the frames."
msgs = [
{'role': 'user', 'content': frames + [question]},
]
params = {
"use_image_id": False,
"max_slice_nums": 1
}
answer = self.model.chat(
image=None,
msgs=msgs,
tokenizer=self.tokenizer,
**params
)
extracted_info = self.extract_info(answer)
return {
"original_answer": answer,
"extracted_info": extracted_info,
"num_frames": len(frames),
"time_range": time_range,
"sequence_period_seconds": sequence_period_seconds
}
@staticmethod
def extract_info(answer):
info = {
"environment": None,
"num_people": None,
"actions": [],
"interactions": [],
"objects": [],
"furniture": []
}
# 环境提取
environments = ["办公室", "室内", "室外", "会议室", "办公"]
for env in environments:
if env in answer.lower():
info["environment"] = env
break
# 改进的人数提取
people_patterns = [
r'(\d+)\s*(人|个人|位|名|员工|用户|小朋友|成年人|女性|男性)',
r'(一|二|三|四|五|六|七|八|九|十)\s*(人|个人|位|名|员工|用户|小朋友|成年人|女性|男性)',
r'(一个|几个)\s*(人|个人|员工|用户|小朋友|成年人|女性|男性)',
r'\s*(名|位)\s*(人|员工|用户|小朋友|成年人|女性|男性)?',
r'(男|女)(性|生|士)',
r'(成年|未成年|青少年|老年)\s*(人|群体)',
r'(员工|职工|工人|学生|顾客|观众|游客|乘客)',
r'(群众|民众|大众|公众)',
r'(男女|老少|老幼|大人|小孩)'
]
for pattern in people_patterns:
match = re.search(pattern, answer)
if match:
if match.group(1).isdigit():
info["num_people"] = int(match.group(1))
elif match.group(1) in ['一个', '']:
info["num_people"] = 1
else:
num_word_to_digit = {
'': 2, '': 3, '': 4, '': 5,
'': 6, '': 7, '': 8, '': 9, '': 10
}
info["num_people"] = num_word_to_digit.get(match.group(1), 0)
break
# 动作和互动提取
actions = ["", "", "摔倒", "跳舞", "转身", "", "", "倒下", "躺下", "转身", "跳跃", "", "", "", "说话"]
for action in actions:
if action in answer:
info["actions"].append(action)
interactions = ["互动", "交流", "身体语言", "交谈", "讨论", "开会"]
for interaction in interactions:
if interaction in answer:
info["interactions"].append(interaction)
# 物体和家具提取
objects = ["水瓶", "办公用品", "文件", "电脑"]
furniture = ["椅子", "桌子", "咖啡桌", "文件柜", "", "沙发"]
for obj in objects:
if obj in answer:
info["objects"].append(obj)
for item in furniture:
if item in answer:
info["furniture"].append(item)
return info
class ImageSequenceAnalysisSystem:
def __init__(self, model_dir):
self.image_processor = ImageSequenceProcessor(model_dir)
def process_image_sequence(self, image_data, cache_keys):
print(f"Attempting to process sequence of {len(image_data)} images")
start_time = time.time()
try:
print("Processing new image sequence...")
for i, (img, cache_key) in enumerate(zip(image_data, cache_keys)):
print(f"Image {i} type: {type(img)}")
if isinstance(img, np.ndarray):
print(f"Image {i} shape: {img.shape}, dtype: {img.dtype}")
else:
print(f"Unexpected data type for image {i}")
print(f"Image {i} cache_key: {cache_key}")
result = self.image_processor.process_image_sequence(image_data, cache_keys)
end_time = time.time()
processing_time = end_time - start_time
print(f"Processed image sequence for time range: {result['time_range']}")
print(f"Average time between frames: {result['sequence_period_seconds']:.2f} minutes")
print(f"Processing time: {processing_time:.2f} seconds")
return result
except Exception as e:
end_time = time.time()
processing_time = end_time - start_time
print(f"Error processing image sequence: {str(e)}")
print(f"Processing time (including error): {processing_time:.2f} seconds")
import traceback
traceback.print_exc()
return None
class ImageProcessingNode:
def __init__(self, bootstrap_servers, input_topic, model_dir, group_id, redis_pool_db0, redis_pool_db1):
self.consumer = KafkaConsumer(
input_topic,
bootstrap_servers=bootstrap_servers,
group_id=group_id,
value_deserializer=lambda x: json.loads(x.decode('utf-8')),
auto_offset_reset='earliest'
)
self.producer = KafkaProducer(
bootstrap_servers=bootstrap_servers,
value_serializer=lambda x: json.dumps(x, cls=JSONEncoder).encode('utf-8')
)
self.input_topic = input_topic
self.analysis_system = ImageSequenceAnalysisSystem(model_dir)
self.group_id = group_id
self.redis_pool_db0 = redis_pool_db0
self.redis_pool_db1 = redis_pool_db1
self.lock = threading.Lock()
self.image_buffer = []
self.image_info_buffer = []
self.cache_key_buffer = [] # 新增:用于存储cache_key
def process_and_produce(self, message_value):
cache_key = message_value.get('cache_key')
etag = message_value.get('etag')
size = message_value.get('size')
object_key = message_value.get('object')
# print(f"Consumer {self.group_id} received image with key_hash: {key_hash}")
try:
with redis.Redis(connection_pool=self.redis_pool_db0) as redis_client_db0:
img_str = redis_client_db0.get(cache_key)
if img_str is None:
print(f"Error: Image data not found in Redis db0 for cache_key: {cache_key}")
return
img_data = base64.b64decode(img_str)
nparr = np.frombuffer(img_data, np.uint8)
img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
self.image_buffer.append(img)
self.image_info_buffer.append({
'key_name': cache_key, # 使用 cache_key 作为 key_name
'etag': etag,
'size': size,
'object_key': object_key
# 'key_hash': key_hash
})
self.cache_key_buffer.append(cache_key)
if len(self.image_buffer) == 3:
result = self.analysis_system.process_image_sequence(self.image_buffer, self.cache_key_buffer)
if result:
result_data = {
'results': result,
'image_sequence': self.image_info_buffer
}
with self.lock:
with redis.Redis(connection_pool=self.redis_pool_db1) as redis_client_db1:
# 使用第一张图片的 cache_key 作为 sequence_key
sequence_key = f"{self.image_info_buffer[0]['key_name']}"
redis_client_db1.set(sequence_key, json.dumps(result_data))
print(f"Consumer {self.group_id} processed and saved results to Redis db1 with sequence_key {sequence_key}")
print(f"Processed image sequence: {[info['key_name'] for info in self.image_info_buffer]}")
# 清空缓冲区
self.image_buffer = []
self.image_info_buffer = []
self.cache_key_buffer = []
except Exception as e:
print(f"Error processing image: {str(e)}")
import traceback
traceback.print_exc()
def run(self):
print(f"Consumer {self.group_id} starting to consume messages from {self.input_topic}")
while True:
messages = self.consumer.poll(timeout_ms=1000)
for tp, records in messages.items():
for record in records:
message_value = record.value
self.process_and_produce(message_value)
# 每分钟打印一次当前缓冲区状态
if len(self.image_buffer) < 3:
print(f"Still waiting for more images. Current buffer size: {len(self.image_buffer)}")
time.sleep(60) # 等待60秒(1分钟)
def start_consumer(kafka_bootstrap_servers, input_topic, model_dir, group_id, redis_host, redis_port, redis_password):
redis_pool_db0 = ConnectionPool(host=redis_host, port=redis_port, db=0, password=redis_password)
redis_pool_db1 = ConnectionPool(host=redis_host, port=redis_port, db=2, password=redis_password)
node = ImageProcessingNode(
kafka_bootstrap_servers,
input_topic,
model_dir,
group_id,
redis_pool_db0,
redis_pool_db1
)
print(f"Image Processing Node {group_id} initialized.")
print(f"Listening on topic: {input_topic}")
node.run()
if __name__ == "__main__":
# Configuration
# kafka_bootstrap_servers = ['222.186.136.78:9092']
# input_topic = 'cpm-input'
# model_dir = 'worker_sys/OpenBMB/MiniCPM-V-2_6'
# redis_host = '222.186.136.78'
# redis_port = 6379
# redis_password = 'Obscura@2024'
# num_consumers = 1
kafka_bootstrap_servers = config['kafka']['bootstrap_servers']
input_topic = config['kafka']['topics']['ten_seconds']['name']
model_dir = config['model']['cpm-path']
redis_host = config['redis']['host']
redis_port = config['redis']['port']
redis_password = config['redis']['password']
num_consumers = config['kafka']['topics']['ten_seconds']['num_consumers']
# 创建多个消费者线程
consumer_threads = []
for i in range(num_consumers):
group_id = f'cpm_group_{i}'
thread = threading.Thread(
target=start_consumer,
args=(kafka_bootstrap_servers, input_topic, model_dir, group_id, redis_host, redis_port, redis_password)
)
consumer_threads.append(thread)
thread.start()
# 等待所有消费者线程完成
for thread in consumer_threads:
thread.join()