Files
zydi-api/api/cpm_analyze.py
2025-01-12 06:15:15 +00:00

292 lines
11 KiB
Python

import os
import json
from datetime import datetime, timedelta
from kafka import KafkaConsumer
from decord import VideoReader, cpu
from PIL import Image
import redis
from redis import Redis
import io
import re
import threading
import requests
import base64
import traceback
import json
from config import *
# 配置
OLLAMA_URL = OLLAMA_URL
KAFKA_BROKER = KAFKA_BROKER
KAFKA_TOPIC = WORKER_CONFIGS["cpm_analyze"]["kafka_topic"]
KAFKA_GROUP_ID = f"cpm_analyze_{KAFKA_GROUP_ID_PREFIX}"
REDIS_HOST = REDIS_HOST
REDIS_PORT = REDIS_PORT
REDIS_PASSWORD = REDIS_PASSWORD
REDIS_DB = WORKER_CONFIGS["cpm_analyze"]["redis_db"] # Worker使用的Redis DB
MAIN_REDIS_DB = MAIN_REDIS_DB # 主Redis DB
UPLOAD_DIR = UPLOAD_DIR
RESULT_DIR = RESULT_DIR
# 确保目录存在
os.makedirs(UPLOAD_DIR, exist_ok=True)
os.makedirs(RESULT_DIR, exist_ok=True)
# 初始化 Kafka
consumer = KafkaConsumer(
KAFKA_TOPIC,
bootstrap_servers=[KAFKA_BROKER],
group_id=KAFKA_GROUP_ID,
auto_offset_reset='earliest',
enable_auto_commit=True,
value_deserializer=lambda x: json.loads(x.decode('utf-8'))
)
# 初始化 Redis
redis_client = Redis(
host=REDIS_HOST,
port=REDIS_PORT,
password=REDIS_PASSWORD,
db=REDIS_DB
)
main_redis_client = Redis(
host=REDIS_HOST,
port=REDIS_PORT,
password=REDIS_PASSWORD,
db=MAIN_REDIS_DB
)
# 设置 GPU 设备
# torch.cuda.set_device(0)
class MediaAnalysisSystem:
def __init__(self):
self.MAX_NUM_FRAMES = 16
def encode_video(self, video_data):
def uniform_sample(l, n):
gap = len(l) / n
return [l[int(i * gap + gap / 2)] for i in range(n)]
video_file = io.BytesIO(video_data)
vr = VideoReader(video_file, ctx=cpu(0))
sample_fps = round(vr.get_avg_fps() / 1)
frame_idx = list(range(0, len(vr), sample_fps))
if len(frame_idx) > self.MAX_NUM_FRAMES:
frame_idx = uniform_sample(frame_idx, self.MAX_NUM_FRAMES)
frames = vr.get_batch(frame_idx).asnumpy()
frames = [Image.fromarray(v.astype('uint8')) for v in frames]
print('num frames:', len(frames))
return frames
def process_video(self, video_data, object_name):
if not video_data:
raise ValueError(f"Empty video data for {object_name}")
print(f"Processing video: {object_name}, data size: {len(video_data)} bytes")
frames = self.encode_video(video_data)
question = """您是一个高级OCR和文本分析助手。您的主要任务是:
1) 从该视频中高精度提取所有文本内容,包括标准文本和数字数据,
2) 对提取的内容进行全面分析,
3) 将信息进行逻辑性的组织和结构化,
4) 提供从文本/数字中发现的详细见解。
对于数字数据,请包含统计分析。以清晰的层次格式呈现您的发现,请提供完整的原始文本。如果遇到任何不清楚或模糊的元素,请突出显示并要求澄清。
请支持:
- 多种语言(简体中文、繁体中文、英文)
- 不同的文本方向和布局
- 表格、图表和结构化数据格式
- 特殊字符、符号和数学符号
- 原始格式和文本位置信息"""
encoded_frames = [self.image_to_base64(frame) for frame in frames]
payload = {
"model": "minicpm-v",
"prompt": question,
"images": encoded_frames
}
try:
response = requests.post(OLLAMA_URL, json=payload, stream=True)
print(f"Ollama API 响应状态码: {response.status_code}")
print(f"Ollama API 响应头: {response.headers}")
if response.status_code == 200:
answer = self.process_stream_response(response)
else:
raise Exception(f"Ollama API 错误: {response.status_code}")
except requests.RequestException as e:
print(f"请求 Ollama API 时出错: {str(e)}")
raise
return {
"original_answer": answer,
"num_frames": len(frames),
}
def process_image(self, image_data, object_name):
image = Image.open(io.BytesIO(image_data))
question = """您是一个高级OCR和文本分析助手。您的主要任务是:
1) 从该图片中高精度提取所有文本内容,包括标准文本和数字数据,
2) 对提取的内容进行全面分析,
3) 将信息进行逻辑性的组织和结构化,
4) 提供从文本/数字中发现的详细见解。
对于数字数据,请包含统计分析。以清晰的层次格式呈现您的发现,请提供完整的原始文本。如果遇到任何不清楚或模糊的元素,请突出显示并要求澄清。
请支持:
- 多种语言(简体中文、繁体中文、英文)
- 不同的文本方向和布局
- 表格、图表和结构化数据格式
- 特殊字符、符号和数学符号
- 原始格式和文本位置信息"""
encoded_image = self.image_to_base64(image)
payload = {
"model": "minicpm-v",
"prompt": question,
"images": [encoded_image]
}
try:
response = requests.post(OLLAMA_URL, json=payload, stream=True)
if response.status_code == 200:
answer = self.process_stream_response(response)
else:
raise Exception(f"Ollama API 错误: {response.status_code}")
except requests.RequestException as e:
print(f"请求 Ollama API 时出错: {str(e)}")
raise
return {
"original_answer": answer,
"timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S")
}
def process_stream_response(self, response):
full_response = []
for line in response.iter_lines():
if line:
try:
json_response = json.loads(line)
if 'response' in json_response:
full_response.append(json_response['response'])
if json_response.get('done', False):
break
except json.JSONDecodeError:
print(f"无法解析 JSON 行: {line}")
return ''.join(full_response)
@staticmethod
def image_to_base64(image):
buffered = io.BytesIO()
image.save(buffered, format="PNG")
return base64.b64encode(buffered.getvalue()).decode()
@staticmethod
def extract_time_from_filename(object_name):
filename = os.path.basename(object_name)
time_str = filename.split('_')[0] + '_' + filename.split('_')[1].split('.')[0]
try:
start_time = datetime.strptime(time_str, "%Y%m%d_%H%M%S")
end_time = start_time + timedelta(seconds=10)
return start_time, end_time
except ValueError:
print(f"无法从文件名 '{filename}' 解析时间。使用默认时间。")
return datetime.now(), datetime.now() + timedelta(seconds=10)
# 初始化 MediaAnalysisSystem
media_analysis_system = MediaAnalysisSystem()
def process_task():
for message in consumer:
print(f"收到Kafka消息: topic={message.topic}, partition={message.partition}, offset={message.offset}")
task = message.value
task_id = task['task_id']
filename = task['filename']
file_type = task['file_type']
print(f"解析任务信息: ID={task_id}, filename={filename}, type={file_type}")
file_path = os.path.join(UPLOAD_DIR, filename)
task_key = f"task:{task_id}"
try:
key_type = main_redis_client.type(task_key)
if key_type != b'hash':
main_redis_client.delete(task_key)
main_redis_client.hset(task_key, "status", "processing")
print(f"任务 {task_id} 状态更新为 'processing'")
if file_type == "image":
print(f"Processing image: {filename}")
with open(file_path, 'rb') as f:
image_data = f.read()
result = media_analysis_system.process_image(image_data, filename)
else: # video
print(f"Processing video: {filename}")
with open(file_path, 'rb') as f:
video_data = f.read()
result = media_analysis_system.process_video(video_data, filename)
if result:
redis_client.hset(f"cpm_analyze_result:{task_id}", mapping={
"result": json.dumps(result),
"result_file": filename
})
main_redis_client.hset(task_key, mapping={
"status": "completed",
"result_type": "cpm_analyze",
"result_key": f"cpm_analyze_result:{task_id}"
})
print(f"{file_type.capitalize()} {filename} 已处理,结果已保存")
else:
print(f"{file_type.capitalize()} {filename} 处理失败")
main_redis_client.hset(task_key, "status", "failed")
except Exception as e:
print(f"处理任务 {task_id} 时出错: {str(e)}")
print(f"错误详情: {traceback.format_exc()}")
main_redis_client.hset(task_key, mapping={
"status": "failed",
"error": str(e)
})
print(f"任务 {task_id} 处理完毕,等待下一个Kafka消息...")
def listen_redis_changes():
pubsub = redis_client.pubsub()
pubsub.psubscribe('__keyspace@3__:cpm_analyze_result:*') # Listen for changes on all cpm_result keys
for message in pubsub.listen():
if message['type'] == 'pmessage':
key = message['channel'].decode('utf-8').split(':')[-1]
operation = message['data'].decode('utf-8')
if operation == 'hset':
value = redis_client.hgetall(f"cpm_analyze_result:{key}")
if value:
result = {k.decode(): v.decode() for k, v in value.items()}
print(f"Result update for task {key}: {result}")
if __name__ == "__main__":
print("cpm_analyze处理程序启动...")
# Start the task processing thread
task_thread = threading.Thread(target=process_task, daemon=True)
task_thread.start()
print("任务处理线程已启动")
# Start the Redis listening thread
redis_thread = threading.Thread(target=listen_redis_changes, daemon=True)
redis_thread.start()
print("Redis监听线程已启动")
print("主程序进入等待状态...")
# Keep the main thread running
task_thread.join()
redis_thread.join()