292 lines
11 KiB
Python
292 lines
11 KiB
Python
import os
|
|
import json
|
|
from datetime import datetime, timedelta
|
|
from kafka import KafkaConsumer
|
|
from decord import VideoReader, cpu
|
|
from PIL import Image
|
|
import redis
|
|
from redis import Redis
|
|
import io
|
|
import re
|
|
import threading
|
|
import requests
|
|
import base64
|
|
import traceback
|
|
import json
|
|
from config import *
|
|
|
|
# 配置
|
|
OLLAMA_URL = OLLAMA_URL
|
|
KAFKA_BROKER = KAFKA_BROKER
|
|
KAFKA_TOPIC = WORKER_CONFIGS["cpm_analyze"]["kafka_topic"]
|
|
KAFKA_GROUP_ID = f"cpm_analyze_{KAFKA_GROUP_ID_PREFIX}"
|
|
|
|
REDIS_HOST = REDIS_HOST
|
|
REDIS_PORT = REDIS_PORT
|
|
REDIS_PASSWORD = REDIS_PASSWORD
|
|
REDIS_DB = WORKER_CONFIGS["cpm_analyze"]["redis_db"] # Worker使用的Redis DB
|
|
MAIN_REDIS_DB = MAIN_REDIS_DB # 主Redis DB
|
|
|
|
UPLOAD_DIR = UPLOAD_DIR
|
|
RESULT_DIR = RESULT_DIR
|
|
|
|
# 确保目录存在
|
|
os.makedirs(UPLOAD_DIR, exist_ok=True)
|
|
os.makedirs(RESULT_DIR, exist_ok=True)
|
|
|
|
# 初始化 Kafka
|
|
consumer = KafkaConsumer(
|
|
KAFKA_TOPIC,
|
|
bootstrap_servers=[KAFKA_BROKER],
|
|
group_id=KAFKA_GROUP_ID,
|
|
auto_offset_reset='earliest',
|
|
enable_auto_commit=True,
|
|
value_deserializer=lambda x: json.loads(x.decode('utf-8'))
|
|
)
|
|
|
|
# 初始化 Redis
|
|
redis_client = Redis(
|
|
host=REDIS_HOST,
|
|
port=REDIS_PORT,
|
|
password=REDIS_PASSWORD,
|
|
db=REDIS_DB
|
|
)
|
|
|
|
main_redis_client = Redis(
|
|
host=REDIS_HOST,
|
|
port=REDIS_PORT,
|
|
password=REDIS_PASSWORD,
|
|
db=MAIN_REDIS_DB
|
|
)
|
|
|
|
# 设置 GPU 设备
|
|
# torch.cuda.set_device(0)
|
|
|
|
class MediaAnalysisSystem:
|
|
def __init__(self):
|
|
self.MAX_NUM_FRAMES = 16
|
|
|
|
def encode_video(self, video_data):
|
|
def uniform_sample(l, n):
|
|
gap = len(l) / n
|
|
return [l[int(i * gap + gap / 2)] for i in range(n)]
|
|
|
|
video_file = io.BytesIO(video_data)
|
|
vr = VideoReader(video_file, ctx=cpu(0))
|
|
sample_fps = round(vr.get_avg_fps() / 1)
|
|
frame_idx = list(range(0, len(vr), sample_fps))
|
|
if len(frame_idx) > self.MAX_NUM_FRAMES:
|
|
frame_idx = uniform_sample(frame_idx, self.MAX_NUM_FRAMES)
|
|
frames = vr.get_batch(frame_idx).asnumpy()
|
|
frames = [Image.fromarray(v.astype('uint8')) for v in frames]
|
|
print('num frames:', len(frames))
|
|
return frames
|
|
|
|
def process_video(self, video_data, object_name):
|
|
if not video_data:
|
|
raise ValueError(f"Empty video data for {object_name}")
|
|
print(f"Processing video: {object_name}, data size: {len(video_data)} bytes")
|
|
frames = self.encode_video(video_data)
|
|
question = """您是一个高级OCR和文本分析助手。您的主要任务是:
|
|
1) 从该视频中高精度提取所有文本内容,包括标准文本和数字数据,
|
|
2) 对提取的内容进行全面分析,
|
|
3) 将信息进行逻辑性的组织和结构化,
|
|
4) 提供从文本/数字中发现的详细见解。
|
|
|
|
对于数字数据,请包含统计分析。以清晰的层次格式呈现您的发现,请提供完整的原始文本。如果遇到任何不清楚或模糊的元素,请突出显示并要求澄清。
|
|
请支持:
|
|
- 多种语言(简体中文、繁体中文、英文)
|
|
- 不同的文本方向和布局
|
|
- 表格、图表和结构化数据格式
|
|
- 特殊字符、符号和数学符号
|
|
- 原始格式和文本位置信息"""
|
|
|
|
encoded_frames = [self.image_to_base64(frame) for frame in frames]
|
|
|
|
payload = {
|
|
"model": "minicpm-v",
|
|
"prompt": question,
|
|
"images": encoded_frames
|
|
}
|
|
|
|
try:
|
|
response = requests.post(OLLAMA_URL, json=payload, stream=True)
|
|
print(f"Ollama API 响应状态码: {response.status_code}")
|
|
print(f"Ollama API 响应头: {response.headers}")
|
|
|
|
if response.status_code == 200:
|
|
answer = self.process_stream_response(response)
|
|
else:
|
|
raise Exception(f"Ollama API 错误: {response.status_code}")
|
|
except requests.RequestException as e:
|
|
print(f"请求 Ollama API 时出错: {str(e)}")
|
|
raise
|
|
|
|
return {
|
|
"original_answer": answer,
|
|
"num_frames": len(frames),
|
|
}
|
|
|
|
def process_image(self, image_data, object_name):
|
|
image = Image.open(io.BytesIO(image_data))
|
|
question = """您是一个高级OCR和文本分析助手。您的主要任务是:
|
|
1) 从该图片中高精度提取所有文本内容,包括标准文本和数字数据,
|
|
2) 对提取的内容进行全面分析,
|
|
3) 将信息进行逻辑性的组织和结构化,
|
|
4) 提供从文本/数字中发现的详细见解。
|
|
|
|
对于数字数据,请包含统计分析。以清晰的层次格式呈现您的发现,请提供完整的原始文本。如果遇到任何不清楚或模糊的元素,请突出显示并要求澄清。
|
|
请支持:
|
|
- 多种语言(简体中文、繁体中文、英文)
|
|
- 不同的文本方向和布局
|
|
- 表格、图表和结构化数据格式
|
|
- 特殊字符、符号和数学符号
|
|
- 原始格式和文本位置信息"""
|
|
|
|
encoded_image = self.image_to_base64(image)
|
|
|
|
payload = {
|
|
"model": "minicpm-v",
|
|
"prompt": question,
|
|
"images": [encoded_image]
|
|
}
|
|
|
|
try:
|
|
response = requests.post(OLLAMA_URL, json=payload, stream=True)
|
|
|
|
if response.status_code == 200:
|
|
answer = self.process_stream_response(response)
|
|
else:
|
|
raise Exception(f"Ollama API 错误: {response.status_code}")
|
|
|
|
except requests.RequestException as e:
|
|
print(f"请求 Ollama API 时出错: {str(e)}")
|
|
raise
|
|
|
|
return {
|
|
"original_answer": answer,
|
|
"timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
|
}
|
|
|
|
def process_stream_response(self, response):
|
|
full_response = []
|
|
for line in response.iter_lines():
|
|
if line:
|
|
try:
|
|
json_response = json.loads(line)
|
|
if 'response' in json_response:
|
|
full_response.append(json_response['response'])
|
|
if json_response.get('done', False):
|
|
break
|
|
except json.JSONDecodeError:
|
|
print(f"无法解析 JSON 行: {line}")
|
|
return ''.join(full_response)
|
|
|
|
@staticmethod
|
|
def image_to_base64(image):
|
|
buffered = io.BytesIO()
|
|
image.save(buffered, format="PNG")
|
|
return base64.b64encode(buffered.getvalue()).decode()
|
|
|
|
@staticmethod
|
|
def extract_time_from_filename(object_name):
|
|
filename = os.path.basename(object_name)
|
|
time_str = filename.split('_')[0] + '_' + filename.split('_')[1].split('.')[0]
|
|
|
|
try:
|
|
start_time = datetime.strptime(time_str, "%Y%m%d_%H%M%S")
|
|
end_time = start_time + timedelta(seconds=10)
|
|
return start_time, end_time
|
|
except ValueError:
|
|
print(f"无法从文件名 '{filename}' 解析时间。使用默认时间。")
|
|
return datetime.now(), datetime.now() + timedelta(seconds=10)
|
|
|
|
|
|
# 初始化 MediaAnalysisSystem
|
|
media_analysis_system = MediaAnalysisSystem()
|
|
def process_task():
|
|
for message in consumer:
|
|
print(f"收到Kafka消息: topic={message.topic}, partition={message.partition}, offset={message.offset}")
|
|
task = message.value
|
|
task_id = task['task_id']
|
|
filename = task['filename']
|
|
file_type = task['file_type']
|
|
|
|
print(f"解析任务信息: ID={task_id}, filename={filename}, type={file_type}")
|
|
|
|
file_path = os.path.join(UPLOAD_DIR, filename)
|
|
task_key = f"task:{task_id}"
|
|
|
|
try:
|
|
key_type = main_redis_client.type(task_key)
|
|
if key_type != b'hash':
|
|
main_redis_client.delete(task_key)
|
|
main_redis_client.hset(task_key, "status", "processing")
|
|
print(f"任务 {task_id} 状态更新为 'processing'")
|
|
|
|
if file_type == "image":
|
|
print(f"Processing image: {filename}")
|
|
with open(file_path, 'rb') as f:
|
|
image_data = f.read()
|
|
result = media_analysis_system.process_image(image_data, filename)
|
|
else: # video
|
|
print(f"Processing video: {filename}")
|
|
with open(file_path, 'rb') as f:
|
|
video_data = f.read()
|
|
result = media_analysis_system.process_video(video_data, filename)
|
|
|
|
if result:
|
|
redis_client.hset(f"cpm_analyze_result:{task_id}", mapping={
|
|
"result": json.dumps(result),
|
|
"result_file": filename
|
|
})
|
|
main_redis_client.hset(task_key, mapping={
|
|
"status": "completed",
|
|
"result_type": "cpm_analyze",
|
|
"result_key": f"cpm_analyze_result:{task_id}"
|
|
})
|
|
print(f"{file_type.capitalize()} {filename} 已处理,结果已保存")
|
|
else:
|
|
print(f"{file_type.capitalize()} {filename} 处理失败")
|
|
main_redis_client.hset(task_key, "status", "failed")
|
|
except Exception as e:
|
|
print(f"处理任务 {task_id} 时出错: {str(e)}")
|
|
print(f"错误详情: {traceback.format_exc()}")
|
|
main_redis_client.hset(task_key, mapping={
|
|
"status": "failed",
|
|
"error": str(e)
|
|
})
|
|
|
|
print(f"任务 {task_id} 处理完毕,等待下一个Kafka消息...")
|
|
|
|
def listen_redis_changes():
|
|
pubsub = redis_client.pubsub()
|
|
pubsub.psubscribe('__keyspace@3__:cpm_analyze_result:*') # Listen for changes on all cpm_result keys
|
|
|
|
for message in pubsub.listen():
|
|
if message['type'] == 'pmessage':
|
|
key = message['channel'].decode('utf-8').split(':')[-1]
|
|
operation = message['data'].decode('utf-8')
|
|
|
|
if operation == 'hset':
|
|
value = redis_client.hgetall(f"cpm_analyze_result:{key}")
|
|
if value:
|
|
result = {k.decode(): v.decode() for k, v in value.items()}
|
|
print(f"Result update for task {key}: {result}")
|
|
|
|
if __name__ == "__main__":
|
|
print("cpm_analyze处理程序启动...")
|
|
# Start the task processing thread
|
|
task_thread = threading.Thread(target=process_task, daemon=True)
|
|
task_thread.start()
|
|
print("任务处理线程已启动")
|
|
|
|
# Start the Redis listening thread
|
|
redis_thread = threading.Thread(target=listen_redis_changes, daemon=True)
|
|
redis_thread.start()
|
|
print("Redis监听线程已启动")
|
|
|
|
print("主程序进入等待状态...")
|
|
# Keep the main thread running
|
|
task_thread.join()
|
|
redis_thread.join() |