367 lines
14 KiB
Python
367 lines
14 KiB
Python
import os
|
|
import json
|
|
from datetime import datetime, timedelta
|
|
from kafka import KafkaConsumer
|
|
from decord import VideoReader, cpu
|
|
from PIL import Image
|
|
import redis
|
|
from redis import Redis
|
|
import io
|
|
import re
|
|
import threading
|
|
import requests
|
|
import base64
|
|
import traceback
|
|
import json
|
|
from config import *
|
|
|
|
# 配置
|
|
OLLAMA_URL = OLLAMA_URL
|
|
KAFKA_BROKER = KAFKA_BROKER
|
|
KAFKA_TOPIC = WORKER_CONFIGS["cpm"]["kafka_topic"]
|
|
KAFKA_GROUP_ID = f"cpm_{KAFKA_GROUP_ID_PREFIX}"
|
|
|
|
REDIS_HOST = REDIS_HOST
|
|
REDIS_PORT = REDIS_PORT
|
|
REDIS_PASSWORD = REDIS_PASSWORD
|
|
REDIS_DB = WORKER_CONFIGS["cpm"]["redis_db"] # Worker使用的Redis DB
|
|
MAIN_REDIS_DB = MAIN_REDIS_DB # 主Redis DB
|
|
|
|
UPLOAD_DIR = UPLOAD_DIR
|
|
RESULT_DIR = RESULT_DIR
|
|
|
|
# 确保目录存在
|
|
os.makedirs(UPLOAD_DIR, exist_ok=True)
|
|
os.makedirs(RESULT_DIR, exist_ok=True)
|
|
|
|
# 初始化 Kafka
|
|
consumer = KafkaConsumer(
|
|
KAFKA_TOPIC,
|
|
bootstrap_servers=[KAFKA_BROKER],
|
|
group_id=KAFKA_GROUP_ID,
|
|
auto_offset_reset='earliest',
|
|
enable_auto_commit=True,
|
|
value_deserializer=lambda x: json.loads(x.decode('utf-8'))
|
|
)
|
|
|
|
# 初始化 Redis
|
|
redis_client = Redis(
|
|
host=REDIS_HOST,
|
|
port=REDIS_PORT,
|
|
password=REDIS_PASSWORD,
|
|
db=REDIS_DB
|
|
)
|
|
|
|
main_redis_client = Redis(
|
|
host=REDIS_HOST,
|
|
port=REDIS_PORT,
|
|
password=REDIS_PASSWORD,
|
|
db=MAIN_REDIS_DB
|
|
)
|
|
|
|
# 设置 GPU 设备
|
|
# torch.cuda.set_device(0)
|
|
|
|
class MediaAnalysisSystem:
|
|
def __init__(self):
|
|
self.MAX_NUM_FRAMES = 16
|
|
|
|
def encode_video(self, video_data):
|
|
def uniform_sample(l, n):
|
|
gap = len(l) / n
|
|
return [l[int(i * gap + gap / 2)] for i in range(n)]
|
|
|
|
video_file = io.BytesIO(video_data)
|
|
vr = VideoReader(video_file, ctx=cpu(0))
|
|
sample_fps = round(vr.get_avg_fps() / 1)
|
|
frame_idx = list(range(0, len(vr), sample_fps))
|
|
if len(frame_idx) > self.MAX_NUM_FRAMES:
|
|
frame_idx = uniform_sample(frame_idx, self.MAX_NUM_FRAMES)
|
|
frames = vr.get_batch(frame_idx).asnumpy()
|
|
frames = [Image.fromarray(v.astype('uint8')) for v in frames]
|
|
print('num frames:', len(frames))
|
|
return frames
|
|
|
|
def process_video(self, video_data, object_name):
|
|
if not video_data:
|
|
raise ValueError(f"Empty video data for {object_name}")
|
|
print(f"Processing video: {object_name}, data size: {len(video_data)} bytes")
|
|
frames = self.encode_video(video_data)
|
|
question = """请对这段监控视频进行详细分析,包括以下方面:
|
|
1. 场景中人数的精确统计
|
|
2. 每个人的个人行为分析
|
|
3. 面部表情识别和情绪状态评估
|
|
4. 整体场景和环境的详细描述
|
|
5. 人与人之间的互动情况
|
|
6. 时间和环境条件(如果可见)
|
|
7. 任何可疑或异常活动
|
|
8. 人员的具体特征(估计年龄范围、性别、着装)
|
|
9. 人员的移动模式和方向
|
|
10. 携带的物品或物体
|
|
11. 群体动态和聚集情况
|
|
12. 视频中的时间戳分析(如果有)
|
|
|
|
请用清晰、有条理的格式描述,并突出重要发现。"""
|
|
|
|
encoded_frames = [self.image_to_base64(frame) for frame in frames]
|
|
|
|
payload = {
|
|
"model": "minicpm-v",
|
|
"prompt": question,
|
|
"images": encoded_frames
|
|
}
|
|
|
|
try:
|
|
response = requests.post(OLLAMA_URL, json=payload, stream=True)
|
|
print(f"Ollama API 响应状态码: {response.status_code}")
|
|
print(f"Ollama API 响应头: {response.headers}")
|
|
|
|
if response.status_code == 200:
|
|
answer = self.process_stream_response(response)
|
|
else:
|
|
raise Exception(f"Ollama API 错误: {response.status_code}")
|
|
except requests.RequestException as e:
|
|
print(f"请求 Ollama API 时出错: {str(e)}")
|
|
raise
|
|
|
|
extracted_info = self.extract_info(answer)
|
|
|
|
return {
|
|
"original_answer": answer,
|
|
"extracted_info": extracted_info,
|
|
"num_frames": len(frames),
|
|
}
|
|
|
|
def process_image(self, image_data, object_name):
|
|
image = Image.open(io.BytesIO(image_data))
|
|
question = """请对这张监控图像进行详细分析,包括以下方面:
|
|
1. 场景中人数的精确统计
|
|
2. 每个人的个人行为分析
|
|
3. 面部表情识别和情绪状态评估
|
|
4. 整体场景和环境的详细描述
|
|
5. 人与人之间的互动情况
|
|
6. 时间和环境条件(如果可见)
|
|
7. 任何可疑或异常活动
|
|
8. 人员的具体特征(估计年龄范围、性别、着装)
|
|
9. 人员的位置和姿态
|
|
10. 携带的物品或物体
|
|
11. 群体动态和聚集情况
|
|
12. 图像中的时间戳信息(如果有)
|
|
|
|
请用清晰、有条理的格式描述,并突出重要发现。"""
|
|
|
|
encoded_image = self.image_to_base64(image)
|
|
|
|
payload = {
|
|
"model": "minicpm-v",
|
|
"prompt": question,
|
|
"images": [encoded_image]
|
|
}
|
|
|
|
try:
|
|
response = requests.post(OLLAMA_URL, json=payload, stream=True)
|
|
|
|
if response.status_code == 200:
|
|
answer = self.process_stream_response(response)
|
|
else:
|
|
raise Exception(f"Ollama API 错误: {response.status_code}")
|
|
|
|
except requests.RequestException as e:
|
|
print(f"请求 Ollama API 时出错: {str(e)}")
|
|
raise
|
|
|
|
extracted_info = self.extract_info(answer)
|
|
|
|
return {
|
|
"original_answer": answer,
|
|
"extracted_info": extracted_info,
|
|
"timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
|
}
|
|
|
|
def process_stream_response(self, response):
|
|
full_response = []
|
|
for line in response.iter_lines():
|
|
if line:
|
|
try:
|
|
json_response = json.loads(line)
|
|
if 'response' in json_response:
|
|
full_response.append(json_response['response'])
|
|
if json_response.get('done', False):
|
|
break
|
|
except json.JSONDecodeError:
|
|
print(f"无法解析 JSON 行: {line}")
|
|
return ''.join(full_response)
|
|
|
|
@staticmethod
|
|
def image_to_base64(image):
|
|
buffered = io.BytesIO()
|
|
image.save(buffered, format="PNG")
|
|
return base64.b64encode(buffered.getvalue()).decode()
|
|
|
|
@staticmethod
|
|
def extract_time_from_filename(object_name):
|
|
filename = os.path.basename(object_name)
|
|
time_str = filename.split('_')[0] + '_' + filename.split('_')[1].split('.')[0]
|
|
|
|
try:
|
|
start_time = datetime.strptime(time_str, "%Y%m%d_%H%M%S")
|
|
end_time = start_time + timedelta(seconds=10)
|
|
return start_time, end_time
|
|
except ValueError:
|
|
print(f"无法从文件名 '{filename}' 解析时间。使用默认时间。")
|
|
return datetime.now(), datetime.now() + timedelta(seconds=10)
|
|
|
|
@staticmethod
|
|
def extract_info(answer):
|
|
info = {
|
|
"environment": None,
|
|
"num_people": None,
|
|
"actions": [],
|
|
"interactions": [],
|
|
"objects": [],
|
|
"furniture": []
|
|
}
|
|
|
|
environments = ["办公室", "室内", "��外", "会议室"]
|
|
for env in environments:
|
|
if env in answer.lower():
|
|
info["environment"] = env
|
|
break
|
|
|
|
people_patterns = [
|
|
r'(\d+)\s*(人|个人|位|名|员工|用户|小朋友|成年人|女性|男性)',
|
|
r'(一|二|三|四|五|六|七|八|九|十)\s*(人|个人|位|名|员工|用户|小朋友|成年人|女性|男性)',
|
|
r'(一个|几个)\s*(人|个人|员工|用户|小朋友|成年人|女性|男性)',
|
|
r'几\s*(名|位)\s*(人|员工|用户|小朋友|成年人|女性|男性)?',
|
|
r'(男|女)(性|生|士)',
|
|
r'(成年|未成年|青少年|老年)\s*(人|群体)',
|
|
r'(员工|职工|工人|学生|顾客|观众|游客|乘客)',
|
|
r'(群众|民众|大众|公众)',
|
|
r'(男女|老少|老幼|大人|小孩)'
|
|
]
|
|
for pattern in people_patterns:
|
|
match = re.search(pattern, answer)
|
|
if match:
|
|
if match.group(1).isdigit():
|
|
info["num_people"] = int(match.group(1))
|
|
elif match.group(1) in ['一个', '一']:
|
|
info["num_people"] = 1
|
|
else:
|
|
num_word_to_digit = {
|
|
'二': 2, '三': 3, '四': 4, '五': 5,
|
|
'六': 6, '七': 7, '八': 8, '九': 9, '十': 10
|
|
}
|
|
info["num_people"] = num_word_to_digit.get(match.group(1), 0)
|
|
break
|
|
|
|
actions = ["坐", "站", "摔倒", "跳舞", "转身", "摔", "倒", "倒下", "躺下", "转身", "跳跃", "跳", "躺", "睡", "说话"]
|
|
for action in actions:
|
|
if action in answer:
|
|
info["actions"].append(action)
|
|
|
|
interactions = ["互动", "交流", "身体语言", "交谈", "讨论", "开会"]
|
|
for interaction in interactions:
|
|
if interaction in answer:
|
|
info["interactions"].append(interaction)
|
|
|
|
objects = ["水瓶", "办公用品", "文件", "电脑"]
|
|
furniture = ["椅子", "桌子", "咖啡桌", "文件柜", "床", "沙发"]
|
|
|
|
for obj in objects:
|
|
if obj in answer:
|
|
info["objects"].append(obj)
|
|
|
|
for item in furniture:
|
|
if item in answer:
|
|
info["furniture"].append(item)
|
|
|
|
return info
|
|
|
|
# 初始化 MediaAnalysisSystem
|
|
media_analysis_system = MediaAnalysisSystem()
|
|
def process_task():
|
|
for message in consumer:
|
|
print(f"收到Kafka消息: topic={message.topic}, partition={message.partition}, offset={message.offset}")
|
|
task = message.value
|
|
task_id = task['task_id']
|
|
filename = task['filename']
|
|
file_type = task['file_type']
|
|
|
|
print(f"解析任务信息: ID={task_id}, filename={filename}, type={file_type}")
|
|
|
|
file_path = os.path.join(UPLOAD_DIR, filename)
|
|
task_key = f"task:{task_id}"
|
|
|
|
try:
|
|
key_type = main_redis_client.type(task_key)
|
|
if key_type != b'hash':
|
|
main_redis_client.delete(task_key)
|
|
main_redis_client.hset(task_key, "status", "processing")
|
|
print(f"任务 {task_id} 状态更新为 'processing'")
|
|
|
|
if file_type == "image":
|
|
print(f"Processing image: {filename}")
|
|
with open(file_path, 'rb') as f:
|
|
image_data = f.read()
|
|
result = media_analysis_system.process_image(image_data, filename)
|
|
else: # video
|
|
print(f"Processing video: {filename}")
|
|
with open(file_path, 'rb') as f:
|
|
video_data = f.read()
|
|
result = media_analysis_system.process_video(video_data, filename)
|
|
|
|
if result:
|
|
redis_client.hset(f"cpm_result:{task_id}", mapping={
|
|
"result": json.dumps(result),
|
|
"result_file": filename
|
|
})
|
|
main_redis_client.hset(task_key, mapping={
|
|
"status": "completed",
|
|
"result_type": "cpm",
|
|
"result_key": f"cpm_result:{task_id}"
|
|
})
|
|
print(f"{file_type.capitalize()} {filename} 已处理,结果已保存")
|
|
else:
|
|
print(f"{file_type.capitalize()} {filename} 处理失败")
|
|
main_redis_client.hset(task_key, "status", "failed")
|
|
except Exception as e:
|
|
print(f"处理任务 {task_id} 时出错: {str(e)}")
|
|
print(f"错误详情: {traceback.format_exc()}")
|
|
main_redis_client.hset(task_key, mapping={
|
|
"status": "failed",
|
|
"error": str(e)
|
|
})
|
|
|
|
print(f"任务 {task_id} 处理完毕,等待下一个Kafka消息...")
|
|
|
|
def listen_redis_changes():
|
|
pubsub = redis_client.pubsub()
|
|
pubsub.psubscribe('__keyspace@3__:cpm_result:*') # Listen for changes on all cpm_result keys
|
|
|
|
for message in pubsub.listen():
|
|
if message['type'] == 'pmessage':
|
|
key = message['channel'].decode('utf-8').split(':')[-1]
|
|
operation = message['data'].decode('utf-8')
|
|
|
|
if operation == 'hset':
|
|
value = redis_client.hgetall(f"cpm_result:{key}")
|
|
if value:
|
|
result = {k.decode(): v.decode() for k, v in value.items()}
|
|
print(f"Result update for task {key}: {result}")
|
|
|
|
if __name__ == "__main__":
|
|
print("cpm处理程序启动...")
|
|
# Start the task processing thread
|
|
task_thread = threading.Thread(target=process_task, daemon=True)
|
|
task_thread.start()
|
|
print("任务处理线程已启动")
|
|
|
|
# Start the Redis listening thread
|
|
redis_thread = threading.Thread(target=listen_redis_changes, daemon=True)
|
|
redis_thread.start()
|
|
print("Redis监听线程已启动")
|
|
|
|
print("主程序进入等待状态...")
|
|
# Keep the main thread running
|
|
task_thread.join()
|
|
redis_thread.join() |