Files
2025-01-12 06:15:15 +00:00

367 lines
14 KiB
Python

import os
import json
from datetime import datetime, timedelta
from kafka import KafkaConsumer
from decord import VideoReader, cpu
from PIL import Image
import redis
from redis import Redis
import io
import re
import threading
import requests
import base64
import traceback
import json
from config import *
# 配置
OLLAMA_URL = OLLAMA_URL
KAFKA_BROKER = KAFKA_BROKER
KAFKA_TOPIC = WORKER_CONFIGS["cpm"]["kafka_topic"]
KAFKA_GROUP_ID = f"cpm_{KAFKA_GROUP_ID_PREFIX}"
REDIS_HOST = REDIS_HOST
REDIS_PORT = REDIS_PORT
REDIS_PASSWORD = REDIS_PASSWORD
REDIS_DB = WORKER_CONFIGS["cpm"]["redis_db"] # Worker使用的Redis DB
MAIN_REDIS_DB = MAIN_REDIS_DB # 主Redis DB
UPLOAD_DIR = UPLOAD_DIR
RESULT_DIR = RESULT_DIR
# 确保目录存在
os.makedirs(UPLOAD_DIR, exist_ok=True)
os.makedirs(RESULT_DIR, exist_ok=True)
# 初始化 Kafka
consumer = KafkaConsumer(
KAFKA_TOPIC,
bootstrap_servers=[KAFKA_BROKER],
group_id=KAFKA_GROUP_ID,
auto_offset_reset='earliest',
enable_auto_commit=True,
value_deserializer=lambda x: json.loads(x.decode('utf-8'))
)
# 初始化 Redis
redis_client = Redis(
host=REDIS_HOST,
port=REDIS_PORT,
password=REDIS_PASSWORD,
db=REDIS_DB
)
main_redis_client = Redis(
host=REDIS_HOST,
port=REDIS_PORT,
password=REDIS_PASSWORD,
db=MAIN_REDIS_DB
)
# 设置 GPU 设备
# torch.cuda.set_device(0)
class MediaAnalysisSystem:
def __init__(self):
self.MAX_NUM_FRAMES = 16
def encode_video(self, video_data):
def uniform_sample(l, n):
gap = len(l) / n
return [l[int(i * gap + gap / 2)] for i in range(n)]
video_file = io.BytesIO(video_data)
vr = VideoReader(video_file, ctx=cpu(0))
sample_fps = round(vr.get_avg_fps() / 1)
frame_idx = list(range(0, len(vr), sample_fps))
if len(frame_idx) > self.MAX_NUM_FRAMES:
frame_idx = uniform_sample(frame_idx, self.MAX_NUM_FRAMES)
frames = vr.get_batch(frame_idx).asnumpy()
frames = [Image.fromarray(v.astype('uint8')) for v in frames]
print('num frames:', len(frames))
return frames
def process_video(self, video_data, object_name):
if not video_data:
raise ValueError(f"Empty video data for {object_name}")
print(f"Processing video: {object_name}, data size: {len(video_data)} bytes")
frames = self.encode_video(video_data)
question = """请对这段监控视频进行详细分析,包括以下方面:
1. 场景中人数的精确统计
2. 每个人的个人行为分析
3. 面部表情识别和情绪状态评估
4. 整体场景和环境的详细描述
5. 人与人之间的互动情况
6. 时间和环境条件(如果可见)
7. 任何可疑或异常活动
8. 人员的具体特征(估计年龄范围、性别、着装)
9. 人员的移动模式和方向
10. 携带的物品或物体
11. 群体动态和聚集情况
12. 视频中的时间戳分析(如果有)
请用清晰、有条理的格式描述,并突出重要发现。"""
encoded_frames = [self.image_to_base64(frame) for frame in frames]
payload = {
"model": "minicpm-v",
"prompt": question,
"images": encoded_frames
}
try:
response = requests.post(OLLAMA_URL, json=payload, stream=True)
print(f"Ollama API 响应状态码: {response.status_code}")
print(f"Ollama API 响应头: {response.headers}")
if response.status_code == 200:
answer = self.process_stream_response(response)
else:
raise Exception(f"Ollama API 错误: {response.status_code}")
except requests.RequestException as e:
print(f"请求 Ollama API 时出错: {str(e)}")
raise
extracted_info = self.extract_info(answer)
return {
"original_answer": answer,
"extracted_info": extracted_info,
"num_frames": len(frames),
}
def process_image(self, image_data, object_name):
image = Image.open(io.BytesIO(image_data))
question = """请对这张监控图像进行详细分析,包括以下方面:
1. 场景中人数的精确统计
2. 每个人的个人行为分析
3. 面部表情识别和情绪状态评估
4. 整体场景和环境的详细描述
5. 人与人之间的互动情况
6. 时间和环境条件(如果可见)
7. 任何可疑或异常活动
8. 人员的具体特征(估计年龄范围、性别、着装)
9. 人员的位置和姿态
10. 携带的物品或物体
11. 群体动态和聚集情况
12. 图像中的时间戳信息(如果有)
请用清晰、有条理的格式描述,并突出重要发现。"""
encoded_image = self.image_to_base64(image)
payload = {
"model": "minicpm-v",
"prompt": question,
"images": [encoded_image]
}
try:
response = requests.post(OLLAMA_URL, json=payload, stream=True)
if response.status_code == 200:
answer = self.process_stream_response(response)
else:
raise Exception(f"Ollama API 错误: {response.status_code}")
except requests.RequestException as e:
print(f"请求 Ollama API 时出错: {str(e)}")
raise
extracted_info = self.extract_info(answer)
return {
"original_answer": answer,
"extracted_info": extracted_info,
"timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S")
}
def process_stream_response(self, response):
full_response = []
for line in response.iter_lines():
if line:
try:
json_response = json.loads(line)
if 'response' in json_response:
full_response.append(json_response['response'])
if json_response.get('done', False):
break
except json.JSONDecodeError:
print(f"无法解析 JSON 行: {line}")
return ''.join(full_response)
@staticmethod
def image_to_base64(image):
buffered = io.BytesIO()
image.save(buffered, format="PNG")
return base64.b64encode(buffered.getvalue()).decode()
@staticmethod
def extract_time_from_filename(object_name):
filename = os.path.basename(object_name)
time_str = filename.split('_')[0] + '_' + filename.split('_')[1].split('.')[0]
try:
start_time = datetime.strptime(time_str, "%Y%m%d_%H%M%S")
end_time = start_time + timedelta(seconds=10)
return start_time, end_time
except ValueError:
print(f"无法从文件名 '{filename}' 解析时间。使用默认时间。")
return datetime.now(), datetime.now() + timedelta(seconds=10)
@staticmethod
def extract_info(answer):
info = {
"environment": None,
"num_people": None,
"actions": [],
"interactions": [],
"objects": [],
"furniture": []
}
environments = ["办公室", "室内", "", "会议室"]
for env in environments:
if env in answer.lower():
info["environment"] = env
break
people_patterns = [
r'(\d+)\s*(人|个人|位|名|员工|用户|小朋友|成年人|女性|男性)',
r'(一|二|三|四|五|六|七|八|九|十)\s*(人|个人|位|名|员工|用户|小朋友|成年人|女性|男性)',
r'(一个|几个)\s*(人|个人|员工|用户|小朋友|成年人|女性|男性)',
r'\s*(名|位)\s*(人|员工|用户|小朋友|成年人|女性|男性)?',
r'(男|女)(性|生|士)',
r'(成年|未成年|青少年|老年)\s*(人|群体)',
r'(员工|职工|工人|学生|顾客|观众|游客|乘客)',
r'(群众|民众|大众|公众)',
r'(男女|老少|老幼|大人|小孩)'
]
for pattern in people_patterns:
match = re.search(pattern, answer)
if match:
if match.group(1).isdigit():
info["num_people"] = int(match.group(1))
elif match.group(1) in ['一个', '']:
info["num_people"] = 1
else:
num_word_to_digit = {
'': 2, '': 3, '': 4, '': 5,
'': 6, '': 7, '': 8, '': 9, '': 10
}
info["num_people"] = num_word_to_digit.get(match.group(1), 0)
break
actions = ["", "", "摔倒", "跳舞", "转身", "", "", "倒下", "躺下", "转身", "跳跃", "", "", "", "说话"]
for action in actions:
if action in answer:
info["actions"].append(action)
interactions = ["互动", "交流", "身体语言", "交谈", "讨论", "开会"]
for interaction in interactions:
if interaction in answer:
info["interactions"].append(interaction)
objects = ["水瓶", "办公用品", "文件", "电脑"]
furniture = ["椅子", "桌子", "咖啡桌", "文件柜", "", "沙发"]
for obj in objects:
if obj in answer:
info["objects"].append(obj)
for item in furniture:
if item in answer:
info["furniture"].append(item)
return info
# 初始化 MediaAnalysisSystem
media_analysis_system = MediaAnalysisSystem()
def process_task():
for message in consumer:
print(f"收到Kafka消息: topic={message.topic}, partition={message.partition}, offset={message.offset}")
task = message.value
task_id = task['task_id']
filename = task['filename']
file_type = task['file_type']
print(f"解析任务信息: ID={task_id}, filename={filename}, type={file_type}")
file_path = os.path.join(UPLOAD_DIR, filename)
task_key = f"task:{task_id}"
try:
key_type = main_redis_client.type(task_key)
if key_type != b'hash':
main_redis_client.delete(task_key)
main_redis_client.hset(task_key, "status", "processing")
print(f"任务 {task_id} 状态更新为 'processing'")
if file_type == "image":
print(f"Processing image: {filename}")
with open(file_path, 'rb') as f:
image_data = f.read()
result = media_analysis_system.process_image(image_data, filename)
else: # video
print(f"Processing video: {filename}")
with open(file_path, 'rb') as f:
video_data = f.read()
result = media_analysis_system.process_video(video_data, filename)
if result:
redis_client.hset(f"cpm_result:{task_id}", mapping={
"result": json.dumps(result),
"result_file": filename
})
main_redis_client.hset(task_key, mapping={
"status": "completed",
"result_type": "cpm",
"result_key": f"cpm_result:{task_id}"
})
print(f"{file_type.capitalize()} {filename} 已处理,结果已保存")
else:
print(f"{file_type.capitalize()} {filename} 处理失败")
main_redis_client.hset(task_key, "status", "failed")
except Exception as e:
print(f"处理任务 {task_id} 时出错: {str(e)}")
print(f"错误详情: {traceback.format_exc()}")
main_redis_client.hset(task_key, mapping={
"status": "failed",
"error": str(e)
})
print(f"任务 {task_id} 处理完毕,等待下一个Kafka消息...")
def listen_redis_changes():
pubsub = redis_client.pubsub()
pubsub.psubscribe('__keyspace@3__:cpm_result:*') # Listen for changes on all cpm_result keys
for message in pubsub.listen():
if message['type'] == 'pmessage':
key = message['channel'].decode('utf-8').split(':')[-1]
operation = message['data'].decode('utf-8')
if operation == 'hset':
value = redis_client.hgetall(f"cpm_result:{key}")
if value:
result = {k.decode(): v.decode() for k, v in value.items()}
print(f"Result update for task {key}: {result}")
if __name__ == "__main__":
print("cpm处理程序启动...")
# Start the task processing thread
task_thread = threading.Thread(target=process_task, daemon=True)
task_thread.start()
print("任务处理线程已启动")
# Start the Redis listening thread
redis_thread = threading.Thread(target=listen_redis_changes, daemon=True)
redis_thread.start()
print("Redis监听线程已启动")
print("主程序进入等待状态...")
# Keep the main thread running
task_thread.join()
redis_thread.join()