Initial commit

This commit is contained in:
2025-01-12 06:15:15 +00:00
commit a3dcc7a619
131 changed files with 19998 additions and 0 deletions
+20
View File
@@ -0,0 +1,20 @@
api_history/OpenBMB/*
!api_history/OpenBMB/.gitkeep
chat_history/ChatTTS/*
!chat_history/ChatTTS/.gitkeep
api_chat/OpenBMB/*
!api_chat/OpenBMB/.gitkeep
api_chat/GPT_SoVITS/*
!api_chat/GPT_SoVITS/.gitkeep
api_chat/tools/*
!api_chat/tools/.gitkeep
api_chat/runtime/*
!api_chat/runtime/.gitkeep
api_chat/docs/*
!api_chat/docs/.gitkeep
+318
View File
@@ -0,0 +1,318 @@
import os
import cv2
import torch
import numpy as np
from redis import Redis
from ultralytics import YOLO
import json
from kafka import KafkaConsumer
import threading
import redis
import torch
from config import *
import insightface
from insightface.app import FaceAnalysis
from insightface.utils import face_align
torch.cuda.set_device(1)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 配置
KAFKA_BROKER = KAFKA_BROKER
KAFKA_TOPIC = WORKER_CONFIGS["compare"]["kafka_topic"]
KAFKA_GROUP_ID = f"compare_{KAFKA_GROUP_ID_PREFIX}"
REDIS_HOST = REDIS_HOST
REDIS_PORT = REDIS_PORT
REDIS_PASSWORD = REDIS_PASSWORD
REDIS_DB = WORKER_CONFIGS["compare"]["redis_db"] # Worker使用的Redis DB
MAIN_REDIS_DB = MAIN_REDIS_DB # 主Redis DB
UPLOAD_DIR = UPLOAD_DIR
RESULT_DIR = RESULT_DIR
# 初始化 Kafka
consumer = KafkaConsumer(
KAFKA_TOPIC,
bootstrap_servers=[KAFKA_BROKER],
group_id=KAFKA_GROUP_ID,
auto_offset_reset='earliest',
enable_auto_commit=True,
value_deserializer=lambda x: json.loads(x.decode('utf-8'))
)
# 初始化 Redis
redis_client = Redis(
host=REDIS_HOST,
port=REDIS_PORT,
password=REDIS_PASSWORD,
db=REDIS_DB
)
main_redis_client = Redis(
host=REDIS_HOST,
port=REDIS_PORT,
password=REDIS_PASSWORD,
db=MAIN_REDIS_DB
)
class FaceComparator:
def __init__(self):
print("初始化 InsightFace...")
self.app = FaceAnalysis(name='buffalo_l', allowed_modules=['detection', 'recognition', 'genderage'])
self.app.prepare(ctx_id=0, det_size=(640, 640))
print(f"InsightFace 初始化完成,使用设备: {device}")
def detect(self, frame):
"""检测人脸并返回所有特征"""
try:
faces = self.app.get(frame)
return faces
except Exception as e:
return []
def format_results(self, faces):
"""格式化检测结果"""
print("开始格式化检测结果...")
try:
formatted_results = []
for i, face in enumerate(faces):
# 设置默认值,避免None导致的错误
result = {
'bbox': face.bbox.tolist(),
'kps': face.kps.tolist(),
'gender': int(face.gender) if hasattr(face, 'gender') and face.gender is not None else -1,
'age': float(face.age) if hasattr(face, 'age') and face.age is not None else 0.0,
'det_score': float(face.det_score),
'embedding': face.embedding.tolist()
}
formatted_results.append(result)
return formatted_results
except Exception as e:
print(f"格式化结果时出错: {str(e)}")
print(f"错误类型: {type(e)}")
import traceback
print(f"错误堆栈: {traceback.format_exc()}")
return []
def draw_results(self, frame, results):
"""在图像上绘制检测结果"""
print("开始绘制检测结果...")
try:
annotated_frame = frame.copy()
if not results:
print("没有检测结果,返回原始图像")
return annotated_frame
for i, r in enumerate(results):
bbox = r['bbox']
kps = r['kps']
cv2.rectangle(annotated_frame,
(int(bbox[0]), int(bbox[1])),
(int(bbox[2]), int(bbox[3])),
(0, 255, 0), 2)
for x, y in kps:
cv2.circle(annotated_frame,
(int(x), int(y)),
3, (255, 255, 0), -1)
gender_text = 'Male' if r['gender'] == 1 else 'Female' if r['gender'] == 0 else 'Unknown'
label = f"Age: {int(r['age'])} Gender: {gender_text}"
cv2.putText(annotated_frame,
label,
(int(bbox[0]), int(bbox[1]-10)),
cv2.FONT_HERSHEY_SIMPLEX,
0.5,
(0, 255, 0),
2)
return annotated_frame
except Exception as e:
print(f"绘制结果时出错: {str(e)}")
return frame.copy()
detector = FaceComparator()
def process_image(image_path):
try:
original_img = cv2.imread(image_path)
if original_img is None:
raise ValueError(f"无法读取图像文件: {image_path}")
if original_img.size == 0:
raise ValueError("图像数据为空")
faces = detector.detect(original_img)
if not faces:
json_results = []
else:
json_results = detector.format_results(faces)
if not json_results: # 如果格式化失败,使用空列表
json_results = []
annotated_img = detector.draw_results(original_img, json_results)
# 即使没有检测到人脸或格式化失败,也返回处理结果
return json_results, annotated_img
except Exception as e:
return None, None
def process_video(video_path):
try:
cap = cv2.VideoCapture(video_path)
frame_count = 0
json_results = []
fps = int(cap.get(cv2.CAP_PROP_FPS))
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
original_shape = (height, width)
out = cv2.VideoWriter(video_path.replace(UPLOAD_DIR, RESULT_DIR), cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height))
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
if frame_count % fps == 0:
preprocessed_frame = preprocess_frame(frame)
results = detector.detect(preprocessed_frame)
frame_json_results = detector.format_results(results, original_shape)
json_results.append({"frame": frame_count, "detections": frame_json_results})
annotated_frame = detector.draw_results(frame, frame_json_results)
out.write(annotated_frame)
frame_count += 1
cap.release()
out.release()
return json_results
except Exception as e:
print(f"处理视频时出错: {str(e)}")
return None
def preprocess_frame(frame):
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frame_resized = cv2.resize(frame_rgb, (640, 640))
frame_transposed = frame_resized.transpose((2, 0, 1))
frame_contiguous = np.ascontiguousarray(frame_transposed)
frame_tensor = torch.from_numpy(frame_contiguous).float()
frame_normalized = frame_tensor.to(device) / 255.0
frame_batched = frame_normalized.unsqueeze(0)
return frame_batched
def process_task():
print("开始处理任务,等待Kafka消息...")
for message in consumer:
print(f"收到Kafka消息: topic={message.topic}, partition={message.partition}, offset={message.offset}")
task = message.value
task_id = task['task_id']
filename = task['filename']
file_type = task['file_type']
print(f"解析任务信息: ID={task_id}, 文件名={filename}, 类型={file_type}")
file_path = os.path.join(UPLOAD_DIR, filename)
# 检查键类型并更新状态
task_key = f"task:{task_id}"
try:
key_type = main_redis_client.type(task_key)
if key_type != b'hash':
main_redis_client.delete(task_key)
main_redis_client.hset(task_key, "status", "processing")
print(f"任务 {task_id} 状态更新为 'processing'")
except redis.exceptions.ResponseError as e:
print(f"更新任务 {task_id} 状态时出错: {str(e)}")
continue # 跳过这个任务,继续处理下一个
try:
if file_type == "image":
print(f"开始处理图像: {filename}")
json_results, annotated_img = process_image(file_path)
if json_results and annotated_img is not None:
result_filename = f"compare_{filename}"
result_path = os.path.join(RESULT_DIR, result_filename)
cv2.imwrite(result_path, annotated_img)
redis_client.hset(f"compare_result:{task_id}",
"result", json.dumps(json_results))
redis_client.hset(f"compare_result:{task_id}",
"result_file", result_filename)
main_redis_client.hset(f"task:{task_id}", "status", "completed")
main_redis_client.hset(f"task:{task_id}", "result_type", "compare")
main_redis_client.hset(f"task:{task_id}", "result_key", f"compare_result:{task_id}")
print(f"图像 {filename} 处理完成,结果已保存")
else:
print(f"图像 {filename} 处理失败")
main_redis_client.hset(f"task:{task_id}", "status", "failed")
else: # video
print(f"开始处理视频: {filename}")
json_results = process_video(file_path)
if json_results:
result_filename = f"compare_{filename}"
redis_client.hset(f"compare_result:{task_id}",
"result", json.dumps(json_results))
redis_client.hset(f"compare_result:{task_id}",
"result_file", result_filename)
main_redis_client.hset(f"task:{task_id}", "status", "completed")
main_redis_client.hset(f"task:{task_id}", "result_type", "compare")
main_redis_client.hset(f"task:{task_id}", "result_key", f"compare_result:{task_id}")
print(f"视频 {filename} 处理完成,结果已保存")
else:
print(f"视频 {filename} 处理失败")
main_redis_client.hset(f"task:{task_id}", "status", "failed")
except Exception as e:
print(f"处理任务 {task_id} 时出错: {str(e)}")
main_redis_client.hset(f"task:{task_id}", "status", "failed")
main_redis_client.hset(f"task:{task_id}", "error", str(e))
print(f"任务 {task_id} 处理完毕,等待下一个Kafka消息...")
def listen_redis_changes():
pubsub = redis_client.pubsub()
pubsub.psubscribe('__keyspace@3__:compare_result:*') # 监听所有compare_result键的变化
for message in pubsub.listen():
if message['type'] == 'pmessage':
key = message['channel'].decode('utf-8').split(':')[-1]
operation = message['data'].decode('utf-8')
if operation == 'hmset':
value = redis_client.hgetall(f"compare_result:{key}")
if value:
result = {k.decode(): v.decode() for k, v in value.items()}
print(f"Result update for task {key}: {result}")
if __name__ == "__main__":
print("compare处理程序启动...")
# 启动处理任务的线程
task_thread = threading.Thread(target=process_task, daemon=True)
task_thread.start()
print("任务处理线程已启动")
# 启动Redis监听线程
redis_thread = threading.Thread(target=listen_redis_changes, daemon=True)
redis_thread.start()
print("Redis监听线程已启动")
print("主程序进入等待状态...")
# 保持主线程运行
task_thread.join()
redis_thread.join()
+81
View File
@@ -0,0 +1,81 @@
# config.py
import os
# Kafka配置
KAFKA_BROKER = "222.186.10.253:9092"
KAFKA_GROUP_ID_PREFIX = "group"
# Redis配置
REDIS_HOST = "150.158.144.159"
REDIS_PORT = 13003
REDIS_PASSWORD = "Obscura@2024"
MAIN_REDIS_DB = 0
REDIS_API_DB = 2
REDIS_API_USAGE_DB = 3
# 目录配置
UPLOAD_DIR = "/obscura/task/upload"
RESULT_DIR = "/obscura/task/result"
# 确保目录存在
os.makedirs(UPLOAD_DIR, exist_ok=True)
os.makedirs(RESULT_DIR, exist_ok=True)
# 模型配置
YOLO_MODEL_PATH = "/obscura/models/yolov8x.pt"
POSE_MODEL_PATH = "/obscura/models/yolov8x-pose.pt"
QWEN_MODEL_PATH = "/obscura/models/qwen/Qwen2-VL-2B-Instruct"
FALL_MODEL_PATH = "/obscura/models/yolov8n-fall.pt"
FACE_MODEL_PATH = "/obscura/models/yolov8n-face.pt"
MEDIAPIPE_MODEL_PATH = "/obscura/models/face_landmarker.task"
# COMPARE_MODEL_PATH = "/obscura/models/insightface/insw_r100_glint360k.onnx"
# Ollama配置
OLLAMA_URL = "http://127.0.0.1:11434/api/generate"
# 各个worker的配置
WORKER_CONFIGS = {
"yolo": {
"kafka_topic": "yolo",
"redis_db": 4,
},
"pose": {
"kafka_topic": "pose",
"redis_db": 5,
},
"qwenvl": {
"kafka_topic": "qwenvl",
"redis_db": 9,
},
"qwenvl_analyze": {
"kafka_topic": "qwenvl_analyze",
"redis_db": 32,
},
"cpm": {
"kafka_topic": "cpm",
"redis_db": 8,
},
"cpm_analyze": {
"kafka_topic": "cpm_analyze",
"redis_db": 31,
},
"fall": {
"kafka_topic": "fall",
"redis_db": 6,
},
"face": {
"kafka_topic": "face",
"redis_db": 7,
},
"mediapipe": {
"kafka_topic": "mediapipe",
"redis_db": 10,
},
"compare": {
"kafka_topic": "compare",
"redis_db": 30,
}
}
# GPU设置
CUDA_DEVICE_0 = "cuda:0"
CUDA_DEVICE_1 = "cuda:1"
+292
View File
@@ -0,0 +1,292 @@
import os
import json
from datetime import datetime, timedelta
from kafka import KafkaConsumer
from decord import VideoReader, cpu
from PIL import Image
import redis
from redis import Redis
import io
import re
import threading
import requests
import base64
import traceback
import json
from config import *
# 配置
OLLAMA_URL = OLLAMA_URL
KAFKA_BROKER = KAFKA_BROKER
KAFKA_TOPIC = WORKER_CONFIGS["cpm_analyze"]["kafka_topic"]
KAFKA_GROUP_ID = f"cpm_analyze_{KAFKA_GROUP_ID_PREFIX}"
REDIS_HOST = REDIS_HOST
REDIS_PORT = REDIS_PORT
REDIS_PASSWORD = REDIS_PASSWORD
REDIS_DB = WORKER_CONFIGS["cpm_analyze"]["redis_db"] # Worker使用的Redis DB
MAIN_REDIS_DB = MAIN_REDIS_DB # 主Redis DB
UPLOAD_DIR = UPLOAD_DIR
RESULT_DIR = RESULT_DIR
# 确保目录存在
os.makedirs(UPLOAD_DIR, exist_ok=True)
os.makedirs(RESULT_DIR, exist_ok=True)
# 初始化 Kafka
consumer = KafkaConsumer(
KAFKA_TOPIC,
bootstrap_servers=[KAFKA_BROKER],
group_id=KAFKA_GROUP_ID,
auto_offset_reset='earliest',
enable_auto_commit=True,
value_deserializer=lambda x: json.loads(x.decode('utf-8'))
)
# 初始化 Redis
redis_client = Redis(
host=REDIS_HOST,
port=REDIS_PORT,
password=REDIS_PASSWORD,
db=REDIS_DB
)
main_redis_client = Redis(
host=REDIS_HOST,
port=REDIS_PORT,
password=REDIS_PASSWORD,
db=MAIN_REDIS_DB
)
# 设置 GPU 设备
# torch.cuda.set_device(0)
class MediaAnalysisSystem:
def __init__(self):
self.MAX_NUM_FRAMES = 16
def encode_video(self, video_data):
def uniform_sample(l, n):
gap = len(l) / n
return [l[int(i * gap + gap / 2)] for i in range(n)]
video_file = io.BytesIO(video_data)
vr = VideoReader(video_file, ctx=cpu(0))
sample_fps = round(vr.get_avg_fps() / 1)
frame_idx = list(range(0, len(vr), sample_fps))
if len(frame_idx) > self.MAX_NUM_FRAMES:
frame_idx = uniform_sample(frame_idx, self.MAX_NUM_FRAMES)
frames = vr.get_batch(frame_idx).asnumpy()
frames = [Image.fromarray(v.astype('uint8')) for v in frames]
print('num frames:', len(frames))
return frames
def process_video(self, video_data, object_name):
if not video_data:
raise ValueError(f"Empty video data for {object_name}")
print(f"Processing video: {object_name}, data size: {len(video_data)} bytes")
frames = self.encode_video(video_data)
question = """您是一个高级OCR和文本分析助手。您的主要任务是:
1) 从该视频中高精度提取所有文本内容,包括标准文本和数字数据,
2) 对提取的内容进行全面分析,
3) 将信息进行逻辑性的组织和结构化,
4) 提供从文本/数字中发现的详细见解。
对于数字数据,请包含统计分析。以清晰的层次格式呈现您的发现,请提供完整的原始文本。如果遇到任何不清楚或模糊的元素,请突出显示并要求澄清。
请支持:
- 多种语言(简体中文、繁体中文、英文)
- 不同的文本方向和布局
- 表格、图表和结构化数据格式
- 特殊字符、符号和数学符号
- 原始格式和文本位置信息"""
encoded_frames = [self.image_to_base64(frame) for frame in frames]
payload = {
"model": "minicpm-v",
"prompt": question,
"images": encoded_frames
}
try:
response = requests.post(OLLAMA_URL, json=payload, stream=True)
print(f"Ollama API 响应状态码: {response.status_code}")
print(f"Ollama API 响应头: {response.headers}")
if response.status_code == 200:
answer = self.process_stream_response(response)
else:
raise Exception(f"Ollama API 错误: {response.status_code}")
except requests.RequestException as e:
print(f"请求 Ollama API 时出错: {str(e)}")
raise
return {
"original_answer": answer,
"num_frames": len(frames),
}
def process_image(self, image_data, object_name):
image = Image.open(io.BytesIO(image_data))
question = """您是一个高级OCR和文本分析助手。您的主要任务是:
1) 从该图片中高精度提取所有文本内容,包括标准文本和数字数据,
2) 对提取的内容进行全面分析,
3) 将信息进行逻辑性的组织和结构化,
4) 提供从文本/数字中发现的详细见解。
对于数字数据,请包含统计分析。以清晰的层次格式呈现您的发现,请提供完整的原始文本。如果遇到任何不清楚或模糊的元素,请突出显示并要求澄清。
请支持:
- 多种语言(简体中文、繁体中文、英文)
- 不同的文本方向和布局
- 表格、图表和结构化数据格式
- 特殊字符、符号和数学符号
- 原始格式和文本位置信息"""
encoded_image = self.image_to_base64(image)
payload = {
"model": "minicpm-v",
"prompt": question,
"images": [encoded_image]
}
try:
response = requests.post(OLLAMA_URL, json=payload, stream=True)
if response.status_code == 200:
answer = self.process_stream_response(response)
else:
raise Exception(f"Ollama API 错误: {response.status_code}")
except requests.RequestException as e:
print(f"请求 Ollama API 时出错: {str(e)}")
raise
return {
"original_answer": answer,
"timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S")
}
def process_stream_response(self, response):
full_response = []
for line in response.iter_lines():
if line:
try:
json_response = json.loads(line)
if 'response' in json_response:
full_response.append(json_response['response'])
if json_response.get('done', False):
break
except json.JSONDecodeError:
print(f"无法解析 JSON 行: {line}")
return ''.join(full_response)
@staticmethod
def image_to_base64(image):
buffered = io.BytesIO()
image.save(buffered, format="PNG")
return base64.b64encode(buffered.getvalue()).decode()
@staticmethod
def extract_time_from_filename(object_name):
filename = os.path.basename(object_name)
time_str = filename.split('_')[0] + '_' + filename.split('_')[1].split('.')[0]
try:
start_time = datetime.strptime(time_str, "%Y%m%d_%H%M%S")
end_time = start_time + timedelta(seconds=10)
return start_time, end_time
except ValueError:
print(f"无法从文件名 '{filename}' 解析时间。使用默认时间。")
return datetime.now(), datetime.now() + timedelta(seconds=10)
# 初始化 MediaAnalysisSystem
media_analysis_system = MediaAnalysisSystem()
def process_task():
for message in consumer:
print(f"收到Kafka消息: topic={message.topic}, partition={message.partition}, offset={message.offset}")
task = message.value
task_id = task['task_id']
filename = task['filename']
file_type = task['file_type']
print(f"解析任务信息: ID={task_id}, filename={filename}, type={file_type}")
file_path = os.path.join(UPLOAD_DIR, filename)
task_key = f"task:{task_id}"
try:
key_type = main_redis_client.type(task_key)
if key_type != b'hash':
main_redis_client.delete(task_key)
main_redis_client.hset(task_key, "status", "processing")
print(f"任务 {task_id} 状态更新为 'processing'")
if file_type == "image":
print(f"Processing image: {filename}")
with open(file_path, 'rb') as f:
image_data = f.read()
result = media_analysis_system.process_image(image_data, filename)
else: # video
print(f"Processing video: {filename}")
with open(file_path, 'rb') as f:
video_data = f.read()
result = media_analysis_system.process_video(video_data, filename)
if result:
redis_client.hset(f"cpm_analyze_result:{task_id}", mapping={
"result": json.dumps(result),
"result_file": filename
})
main_redis_client.hset(task_key, mapping={
"status": "completed",
"result_type": "cpm_analyze",
"result_key": f"cpm_analyze_result:{task_id}"
})
print(f"{file_type.capitalize()} {filename} 已处理,结果已保存")
else:
print(f"{file_type.capitalize()} {filename} 处理失败")
main_redis_client.hset(task_key, "status", "failed")
except Exception as e:
print(f"处理任务 {task_id} 时出错: {str(e)}")
print(f"错误详情: {traceback.format_exc()}")
main_redis_client.hset(task_key, mapping={
"status": "failed",
"error": str(e)
})
print(f"任务 {task_id} 处理完毕,等待下一个Kafka消息...")
def listen_redis_changes():
pubsub = redis_client.pubsub()
pubsub.psubscribe('__keyspace@3__:cpm_analyze_result:*') # Listen for changes on all cpm_result keys
for message in pubsub.listen():
if message['type'] == 'pmessage':
key = message['channel'].decode('utf-8').split(':')[-1]
operation = message['data'].decode('utf-8')
if operation == 'hset':
value = redis_client.hgetall(f"cpm_analyze_result:{key}")
if value:
result = {k.decode(): v.decode() for k, v in value.items()}
print(f"Result update for task {key}: {result}")
if __name__ == "__main__":
print("cpm_analyze处理程序启动...")
# Start the task processing thread
task_thread = threading.Thread(target=process_task, daemon=True)
task_thread.start()
print("任务处理线程已启动")
# Start the Redis listening thread
redis_thread = threading.Thread(target=listen_redis_changes, daemon=True)
redis_thread.start()
print("Redis监听线程已启动")
print("主程序进入等待状态...")
# Keep the main thread running
task_thread.join()
redis_thread.join()
+367
View File
@@ -0,0 +1,367 @@
import os
import json
from datetime import datetime, timedelta
from kafka import KafkaConsumer
from decord import VideoReader, cpu
from PIL import Image
import redis
from redis import Redis
import io
import re
import threading
import requests
import base64
import traceback
import json
from config import *
# 配置
OLLAMA_URL = OLLAMA_URL
KAFKA_BROKER = KAFKA_BROKER
KAFKA_TOPIC = WORKER_CONFIGS["cpm"]["kafka_topic"]
KAFKA_GROUP_ID = f"cpm_{KAFKA_GROUP_ID_PREFIX}"
REDIS_HOST = REDIS_HOST
REDIS_PORT = REDIS_PORT
REDIS_PASSWORD = REDIS_PASSWORD
REDIS_DB = WORKER_CONFIGS["cpm"]["redis_db"] # Worker使用的Redis DB
MAIN_REDIS_DB = MAIN_REDIS_DB # 主Redis DB
UPLOAD_DIR = UPLOAD_DIR
RESULT_DIR = RESULT_DIR
# 确保目录存在
os.makedirs(UPLOAD_DIR, exist_ok=True)
os.makedirs(RESULT_DIR, exist_ok=True)
# 初始化 Kafka
consumer = KafkaConsumer(
KAFKA_TOPIC,
bootstrap_servers=[KAFKA_BROKER],
group_id=KAFKA_GROUP_ID,
auto_offset_reset='earliest',
enable_auto_commit=True,
value_deserializer=lambda x: json.loads(x.decode('utf-8'))
)
# 初始化 Redis
redis_client = Redis(
host=REDIS_HOST,
port=REDIS_PORT,
password=REDIS_PASSWORD,
db=REDIS_DB
)
main_redis_client = Redis(
host=REDIS_HOST,
port=REDIS_PORT,
password=REDIS_PASSWORD,
db=MAIN_REDIS_DB
)
# 设置 GPU 设备
# torch.cuda.set_device(0)
class MediaAnalysisSystem:
def __init__(self):
self.MAX_NUM_FRAMES = 16
def encode_video(self, video_data):
def uniform_sample(l, n):
gap = len(l) / n
return [l[int(i * gap + gap / 2)] for i in range(n)]
video_file = io.BytesIO(video_data)
vr = VideoReader(video_file, ctx=cpu(0))
sample_fps = round(vr.get_avg_fps() / 1)
frame_idx = list(range(0, len(vr), sample_fps))
if len(frame_idx) > self.MAX_NUM_FRAMES:
frame_idx = uniform_sample(frame_idx, self.MAX_NUM_FRAMES)
frames = vr.get_batch(frame_idx).asnumpy()
frames = [Image.fromarray(v.astype('uint8')) for v in frames]
print('num frames:', len(frames))
return frames
def process_video(self, video_data, object_name):
if not video_data:
raise ValueError(f"Empty video data for {object_name}")
print(f"Processing video: {object_name}, data size: {len(video_data)} bytes")
frames = self.encode_video(video_data)
question = """请对这段监控视频进行详细分析,包括以下方面:
1. 场景中人数的精确统计
2. 每个人的个人行为分析
3. 面部表情识别和情绪状态评估
4. 整体场景和环境的详细描述
5. 人与人之间的互动情况
6. 时间和环境条件(如果可见)
7. 任何可疑或异常活动
8. 人员的具体特征(估计年龄范围、性别、着装)
9. 人员的移动模式和方向
10. 携带的物品或物体
11. 群体动态和聚集情况
12. 视频中的时间戳分析(如果有)
请用清晰、有条理的格式描述,并突出重要发现。"""
encoded_frames = [self.image_to_base64(frame) for frame in frames]
payload = {
"model": "minicpm-v",
"prompt": question,
"images": encoded_frames
}
try:
response = requests.post(OLLAMA_URL, json=payload, stream=True)
print(f"Ollama API 响应状态码: {response.status_code}")
print(f"Ollama API 响应头: {response.headers}")
if response.status_code == 200:
answer = self.process_stream_response(response)
else:
raise Exception(f"Ollama API 错误: {response.status_code}")
except requests.RequestException as e:
print(f"请求 Ollama API 时出错: {str(e)}")
raise
extracted_info = self.extract_info(answer)
return {
"original_answer": answer,
"extracted_info": extracted_info,
"num_frames": len(frames),
}
def process_image(self, image_data, object_name):
image = Image.open(io.BytesIO(image_data))
question = """请对这张监控图像进行详细分析,包括以下方面:
1. 场景中人数的精确统计
2. 每个人的个人行为分析
3. 面部表情识别和情绪状态评估
4. 整体场景和环境的详细描述
5. 人与人之间的互动情况
6. 时间和环境条件(如果可见)
7. 任何可疑或异常活动
8. 人员的具体特征(估计年龄范围、性别、着装)
9. 人员的位置和姿态
10. 携带的物品或物体
11. 群体动态和聚集情况
12. 图像中的时间戳信息(如果有)
请用清晰、有条理的格式描述,并突出重要发现。"""
encoded_image = self.image_to_base64(image)
payload = {
"model": "minicpm-v",
"prompt": question,
"images": [encoded_image]
}
try:
response = requests.post(OLLAMA_URL, json=payload, stream=True)
if response.status_code == 200:
answer = self.process_stream_response(response)
else:
raise Exception(f"Ollama API 错误: {response.status_code}")
except requests.RequestException as e:
print(f"请求 Ollama API 时出错: {str(e)}")
raise
extracted_info = self.extract_info(answer)
return {
"original_answer": answer,
"extracted_info": extracted_info,
"timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S")
}
def process_stream_response(self, response):
full_response = []
for line in response.iter_lines():
if line:
try:
json_response = json.loads(line)
if 'response' in json_response:
full_response.append(json_response['response'])
if json_response.get('done', False):
break
except json.JSONDecodeError:
print(f"无法解析 JSON 行: {line}")
return ''.join(full_response)
@staticmethod
def image_to_base64(image):
buffered = io.BytesIO()
image.save(buffered, format="PNG")
return base64.b64encode(buffered.getvalue()).decode()
@staticmethod
def extract_time_from_filename(object_name):
filename = os.path.basename(object_name)
time_str = filename.split('_')[0] + '_' + filename.split('_')[1].split('.')[0]
try:
start_time = datetime.strptime(time_str, "%Y%m%d_%H%M%S")
end_time = start_time + timedelta(seconds=10)
return start_time, end_time
except ValueError:
print(f"无法从文件名 '{filename}' 解析时间。使用默认时间。")
return datetime.now(), datetime.now() + timedelta(seconds=10)
@staticmethod
def extract_info(answer):
info = {
"environment": None,
"num_people": None,
"actions": [],
"interactions": [],
"objects": [],
"furniture": []
}
environments = ["办公室", "室内", "", "会议室"]
for env in environments:
if env in answer.lower():
info["environment"] = env
break
people_patterns = [
r'(\d+)\s*(人|个人|位|名|员工|用户|小朋友|成年人|女性|男性)',
r'(一|二|三|四|五|六|七|八|九|十)\s*(人|个人|位|名|员工|用户|小朋友|成年人|女性|男性)',
r'(一个|几个)\s*(人|个人|员工|用户|小朋友|成年人|女性|男性)',
r'\s*(名|位)\s*(人|员工|用户|小朋友|成年人|女性|男性)?',
r'(男|女)(性|生|士)',
r'(成年|未成年|青少年|老年)\s*(人|群体)',
r'(员工|职工|工人|学生|顾客|观众|游客|乘客)',
r'(群众|民众|大众|公众)',
r'(男女|老少|老幼|大人|小孩)'
]
for pattern in people_patterns:
match = re.search(pattern, answer)
if match:
if match.group(1).isdigit():
info["num_people"] = int(match.group(1))
elif match.group(1) in ['一个', '']:
info["num_people"] = 1
else:
num_word_to_digit = {
'': 2, '': 3, '': 4, '': 5,
'': 6, '': 7, '': 8, '': 9, '': 10
}
info["num_people"] = num_word_to_digit.get(match.group(1), 0)
break
actions = ["", "", "摔倒", "跳舞", "转身", "", "", "倒下", "躺下", "转身", "跳跃", "", "", "", "说话"]
for action in actions:
if action in answer:
info["actions"].append(action)
interactions = ["互动", "交流", "身体语言", "交谈", "讨论", "开会"]
for interaction in interactions:
if interaction in answer:
info["interactions"].append(interaction)
objects = ["水瓶", "办公用品", "文件", "电脑"]
furniture = ["椅子", "桌子", "咖啡桌", "文件柜", "", "沙发"]
for obj in objects:
if obj in answer:
info["objects"].append(obj)
for item in furniture:
if item in answer:
info["furniture"].append(item)
return info
# 初始化 MediaAnalysisSystem
media_analysis_system = MediaAnalysisSystem()
def process_task():
for message in consumer:
print(f"收到Kafka消息: topic={message.topic}, partition={message.partition}, offset={message.offset}")
task = message.value
task_id = task['task_id']
filename = task['filename']
file_type = task['file_type']
print(f"解析任务信息: ID={task_id}, filename={filename}, type={file_type}")
file_path = os.path.join(UPLOAD_DIR, filename)
task_key = f"task:{task_id}"
try:
key_type = main_redis_client.type(task_key)
if key_type != b'hash':
main_redis_client.delete(task_key)
main_redis_client.hset(task_key, "status", "processing")
print(f"任务 {task_id} 状态更新为 'processing'")
if file_type == "image":
print(f"Processing image: {filename}")
with open(file_path, 'rb') as f:
image_data = f.read()
result = media_analysis_system.process_image(image_data, filename)
else: # video
print(f"Processing video: {filename}")
with open(file_path, 'rb') as f:
video_data = f.read()
result = media_analysis_system.process_video(video_data, filename)
if result:
redis_client.hset(f"cpm_result:{task_id}", mapping={
"result": json.dumps(result),
"result_file": filename
})
main_redis_client.hset(task_key, mapping={
"status": "completed",
"result_type": "cpm",
"result_key": f"cpm_result:{task_id}"
})
print(f"{file_type.capitalize()} {filename} 已处理,结果已保存")
else:
print(f"{file_type.capitalize()} {filename} 处理失败")
main_redis_client.hset(task_key, "status", "failed")
except Exception as e:
print(f"处理任务 {task_id} 时出错: {str(e)}")
print(f"错误详情: {traceback.format_exc()}")
main_redis_client.hset(task_key, mapping={
"status": "failed",
"error": str(e)
})
print(f"任务 {task_id} 处理完毕,等待下一个Kafka消息...")
def listen_redis_changes():
pubsub = redis_client.pubsub()
pubsub.psubscribe('__keyspace@3__:cpm_result:*') # Listen for changes on all cpm_result keys
for message in pubsub.listen():
if message['type'] == 'pmessage':
key = message['channel'].decode('utf-8').split(':')[-1]
operation = message['data'].decode('utf-8')
if operation == 'hset':
value = redis_client.hgetall(f"cpm_result:{key}")
if value:
result = {k.decode(): v.decode() for k, v in value.items()}
print(f"Result update for task {key}: {result}")
if __name__ == "__main__":
print("cpm处理程序启动...")
# Start the task processing thread
task_thread = threading.Thread(target=process_task, daemon=True)
task_thread.start()
print("任务处理线程已启动")
# Start the Redis listening thread
redis_thread = threading.Thread(target=listen_redis_changes, daemon=True)
redis_thread.start()
print("Redis监听线程已启动")
print("主程序进入等待状态...")
# Keep the main thread running
task_thread.join()
redis_thread.join()
+290
View File
@@ -0,0 +1,290 @@
import os
import cv2
import torch
import numpy as np
from redis import Redis
from ultralytics import YOLO
import json
from kafka import KafkaConsumer
import threading
import redis
import torch
from config import *
torch.cuda.set_device(1)
# 配置
MODEL_PATH = FACE_MODEL_PATH
KAFKA_BROKER = KAFKA_BROKER
KAFKA_TOPIC = WORKER_CONFIGS["face"]["kafka_topic"]
KAFKA_GROUP_ID = f"face_{KAFKA_GROUP_ID_PREFIX}"
REDIS_HOST = REDIS_HOST
REDIS_PORT = REDIS_PORT
REDIS_PASSWORD = REDIS_PASSWORD
REDIS_DB = WORKER_CONFIGS["face"]["redis_db"] # Worker使用的Redis DB
MAIN_REDIS_DB = MAIN_REDIS_DB # 主Redis DB
UPLOAD_DIR = UPLOAD_DIR
RESULT_DIR = RESULT_DIR
# 初始化 Kafka
consumer = KafkaConsumer(
KAFKA_TOPIC,
bootstrap_servers=[KAFKA_BROKER],
group_id=KAFKA_GROUP_ID,
auto_offset_reset='earliest',
enable_auto_commit=True,
value_deserializer=lambda x: json.loads(x.decode('utf-8'))
)
# 初始化 Redis
redis_client = Redis(
host=REDIS_HOST,
port=REDIS_PORT,
password=REDIS_PASSWORD,
db=REDIS_DB
)
main_redis_client = Redis(
host=REDIS_HOST,
port=REDIS_PORT,
password=REDIS_PASSWORD,
db=MAIN_REDIS_DB
)
class faceDetector:
def __init__(self, model_path):
self.model = YOLO(model_path).to('cuda:1')
def detect(self, frame):
results = self.model(frame, device='cuda:1')
return results
def format_results(self, results, original_shape):
formatted_results = []
for r in results:
boxes = r.boxes
keypoints = r.keypoints
for i in range(len(boxes)):
box = boxes[i]
kpts = keypoints[i]
# 调整边界框坐标以适应原始图像大小
orig_h, orig_w = original_shape[:2]
model_h, model_w = r.orig_shape
scale_x, scale_y = orig_w / model_w, orig_h / model_h
bbox = box.xyxy[0].cpu().numpy()
bbox_scaled = [
bbox[0] * scale_x, bbox[1] * scale_y,
bbox[2] * scale_x, bbox[3] * scale_y
]
# 调整关键点坐标以适应原始图像大小
kpts_scaled = kpts.xy[0].cpu().numpy() * np.array([scale_x, scale_y])
formatted_results.append({
"bbox": bbox_scaled,
"confidence": box.conf.item(),
"keypoints": kpts_scaled.tolist()
})
return formatted_results
def draw_results(self, frame, results):
annotated_frame = frame.copy()
for r in results:
bbox = r["bbox"]
keypoints = r["keypoints"]
# 绘制边界框
cv2.rectangle(annotated_frame,
(int(bbox[0]), int(bbox[1])),
(int(bbox[2]), int(bbox[3])),
(0, 255, 0), 2)
# 绘制关键点
for kp in keypoints:
cv2.circle(annotated_frame,
(int(kp[0]), int(kp[1])),
5, (255, 0, 0), -1)
return annotated_frame
detector = faceDetector(MODEL_PATH)
def process_image(image_path):
try:
original_img = cv2.imread(image_path)
original_shape = original_img.shape
img = cv2.cvtColor(original_img, cv2.COLOR_BGR2RGB)
img = cv2.resize(img, (640, 640))
img = img.transpose((2, 0, 1))
img = np.ascontiguousarray(img)
img = torch.from_numpy(img).float()
img /= 255.0
img = img.unsqueeze(0)
results = detector.detect(img)
json_results = detector.format_results(results, original_shape)
annotated_img = detector.draw_results(original_img, json_results)
return json_results, annotated_img
except Exception as e:
print(f"处理图像时出错: {str(e)}")
return None, None
def process_video(video_path):
try:
cap = cv2.VideoCapture(video_path)
frame_count = 0
json_results = []
fps = int(cap.get(cv2.CAP_PROP_FPS))
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
original_shape = (height, width)
out = cv2.VideoWriter(video_path.replace(UPLOAD_DIR, RESULT_DIR), cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height))
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
if frame_count % fps == 0:
preprocessed_frame = preprocess_frame(frame)
results = detector.detect(preprocessed_frame)
frame_json_results = detector.format_results(results, original_shape)
json_results.append({"frame": frame_count, "detections": frame_json_results})
annotated_frame = detector.draw_results(frame, frame_json_results)
out.write(annotated_frame)
frame_count += 1
cap.release()
out.release()
return json_results
except Exception as e:
print(f"处理视频时出错: {str(e)}")
return None
def preprocess_frame(frame):
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frame_resized = cv2.resize(frame_rgb, (640, 640))
frame_transposed = frame_resized.transpose((2, 0, 1))
frame_contiguous = np.ascontiguousarray(frame_transposed)
frame_tensor = torch.from_numpy(frame_contiguous).float()
frame_normalized = frame_tensor / 255.0
frame_batched = frame_normalized.unsqueeze(0)
return frame_batched
def process_task():
print("开始处理任务,等待Kafka消息...")
for message in consumer:
print(f"收到Kafka消息: topic={message.topic}, partition={message.partition}, offset={message.offset}")
task = message.value
task_id = task['task_id']
filename = task['filename']
file_type = task['file_type']
print(f"解析任务信息: ID={task_id}, 文件名={filename}, 类型={file_type}")
file_path = os.path.join(UPLOAD_DIR, filename)
# 检查键类型并更新状态
task_key = f"task:{task_id}"
try:
key_type = main_redis_client.type(task_key)
if key_type != b'hash':
main_redis_client.delete(task_key)
main_redis_client.hset(task_key, "status", "processing")
print(f"任务 {task_id} 状态更新为 'processing'")
except redis.exceptions.ResponseError as e:
print(f"更新任务 {task_id} 状态时出错: {str(e)}")
continue # 跳过这个任务,继续处理下一个
try:
if file_type == "image":
print(f"开始处理图像: {filename}")
json_results, annotated_img = process_image(file_path)
if json_results and annotated_img is not None:
result_filename = f"face_{filename}"
result_path = os.path.join(RESULT_DIR, result_filename)
cv2.imwrite(result_path, annotated_img)
redis_client.hmset(f"face_result:{task_id}", {
"result": json.dumps(json_results),
"result_file": result_filename
})
main_redis_client.hmset(f"task:{task_id}", {
"status": "completed",
"result_type": "face",
"result_key": f"face_result:{task_id}"
})
print(f"图像 {filename} 处理完成,结果已保存")
else:
print(f"图像 {filename} 处理失败")
main_redis_client.hset(f"task:{task_id}", "status", "failed")
else: # video
print(f"开始处理视频: {filename}")
json_results = process_video(file_path)
if json_results:
result_filename = f"face_{filename}"
redis_client.hmset(f"face_result:{task_id}", {
"result": json.dumps(json_results),
"result_file": result_filename
})
main_redis_client.hmset(f"task:{task_id}", {
"status": "completed",
"result_type": "face",
"result_key": f"face_result:{task_id}"
})
print(f"视频 {filename} 处理完成,结果已保存")
else:
print(f"视频 {filename} 处理失败")
main_redis_client.hset(f"task:{task_id}", "status", "failed")
except Exception as e:
print(f"处理任务 {task_id} 时出错: {str(e)}")
main_redis_client.hmset(f"task:{task_id}", {
"status": "failed",
"error": str(e)
})
print(f"任务 {task_id} 处理完毕,等待下一个Kafka消息...")
def listen_redis_changes():
pubsub = redis_client.pubsub()
pubsub.psubscribe('__keyspace@3__:face_result:*') # 监听所有face_result键的变化
for message in pubsub.listen():
if message['type'] == 'pmessage':
key = message['channel'].decode('utf-8').split(':')[-1]
operation = message['data'].decode('utf-8')
if operation == 'hmset':
value = redis_client.hgetall(f"face_result:{key}")
if value:
result = {k.decode(): v.decode() for k, v in value.items()}
print(f"Result update for task {key}: {result}")
if __name__ == "__main__":
print("face处理程序启动...")
# 启动处理任务的线程
task_thread = threading.Thread(target=process_task, daemon=True)
task_thread.start()
print("任务处理线程已启动")
# 启动Redis监听线程
redis_thread = threading.Thread(target=listen_redis_changes, daemon=True)
redis_thread.start()
print("Redis监听线程已启动")
print("主程序进入等待状态...")
# 保持主线程运行
task_thread.join()
redis_thread.join()
+274
View File
@@ -0,0 +1,274 @@
import os
import cv2
import torch
import numpy as np
from redis import Redis
from ultralytics import YOLO
import json
from kafka import KafkaConsumer
import threading
import redis
from config import *
# 配置
MODEL_PATH = FALL_MODEL_PATH
KAFKA_BROKER = KAFKA_BROKER
KAFKA_TOPIC = WORKER_CONFIGS["fall"]["kafka_topic"]
KAFKA_GROUP_ID = f"fall_{KAFKA_GROUP_ID_PREFIX}"
REDIS_HOST = REDIS_HOST
REDIS_PORT = REDIS_PORT
REDIS_PASSWORD = REDIS_PASSWORD
REDIS_DB = WORKER_CONFIGS["fall"]["redis_db"] # Worker使用的Redis DB
MAIN_REDIS_DB = MAIN_REDIS_DB # 主Redis DB
UPLOAD_DIR = UPLOAD_DIR
RESULT_DIR = RESULT_DIR
# 确保目录存在
os.makedirs(UPLOAD_DIR, exist_ok=True)
os.makedirs(RESULT_DIR, exist_ok=True)
# 初始化 Kafka
consumer = KafkaConsumer(
KAFKA_TOPIC,
bootstrap_servers=[KAFKA_BROKER],
group_id=KAFKA_GROUP_ID,
auto_offset_reset='latest',
enable_auto_commit=True,
value_deserializer=lambda x: json.loads(x.decode('utf-8'))
)
# 初始化 Redis
redis_client = Redis(
host=REDIS_HOST,
port=REDIS_PORT,
password=REDIS_PASSWORD,
db=REDIS_DB
)
main_redis_client = Redis(
host=REDIS_HOST,
port=REDIS_PORT,
password=REDIS_PASSWORD,
db=MAIN_REDIS_DB
)
class fallDetector:
def __init__(self, model_path):
self.model = YOLO(model_path)
def detect(self, frame):
results = self.model(frame)
return results
def format_results(self, results):
formatted_results = []
for r in results:
if not hasattr(r, 'boxes') or len(r.boxes) == 0:
print("没有检测到任何对象")
return [{"message": "No objects detected"}]
boxes = r.boxes
names = getattr(r, 'names', {})
for i in range(len(boxes)):
box = boxes[i]
if not hasattr(box, 'cls') or not hasattr(box, 'conf') or not hasattr(box, 'xyxy'):
print(f"警告: 第 {i} 个框缺少必要的属性")
continue
try:
class_id = int(box.cls.item())
formatted_result = {
"bbox": box.xyxy.tolist()[0],
"confidence": box.conf.item(),
"class_id": class_id,
"class": names.get(class_id, f"Unknown-{class_id}")
}
formatted_results.append(formatted_result)
except Exception as e:
print(f"处理第 {i} 个框时出错: {str(e)}")
# print("格式化后的结果:", formatted_results)
return formatted_results
def draw_results(self, frame, results):
for r in results:
annotated_frame = r.plot()
return annotated_frame
detector = fallDetector(MODEL_PATH)
def process_image(image_data, filename):
try:
img = cv2.imdecode(np.frombuffer(image_data, np.uint8), cv2.IMREAD_COLOR)
results = detector.detect(img)
# Format results for JSON
json_results = detector.format_results(results)
# Draw results on image
annotated_img = detector.draw_results(img, results)
# Save annotated image
annotated_filename = f"fall_{filename}"
annotated_path = os.path.join(RESULT_DIR, annotated_filename)
cv2.imwrite(annotated_path, annotated_img)
return json_results, annotated_filename
except Exception as e:
print(f"Error processing image: {str(e)}")
return None, None
def process_video(video_data, filename):
try:
temp_video_path = os.path.join(UPLOAD_DIR, f"fall_{filename}")
with open(temp_video_path, 'wb') as temp_video:
temp_video.write(video_data)
cap = cv2.VideoCapture(temp_video_path)
frame_count = 0
json_results = []
# Get video properties
fps = int(cap.get(cv2.CAP_PROP_FPS))
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
# Create output video file
annotated_filename = f"fall_{filename}"
output_path = os.path.join(RESULT_DIR, annotated_filename)
out = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height))
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
results = detector.detect(frame)
frame_json_results = detector.format_results(results)
json_results.append({"frame": frame_count, "detections": frame_json_results})
annotated_frame = detector.draw_results(frame, results)
out.write(annotated_frame)
frame_count += 1
cap.release()
out.release()
# Clean up temporary input video file
os.remove(temp_video_path)
return json_results, annotated_filename
except Exception as e:
print(f"Error processing video: {str(e)}")
return None, None
def process_task():
print("开始处理任务,等待Kafka消息...")
for message in consumer:
print(f"收到Kafka消息: topic={message.topic}, partition={message.partition}, offset={message.offset}")
task = message.value
task_id = task['task_id']
filename = task['filename']
file_type = task['file_type']
print(f"解析任务信息: ID={task_id}, 文件名={filename}, 类型={file_type}")
file_path = os.path.join(UPLOAD_DIR, filename)
# 检查键类型并更新状态
task_key = f"task:{task_id}"
try:
key_type = main_redis_client.type(task_key)
if key_type != b'hash':
main_redis_client.delete(task_key)
main_redis_client.hset(f"task:{task_id}", "status", "processing")
print(f"任务 {task_id} 状态更新为 'processing'")
except redis.exceptions.ResponseError as e:
print(f"更新任务 {task_id} 状态时出错: {str(e)}")
continue # 跳过这个任务,继续处理下一个
try:
if file_type == "image":
print(f"开始处理图像: {filename}")
with open(file_path, 'rb') as f:
image_data = f.read()
json_results, annotated_filename = process_image(image_data, filename)
if json_results and annotated_filename is not None:
result_path = os.path.join(RESULT_DIR, annotated_filename)
redis_client.hset(f"fall_result:{task_id}", mapping={
"result": json.dumps(json_results),
"result_file": annotated_filename
})
main_redis_client.hset(f"task:{task_id}", "status", "completed")
main_redis_client.hset(f"task:{task_id}", "result_type", "fall")
main_redis_client.hset(f"task:{task_id}", "result_key", f"fall_result:{task_id}")
print(f"图像 {filename} 处理完成,结果已保存")
else:
print(f"图像 {filename} 处理失败")
main_redis_client.hset(f"task:{task_id}", "status", "failed")
else: # video
print(f"开始处理视频: {filename}")
with open(file_path, 'rb') as f:
video_data = f.read()
json_results, annotated_filename = process_video(video_data, filename)
if json_results and annotated_filename:
redis_client.hset(f"fall_result:{task_id}", mapping={
"result": json.dumps(json_results),
"result_file": annotated_filename
})
main_redis_client.hset(f"task:{task_id}", "status", "completed")
main_redis_client.hset(f"task:{task_id}", "result_type", "fall")
main_redis_client.hset(f"task:{task_id}", "result_key", f"fall_result:{task_id}")
print(f"视频 {filename} 处理完成,结果已保存")
else:
print(f"视频 {filename} 处理失败")
main_redis_client.hset(f"task:{task_id}", "status", "failed")
except Exception as e:
print(f"处理任务 {task_id} 时出错: {str(e)}")
main_redis_client.hset(f"task:{task_id}", {
"status": "failed",
"error": str(e)
})
print(f"任务 {task_id} 处理完毕,等待下一个Kafka消息...")
def listen_redis_changes():
pubsub = redis_client.pubsub()
pubsub.psubscribe('__keyspace@3__:fall_result:*') # 监听所有fall_result键的变化
for message in pubsub.listen():
if message['type'] == 'pmessage':
key = message['channel'].decode('utf-8').split(':')[-1]
operation = message['data'].decode('utf-8')
if operation == 'hset':
value = redis_client.hgetall(f"fall_result:{key}")
if value:
result = {k.decode(): v.decode() for k, v in value.items()}
print(f"Result update for task {key}: {result}")
if __name__ == "__main__":
print("fall处理程序启动...")
# 启动处理任务的线程
task_thread = threading.Thread(target=process_task, daemon=True)
task_thread.start()
print("任务处理线程已启动")
# 启动Redis监听线程
redis_thread = threading.Thread(target=listen_redis_changes, daemon=True)
redis_thread.start()
print("Redis监听线程已启动")
print("主程序进入等待状态...")
# 保持主线程运行
task_thread.join()
redis_thread.join()
+426
View File
@@ -0,0 +1,426 @@
# main.py
from fastapi import FastAPI, File, UploadFile, Depends, HTTPException, Security
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from fastapi.security import APIKeyHeader
from kafka import KafkaProducer
from redis import Redis
import os
import json
import uuid
from datetime import datetime, timedelta, timezone
import string
from decord import VideoReader
from PIL import Image
from fastapi.responses import FileResponse
import logging
from config import *
app = FastAPI()
v1_app = FastAPI()
app.mount("/v1", v1_app)
# CORS设置
# ALLOWED_ORIGINS = ['https://beta.obscura.work']
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# 配置
KAFKA_BROKER = KAFKA_BROKER
REDIS_HOST = REDIS_HOST
REDIS_PORT = REDIS_PORT
REDIS_PASSWORD = REDIS_PASSWORD
REDIS_DB = MAIN_REDIS_DB
REDIS_API_DB = REDIS_API_DB
REDIS_API_USAGE_DB = REDIS_API_USAGE_DB
UPLOAD_DIR = UPLOAD_DIR
RESULT_DIR = RESULT_DIR
MAX_FILE_AGE = timedelta(hours=1)
# 确保目录存在
os.makedirs(UPLOAD_DIR, exist_ok=True)
os.makedirs(RESULT_DIR, exist_ok=True)
# 定义支持的任务类型
KAFKA_TOPICS = {
'pose': 'pose',
'mediapipe': 'mediapipe',
'qwenvl': 'qwenvl',
'yolo': 'yolo',
'fall': 'fall',
'face': 'face',
'cpm': 'cpm',
'compare': 'compare'
}
TASK_TYPES = list(KAFKA_TOPICS.keys())
# 初始化 Kafka Producer
producer = KafkaProducer(
bootstrap_servers=[KAFKA_BROKER],
value_serializer=lambda v: json.dumps(v).encode('utf-8')
)
# 初始化 Redis
redis_client = Redis(
host=REDIS_HOST,
port=REDIS_PORT,
password=REDIS_PASSWORD,
db=REDIS_DB
)
redis_api_client = Redis(
host=REDIS_HOST,
port=REDIS_PORT,
password=REDIS_PASSWORD,
db=REDIS_API_DB
)
redis_api_usage_client = Redis(
host=REDIS_HOST,
port=REDIS_PORT,
password=REDIS_PASSWORD,
db=REDIS_API_USAGE_DB
)
redis_pose_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=WORKER_CONFIGS['pose']['redis_db'])
redis_cpm_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=WORKER_CONFIGS['cpm']['redis_db'])
redis_yolo_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=WORKER_CONFIGS['yolo']['redis_db'])
redis_face_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=WORKER_CONFIGS['face']['redis_db'])
redis_fall_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=WORKER_CONFIGS['fall']['redis_db'])
redis_mediapipe_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=WORKER_CONFIGS['mediapipe']['redis_db'])
redis_qwenvl_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=WORKER_CONFIGS['qwenvl']['redis_db'])
redis_compare_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=WORKER_CONFIGS['compare']['redis_db'])
@v1_app.get('/favicon.ico', include_in_schema=False)
async def favicon():
file_name = "favicon.ico"
file_path = os.path.join(app.root_path, "static", file_name)
if os.path.isfile(file_path):
return FileResponse(file_path)
else:
return {"message": "Favicon not found"}, 404
# 定义API密钥头部
api_key_header = APIKeyHeader(name="Authorization", auto_error=False)
# 定义 base62 字符集
BASE62 = string.digits + string.ascii_letters
# 验证API密钥的函数
async def get_api_key(api_key: str = Security(api_key_header)):
if api_key and api_key.startswith("Bearer "):
key = api_key.split(" ")[1]
if key.startswith("obs-"):
return key
raise HTTPException(
status_code=401,
detail="无效的API密钥",
headers={"WWW-Authenticate": "Bearer"},
)
async def verify_api_key(api_key: str = Depends(get_api_key)):
logging.info(f"验证API密钥: {api_key}")
redis_key = f"api_key:{api_key}"
api_key_info = redis_api_client.hgetall(redis_key)
if not api_key_info:
logging.warning(f"API密钥不存在: {api_key}")
raise HTTPException(status_code=401, detail="无效的API密钥")
api_key_info = {k.decode('utf-8'): v.decode('utf-8') for k, v in api_key_info.items()}
if api_key_info.get('is_active') != '1':
logging.warning(f"API密钥已停用: {api_key}")
raise HTTPException(status_code=401, detail="API密钥已停用")
expires_at = datetime.fromisoformat(api_key_info.get('expires_at'))
if datetime.now(timezone.utc) > expires_at:
logging.warning(f"API密钥已过期: {api_key}")
raise HTTPException(status_code=401, detail="API密钥已过期")
usage_info = redis_api_usage_client.hgetall(redis_key)
usage_info = {k.decode('utf-8'): v.decode('utf-8') for k, v in usage_info.items()}
logging.info(f"API密钥验证成功: {api_key}")
return {
"api_key": api_key,
**api_key_info,
**usage_info
}
async def update_token_usage(api_key: str, new_tokens_used: int, model_name: str):
redis_key = f"api_key:{api_key}"
current_time = datetime.now(timezone.utc).isoformat()
pipe = redis_api_usage_client.pipeline()
# 更新总的token使用量
pipe.hincrby(redis_key, "tokens_used", new_tokens_used)
pipe.hset(redis_key, "last_used_at", current_time)
# 更新特定模型的token使用量
model_tokens_field = f"{model_name}_tokens_used"
model_last_used_field = f"{model_name}_last_used_at"
pipe.hincrby(redis_key, model_tokens_field, new_tokens_used)
pipe.hset(redis_key, model_last_used_field, current_time)
pipe.execute()
def calculate_tokens(file_path: str, file_type: str) -> int:
base_tokens = 0
try:
file_size = os.path.getsize(file_path) # 获取文件大小(字节)
# 基础token:每MB文件大小消耗10个token
base_tokens = int((file_size / (1024 * 1024)) * 0.1)
if file_type == "image":
img = Image.open(file_path)
width, height = img.size
pixel_count = width * height
# 图片token:每100个像素额外消耗5个token
image_tokens = int((pixel_count / 100000000) * 0.1)
base_tokens += image_tokens
elif file_type == "video":
vr = VideoReader(file_path)
fps = vr.get_avg_fps()
frame_count = len(vr)
width, height = vr[0].shape[1], vr[0].shape[0]
pixel_count = width * height * frame_count
duration = frame_count / fps # 视频时长(秒)
# 视频token:每100万像素每秒额外消耗1个token
video_tokens = int((pixel_count / 100000000) * (duration / 60))
base_tokens += video_tokens
return max(1, base_tokens) # 确保至少返回1个token
except Exception as e:
print(f"计算token时出错: {str(e)}")
return 1 # 出错时返回默认值1
async def upload_file(file: UploadFile, task_type: str, api_key_info: dict):
if task_type not in KAFKA_TOPICS:
raise HTTPException(status_code=400, detail="不支持的任务类型")
content = await file.read()
file_extension = os.path.splitext(file.filename)[1].lower()
new_filename = f"{uuid.uuid4()}{file_extension}"
file_path = os.path.join(UPLOAD_DIR, new_filename)
with open(file_path, "wb") as f:
f.write(content)
# 计算 token
file_type = "image" if file_extension in ['.jpg', '.jpeg', '.png'] else "video"
tokens_required = calculate_tokens(file_path, file_type)
# 检查并更新 token 使用量
api_key = api_key_info['api_key']
usage_key = f"api_key:{api_key}"
total_tokens = int(redis_api_usage_client.hget(usage_key, "total_tokens") or 0)
tokens_used = int(redis_api_usage_client.hget(usage_key, "tokens_used") or 0)
if tokens_used + tokens_required > total_tokens:
raise HTTPException(status_code=403, detail="Token 余额不足")
# 更新 token 使用量
await update_token_usage(api_key, tokens_required, task_type)
# 创建任务记录
task_id = str(uuid.uuid4())
task_data = {
"task_id": task_id,
"filename": new_filename,
"file_type": file_type,
"task_type": task_type,
"status": "queued",
"created_at": datetime.now(timezone.utc).isoformat(),
}
# 存储任务信息到Redis
redis_client.set(f"task:{task_id}", json.dumps(task_data))
logging.info(f"任务信息已存储到Redis: {task_id}")
# 发送任务到对应的Kafka主题
kafka_topic = KAFKA_TOPICS[task_type]
producer.send(kafka_topic, task_data)
logging.info(f"任务已发送到Kafka主题: {kafka_topic}")
# 获取更新后的 token 使用情况
updated_api_key_info = await verify_api_key(api_key)
new_tokens_used = int(updated_api_key_info.get("tokens_used", 0))
model_tokens_used = int(updated_api_key_info.get(f"{task_type}_tokens_used", 0))
response_data = {
"message": "文件已上传并排队等待处理",
"task_id": task_id,
"tokens_used": tokens_required,
"total_tokens_used": new_tokens_used,
f"{task_type}_tokens_used": model_tokens_used,
"tokens_remaining": total_tokens - new_tokens_used
}
logging.info(f"上传文件完成: {task_id}")
return JSONResponse(content=response_data)
# 为每个任务类型创建单独的端点
@v1_app.post("/pose")
async def upload_pose(file: UploadFile = File(...), api_key_info: dict = Depends(verify_api_key)):
logging.info(f"收到 /pose端点的请求")
return await upload_file(file, task_type="pose", api_key_info=api_key_info)
@v1_app.post("/cpm")
async def upload_cpm(file: UploadFile = File(...), api_key_info: dict = Depends(verify_api_key)):
return await upload_file(file, task_type="cpm", api_key_info=api_key_info)
@v1_app.post("/qwenvl")
async def upload_qwenvl(file: UploadFile = File(...), api_key_info: dict = Depends(verify_api_key)):
return await upload_file(file, task_type="qwenvl", api_key_info=api_key_info)
@v1_app.post("/yolo")
async def upload_yolo(file: UploadFile = File(...), api_key_info: dict = Depends(verify_api_key)):
return await upload_file(file, task_type="yolo", api_key_info=api_key_info)
@v1_app.post("/fall")
async def upload_fall(file: UploadFile = File(...), api_key_info: dict = Depends(verify_api_key)):
return await upload_file(file, task_type="fall", api_key_info=api_key_info)
@v1_app.post("/face")
async def upload_face(file: UploadFile = File(...), api_key_info: dict = Depends(verify_api_key)):
logging.info(f"收到 /face 端点的请求")
return await upload_file(file, task_type="face", api_key_info=api_key_info)
@v1_app.post("/mediapipe")
async def upload_mediapipe(file: UploadFile = File(...), api_key_info: dict = Depends(verify_api_key)):
return await upload_file(file, task_type="mediapipe", api_key_info=api_key_info)
@v1_app.post("/compare")
async def upload_compare(file: UploadFile = File(...), api_key_info: dict = Depends(verify_api_key)):
return await upload_file(file, task_type="compare", api_key_info=api_key_info)
@v1_app.get("/result/{task_id}")
async def get_result(task_id: str, api_key: str = Depends(get_api_key)):
api_key_info = await verify_api_key(api_key)
if not api_key_info:
raise HTTPException(status_code=403, detail="无效的API密钥")
# 从 REDIS_DB (15) 获取任务状态
task_info = redis_client.hgetall(f"task:{task_id}")
if not task_info:
raise HTTPException(status_code=404, detail="Task not found")
task_info = {k.decode('utf-8'): v.decode('utf-8') for k, v in task_info.items()}
if task_info['status'] != 'completed':
return {"status": task_info['status'], "message": "Task is not completed yet"}
result_type = task_info['result_type']
result_key = task_info['result_key']
# 根据任务类型选择相应的 Redis 客户端
redis_client_map = {
'pose': redis_pose_client,
'cpm': redis_cpm_client,
'yolo': redis_yolo_client,
'face': redis_face_client,
'fall': redis_fall_client,
'mediapipe': redis_mediapipe_client,
'qwenvl': redis_qwenvl_client,
'compare': redis_compare_client
}
result_redis = redis_client_map.get(result_type)
if not result_redis:
raise HTTPException(status_code=400, detail="Unsupported result type")
result = result_redis.hgetall(result_key)
if not result:
raise HTTPException(status_code=404, detail=f"{result_type.upper()} result not found")
result = {k.decode('utf-8'): v.decode('utf-8') for k, v in result.items()}
# 将 result 字段解析为 JSON(如果存在)
if 'result' in result:
result['result'] = json.loads(result['result'])
return {
"status": "completed",
"result_type": result_type,
"result": result
}
@v1_app.get("/annotated/{task_id}")
async def get_annotated_image(task_id: str, api_key: str = Depends(get_api_key)):
api_key_info = await verify_api_key(api_key)
if not api_key_info:
raise HTTPException(status_code=403, detail="无效的API密钥")
# 从 REDIS_DB (15) 获取任务信息
task_info = redis_client.hgetall(f"task:{task_id}")
if not task_info:
raise HTTPException(status_code=404, detail="Task not found")
task_info = {k.decode('utf-8'): v.decode('utf-8') for k, v in task_info.items()}
if task_info['status'] != 'completed':
raise HTTPException(status_code=400, detail="Task is not completed yet")
result_type = task_info.get('result_type')
result_key = task_info.get('result_key')
if not result_key:
raise HTTPException(status_code=404, detail="Result key not found")
if result_type in ['cpm', 'qwenvl']:
raise HTTPException(status_code=400, detail="Annotated image not available for this task type")
# 根据任务类型选择相应的 Redis 客户端
redis_client_map = {
'pose': redis_pose_client,
'yolo': redis_yolo_client,
'face': redis_face_client,
'fall': redis_fall_client,
'mediapipe': redis_mediapipe_client,
'compare': redis_compare_client
}
result_redis = redis_client_map.get(result_type)
if not result_redis:
raise HTTPException(status_code=400, detail="Unsupported result type")
result = result_redis.hgetall(result_key)
if not result:
raise HTTPException(status_code=404, detail=f"{result_type.upper()} result not found")
result = {k.decode('utf-8'): v.decode('utf-8') for k, v in result.items()}
result_file = result.get('result_file')
if not result_file:
raise HTTPException(status_code=404, detail="Result file not found")
file_path = os.path.join(RESULT_DIR, result_file)
if not os.path.exists(file_path):
raise HTTPException(status_code=404, detail="Result image file not found")
return FileResponse(file_path, media_type="image/png")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8005)
+299
View File
@@ -0,0 +1,299 @@
import os
import cv2
import torch
import numpy as np
from redis import Redis
import json
from kafka import KafkaConsumer
import threading
import redis
import torch
import mediapipe as mp
from mediapipe.tasks import python
from mediapipe.tasks.python import vision
from config import *
# 配置
MODEL_PATH = MEDIAPIPE_MODEL_PATH
KAFKA_BROKER = KAFKA_BROKER
KAFKA_TOPIC = WORKER_CONFIGS["mediapipe"]["kafka_topic"]
KAFKA_GROUP_ID = f"mediapipe_{KAFKA_GROUP_ID_PREFIX}"
REDIS_HOST = REDIS_HOST
REDIS_PORT = REDIS_PORT
REDIS_PASSWORD = REDIS_PASSWORD
REDIS_DB = WORKER_CONFIGS["mediapipe"]["redis_db"] # Worker使用的Redis DB
MAIN_REDIS_DB = MAIN_REDIS_DB # 主Redis DB
UPLOAD_DIR = UPLOAD_DIR
RESULT_DIR = RESULT_DIR
# Ensure directories exist
os.makedirs(UPLOAD_DIR, exist_ok=True)
os.makedirs(RESULT_DIR, exist_ok=True)
# Initialize Kafka
consumer = KafkaConsumer(
KAFKA_TOPIC,
bootstrap_servers=[KAFKA_BROKER],
group_id=KAFKA_GROUP_ID,
auto_offset_reset='earliest',
enable_auto_commit=True,
value_deserializer=lambda x: json.loads(x.decode('utf-8'))
)
# 初始化 Redis
redis_client = Redis(
host=REDIS_HOST,
port=REDIS_PORT,
password=REDIS_PASSWORD,
db=REDIS_DB
)
main_redis_client = Redis(
host=REDIS_HOST,
port=REDIS_PORT,
password=REDIS_PASSWORD,
db=MAIN_REDIS_DB
)
class mediapipeEmbedder:
def __init__(self, model_path):
base_options = python.BaseOptions(model_asset_path=model_path)
options = vision.FaceLandmarkerOptions(
base_options=base_options,
output_face_blendshapes=True,
output_facial_transformation_matrixes=True,
num_faces=1
)
self.detector = vision.FaceLandmarker.create_from_options(options)
def get_mediapipe_landmarks(self, image):
mp_image = mp.Image(image_format=mp.ImageFormat.SRGB, data=image)
detection_result = self.detector.detect(mp_image)
if detection_result.face_landmarks:
return np.array([(lm.x, lm.y, lm.z) for lm in detection_result.face_landmarks[0]])
return None
def process_image(self, image_data):
img = cv2.imdecode(np.frombuffer(image_data, np.uint8), cv2.IMREAD_COLOR)
landmarks = self.get_mediapipe_landmarks(img)
if landmarks is not None:
# Calculate a more detailed mediapipe embedding
embedding = self.calculate_detailed_embedding(landmarks)
# Draw landmarks on the image
for lm in landmarks:
cv2.circle(img, (int(lm[0]*img.shape[1]), int(lm[1]*img.shape[0])), 2, (0,255,0), -1)
return {
"embedding": embedding,
"landmarks": landmarks.tolist()
}, img
else:
return None, img
def calculate_detailed_embedding(self, landmarks):
# Calculate various statistical features
mean = np.mean(landmarks, axis=0)
std = np.std(landmarks, axis=0)
median = np.median(landmarks, axis=0)
min_vals = np.min(landmarks, axis=0)
max_vals = np.max(landmarks, axis=0)
# Calculate pairwise distances between key facial landmarks
nose_tip = landmarks[4]
left_eye = landmarks[159]
right_eye = landmarks[386]
left_mouth = landmarks[61]
right_mouth = landmarks[291]
eye_distance = np.linalg.norm(left_eye - right_eye)
mouth_width = np.linalg.norm(left_mouth - right_mouth)
nose_to_mouth = np.linalg.norm(nose_tip - (left_mouth + right_mouth) / 2)
# Calculate face shape features
face_width = np.max(landmarks[:, 0]) - np.min(landmarks[:, 0])
face_height = np.max(landmarks[:, 1]) - np.min(landmarks[:, 1])
face_depth = np.max(landmarks[:, 2]) - np.min(landmarks[:, 2])
# Combine all features into a single embedding
embedding = np.concatenate([
mean, std, median, min_vals, max_vals,
[eye_distance, mouth_width, nose_to_mouth, face_width, face_height, face_depth]
])
return embedding.tolist()
embedder = mediapipeEmbedder(MODEL_PATH)
def process_image(image_data, filename):
try:
results, annotated_img = embedder.process_image(image_data)
if results:
# Save annotated image
annotated_filename = f"mediapipe_{filename}"
annotated_path = os.path.join(RESULT_DIR, annotated_filename)
cv2.imwrite(annotated_path, annotated_img)
return results, annotated_filename
else:
print(f"No face landmarks detected in image: {filename}")
return None, None
except Exception as e:
print(f"Error processing image {filename}: {str(e)}")
import traceback
traceback.print_exc()
return None, None
def process_video(video_data, filename):
try:
temp_video_path = os.path.join(UPLOAD_DIR, f"mediapipe_{filename}")
with open(temp_video_path, 'wb') as temp_video:
temp_video.write(video_data)
cap = cv2.VideoCapture(temp_video_path)
frame_count = 0
results = []
fps = int(cap.get(cv2.CAP_PROP_FPS))
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
annotated_filename = f"mediapipe_{filename}"
output_path = os.path.join(RESULT_DIR, annotated_filename)
out = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height))
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
if frame_count % fps == 0:
frame_results, annotated_frame = embedder.process_image(cv2.imencode('.jpg', frame)[1].tobytes())
if frame_results:
results.append({"frame": frame_count, "results": frame_results})
else:
annotated_frame = frame
out.write(annotated_frame)
frame_count += 1
cap.release()
out.release()
os.remove(temp_video_path)
return results, annotated_filename
except Exception as e:
print(f"Error processing video: {str(e)}")
return None, None
def process_task():
print("开始处理任务,等待Kafka消息...")
for message in consumer:
print(f"收到Kafka消息: topic={message.topic}, partition={message.partition}, offset={message.offset}")
task = message.value
task_id = task['task_id']
filename = task['filename']
file_type = task['file_type']
print(f"解析任务信息: ID={task_id}, 文件名={filename}, 类型={file_type}")
file_path = os.path.join(UPLOAD_DIR, filename)
# 检查键类型并更新状态
task_key = f"task:{task_id}"
try:
key_type = main_redis_client.type(task_key)
if key_type != b'hash':
main_redis_client.delete(task_key)
main_redis_client.hset(f"task:{task_id}", "status", "processing")
print(f"任务 {task_id} 状态更新为 'processing'")
except redis.exceptions.ResponseError as e:
print(f"更新任务 {task_id} 状态时出错: {str(e)}")
continue # 跳过这个任务,继续处理下一个
try:
if file_type == "image":
print(f"开始处理图像: {filename}")
with open(file_path, 'rb') as f:
image_data = f.read()
json_results, annotated_filename = process_image(image_data, filename)
if json_results and annotated_filename is not None:
result_path = os.path.join(RESULT_DIR, annotated_filename)
redis_client.hset(f"mediapipe_result:{task_id}", mapping={
"result": json.dumps(json_results),
"result_file": annotated_filename
})
main_redis_client.hset(f"task:{task_id}", "status", "completed")
main_redis_client.hset(f"task:{task_id}", "result_type", "mediapipe")
main_redis_client.hset(f"task:{task_id}", "result_key", f"mediapipe_result:{task_id}")
print(f"图像 {filename} 处理完成,结果已保存")
else:
print(f"图像 {filename} 处理失败")
main_redis_client.hset(f"task:{task_id}", "status", "failed")
else: # video
print(f"开始处理视频: {filename}")
with open(file_path, 'rb') as f:
video_data = f.read()
json_results, annotated_filename = process_video(video_data, filename)
if json_results and annotated_filename:
redis_client.hset(f"mediapipe_result:{task_id}", mapping={
"result": json.dumps(json_results),
"result_file": annotated_filename
})
main_redis_client.hset(f"task:{task_id}", "status", "completed")
main_redis_client.hset(f"task:{task_id}", "result_type", "mediapipe")
main_redis_client.hset(f"task:{task_id}", "result_key", f"mediapipe_result:{task_id}")
print(f"视频 {filename} 处理完成,结果已保存")
else:
print(f"视频 {filename} 处理失败")
main_redis_client.hset(f"task:{task_id}", "status", "failed")
except Exception as e:
print(f"处理任务 {task_id} 时出错: {str(e)}")
main_redis_client.hset(f"task:{task_id}", {
"status": "failed",
"error": str(e)
})
print(f"任务 {task_id} 处理完毕,等待下一个Kafka消息...")
def listen_redis_changes():
pubsub = redis_client.pubsub()
pubsub.psubscribe('__keyspace@3__:mediapipe_result:*') # 监听所有fall_result键的变化
for message in pubsub.listen():
if message['type'] == 'pmessage':
key = message['channel'].decode('utf-8').split(':')[-1]
operation = message['data'].decode('utf-8')
if operation == 'hset':
value = redis_client.hgetall(f"mediapipe_result:{key}")
if value:
result = {k.decode(): v.decode() for k, v in value.items()}
print(f"Result update for task {key}: {result}")
if __name__ == "__main__":
print("mediapipe处理程序启动...")
# 启动处理任务的线程
task_thread = threading.Thread(target=process_task, daemon=True)
task_thread.start()
print("任务处理线程已启动")
# 启动Redis监听线程
redis_thread = threading.Thread(target=listen_redis_changes, daemon=True)
redis_thread.start()
print("Redis监听线程已启动")
print("主程序进入等待状态...")
# 保持主线程运行
task_thread.join()
redis_thread.join()
+97
View File
@@ -0,0 +1,97 @@
# 将本地Ollama API完全反向代理
from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
import httpx
OLLAMA_URL = "http://127.0.0.1:11434"
app = FastAPI()
# 添加CORS中间件
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # 允许所有来源
allow_credentials=True,
allow_methods=["*"], # 允许所有HTTP方法
allow_headers=["*"], # 允许所有HTTP头
)
# 创建异步HTTP客户端
async_client = httpx.AsyncClient()
@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"])
async def proxy_to_ollama(request: Request, path: str):
if request.method == "OPTIONS":
# 处理预检请求
headers = {
"Access-Control-Allow-Origin": "*",
"Access-Control-Allow-Methods": "GET, POST, PUT, DELETE, OPTIONS",
"Access-Control-Allow-Headers": "*",
}
return StreamingResponse(content=iter([]), headers=headers)
target_url = f"{OLLAMA_URL}/{path}"
# 获取请求体
body = await request.body()
# 获取请求头
headers = dict(request.headers)
headers.pop("host", None)
try:
# 将请求转换为Python请求
python_request = {
"method": request.method,
"url": target_url,
"headers": headers,
"data": body
}
# 使用Python请求发送到Ollama API
async with async_client.stream(**python_request) as response:
# 返回Ollama API的流式响应,并添加CORS头
response_headers = dict(response.headers)
response_headers["Access-Control-Allow-Origin"] = "*"
return StreamingResponse(
response.aiter_raw(),
status_code=response.status_code,
headers=response_headers
)
except httpx.RequestError as e:
return {"error": f"请求Ollama API时发生错误: {str(e)}"}, 500
except httpx.StreamClosed:
# 处理流关闭异常
print("流已关闭,客户端可能已断开连接")
return {"error": "流已关闭,客户端可能已断开连接"}, 499
@app.on_event("shutdown")
async def shutdown_event():
await async_client.aclose()
if __name__ == "__main__":
import uvicorn
import requests
import json
# 测试Ollama API
test_url = "http://localhost:11434/api/generate"
test_data = {
"model": "llama3.1",
"prompt": "Why is the sky blue?",
"stream": False
}
try:
response = requests.post(test_url, json=test_data)
if response.status_code == 200:
print("Ollama API 测试成功:")
print(json.dumps(response.json(), indent=2))
else:
print(f"Ollama API 测试失败,状态码: {response.status_code}")
print(response.text)
except requests.RequestException as e:
print(f"Ollama API 测试出错: {str(e)}")
# 启动FastAPI应用
uvicorn.run(app, host="0.0.0.0", port=8001)
+292
View File
@@ -0,0 +1,292 @@
import os
import cv2
import torch
import numpy as np
from redis import Redis
from ultralytics import YOLO
import json
from kafka import KafkaConsumer
import threading
import redis
import torch
from config import *
torch.cuda.set_device(CUDA_DEVICE_1)
# 配置
MODEL_PATH = POSE_MODEL_PATH
KAFKA_BROKER = KAFKA_BROKER
KAFKA_TOPIC = WORKER_CONFIGS["pose"]["kafka_topic"]
KAFKA_GROUP_ID = f"pose_{KAFKA_GROUP_ID_PREFIX}"
REDIS_HOST = REDIS_HOST
REDIS_PORT = REDIS_PORT
REDIS_PASSWORD = REDIS_PASSWORD
REDIS_DB = WORKER_CONFIGS["pose"]["redis_db"] # Worker使用的Redis DB
MAIN_REDIS_DB = MAIN_REDIS_DB # 主Redis DB
UPLOAD_DIR = UPLOAD_DIR
RESULT_DIR = RESULT_DIR
# 确保目录存在
os.makedirs(UPLOAD_DIR, exist_ok=True)
os.makedirs(RESULT_DIR, exist_ok=True)
# 初始化 Kafka
consumer = KafkaConsumer(
KAFKA_TOPIC,
bootstrap_servers=[KAFKA_BROKER],
group_id=KAFKA_GROUP_ID,
auto_offset_reset='earliest',
enable_auto_commit=True,
value_deserializer=lambda x: json.loads(x.decode('utf-8'))
)
# 初始化 Redis
redis_client = Redis(
host=REDIS_HOST,
port=REDIS_PORT,
password=REDIS_PASSWORD,
db=REDIS_DB
)
main_redis_client = Redis(
host=REDIS_HOST,
port=REDIS_PORT,
password=REDIS_PASSWORD,
db=MAIN_REDIS_DB
)
class PoseDetector:
def __init__(self, model_path):
self.model = YOLO(model_path).to(CUDA_DEVICE_1)
def detect(self, frame):
results = self.model(frame, device=CUDA_DEVICE_1)
return results
def format_results(self, results, original_shape):
formatted_results = []
for r in results:
boxes = r.boxes
keypoints = r.keypoints
for i in range(len(boxes)):
box = boxes[i]
kpts = keypoints[i]
# 调整边界框坐标以适应原始图像大小
orig_h, orig_w = original_shape[:2]
model_h, model_w = r.orig_shape
scale_x, scale_y = orig_w / model_w, orig_h / model_h
bbox = box.xyxy[0].cpu().numpy()
bbox_scaled = [
bbox[0] * scale_x, bbox[1] * scale_y,
bbox[2] * scale_x, bbox[3] * scale_y
]
# 调整关键点坐标以适应原始图像大小
kpts_scaled = kpts.xy[0].cpu().numpy() * np.array([scale_x, scale_y])
formatted_results.append({
"bbox": bbox_scaled,
"confidence": box.conf.item(),
"keypoints": kpts_scaled.tolist()
})
return formatted_results
def draw_results(self, frame, results):
annotated_frame = frame.copy()
for r in results:
bbox = r["bbox"]
keypoints = r["keypoints"]
# 绘制边界框
cv2.rectangle(annotated_frame,
(int(bbox[0]), int(bbox[1])),
(int(bbox[2]), int(bbox[3])),
(0, 255, 0), 2)
# 绘制关键点
for kp in keypoints:
cv2.circle(annotated_frame,
(int(kp[0]), int(kp[1])),
5, (255, 0, 0), -1)
return annotated_frame
detector = PoseDetector(MODEL_PATH)
def process_image(image_path):
try:
original_img = cv2.imread(image_path)
original_shape = original_img.shape
img = cv2.cvtColor(original_img, cv2.COLOR_BGR2RGB)
img = cv2.resize(img, (640, 640))
img = img.transpose((2, 0, 1))
img = np.ascontiguousarray(img)
img = torch.from_numpy(img).float()
img /= 255.0
img = img.unsqueeze(0)
results = detector.detect(img)
json_results = detector.format_results(results, original_shape)
annotated_img = detector.draw_results(original_img, json_results)
return json_results, annotated_img
except Exception as e:
print(f"处理图像时出错: {str(e)}")
return None, None
def process_video(video_path):
try:
cap = cv2.VideoCapture(video_path)
frame_count = 0
json_results = []
fps = int(cap.get(cv2.CAP_PROP_FPS))
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
original_shape = (height, width)
out = cv2.VideoWriter(video_path.replace(UPLOAD_DIR, RESULT_DIR), cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height))
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
if frame_count % fps == 0:
preprocessed_frame = preprocess_frame(frame)
results = detector.detect(preprocessed_frame)
frame_json_results = detector.format_results(results, original_shape)
json_results.append({"frame": frame_count, "detections": frame_json_results})
annotated_frame = detector.draw_results(frame, frame_json_results)
out.write(annotated_frame)
frame_count += 1
cap.release()
out.release()
return json_results
except Exception as e:
print(f"处理视频时出错: {str(e)}")
return None
def preprocess_frame(frame):
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frame_resized = cv2.resize(frame_rgb, (640, 640))
frame_transposed = frame_resized.transpose((2, 0, 1))
frame_contiguous = np.ascontiguousarray(frame_transposed)
frame_tensor = torch.from_numpy(frame_contiguous).float()
frame_normalized = frame_tensor / 255.0
frame_batched = frame_normalized.unsqueeze(0)
return frame_batched
def process_task():
print("开始处理任务,等待Kafka消息...")
for message in consumer:
print(f"收到Kafka消息: topic={message.topic}, partition={message.partition}, offset={message.offset}")
task = message.value
task_id = task['task_id']
filename = task['filename']
file_type = task['file_type']
print(f"解析任务信息: ID={task_id}, 文件名={filename}, 类型={file_type}")
file_path = os.path.join(UPLOAD_DIR, filename)
# 检查键类型并更新状态
task_key = f"task:{task_id}"
try:
key_type = main_redis_client.type(task_key)
if key_type != b'hash':
main_redis_client.delete(task_key)
main_redis_client.hset(task_key, "status", "processing")
print(f"任务 {task_id} 状态更新为 'processing'")
except redis.exceptions.ResponseError as e:
print(f"更新任务 {task_id} 状态时出错: {str(e)}")
continue # 跳过这个任务,继续处理下一个
try:
if file_type == "image":
print(f"开始处理图像: {filename}")
json_results, annotated_img = process_image(file_path)
if json_results and annotated_img is not None:
result_filename = f"pose_{filename}"
result_path = os.path.join(RESULT_DIR, result_filename)
cv2.imwrite(result_path, annotated_img)
redis_client.hmset(f"pose_result:{task_id}", {
"result": json.dumps(json_results),
"result_file": result_filename
})
main_redis_client.hmset(f"task:{task_id}", {
"status": "completed",
"result_type": "pose",
"result_key": f"pose_result:{task_id}"
})
print(f"图像 {filename} 处理完成,结果已保存")
else:
print(f"图像 {filename} 处理失败")
main_redis_client.hset(f"task:{task_id}", "status", "failed")
else: # video
print(f"开始处理视频: {filename}")
json_results = process_video(file_path)
if json_results:
result_filename = f"pose_{filename}"
redis_client.hmset(f"pose_result:{task_id}", {
"result": json.dumps(json_results),
"result_file": result_filename
})
main_redis_client.hmset(f"task:{task_id}", {
"status": "completed",
"result_type": "pose",
"result_key": f"pose_result:{task_id}"
})
print(f"视频 {filename} 处理完成,结果已保存")
else:
print(f"视频 {filename} 处理失败")
main_redis_client.hset(f"task:{task_id}", "status", "failed")
except Exception as e:
print(f"处理任务 {task_id} 时出错: {str(e)}")
main_redis_client.hmset(f"task:{task_id}", {
"status": "failed",
"error": str(e)
})
print(f"任务 {task_id} 处理完毕,等待下一个Kafka消息...")
def listen_redis_changes():
pubsub = redis_client.pubsub()
pubsub.psubscribe('__keyspace@3__:pose_result:*') # 监听所有pose_result键的变化
for message in pubsub.listen():
if message['type'] == 'pmessage':
key = message['channel'].decode('utf-8').split(':')[-1]
operation = message['data'].decode('utf-8')
if operation == 'hmset':
value = redis_client.hgetall(f"pose_result:{key}")
if value:
result = {k.decode(): v.decode() for k, v in value.items()}
print(f"Result update for task {key}: {result}")
if __name__ == "__main__":
print("pose处理程序启动...")
# 启动处理任务的线程
task_thread = threading.Thread(target=process_task, daemon=True)
task_thread.start()
print("任务处理线程已启动")
# 启动Redis监听线程
redis_thread = threading.Thread(target=listen_redis_changes, daemon=True)
redis_thread.start()
print("Redis监听线程已启动")
print("主程序进入等待状态...")
# 保持主线程运行
task_thread.join()
redis_thread.join()
+419
View File
@@ -0,0 +1,419 @@
# main.py
from fastapi import FastAPI, File, UploadFile, Depends, HTTPException, Security
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from fastapi.security import APIKeyHeader
from kafka import KafkaProducer
from redis import Redis
import os
import json
import uuid
from datetime import datetime, timedelta, timezone
import string
from decord import VideoReader
from PIL import Image
from fastapi.responses import FileResponse
import logging
from config import *
app = FastAPI()
v1_app = FastAPI()
app.mount("/v1", v1_app)
# CORS设置
# ALLOWED_ORIGINS = ['https://beta.obscura.work']
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# 配置
KAFKA_BROKER = KAFKA_BROKER
REDIS_HOST = REDIS_HOST
REDIS_PORT = REDIS_PORT
REDIS_PASSWORD = REDIS_PASSWORD
REDIS_DB = MAIN_REDIS_DB
REDIS_API_DB = REDIS_API_DB
REDIS_API_USAGE_DB = REDIS_API_USAGE_DB
UPLOAD_DIR = UPLOAD_DIR
RESULT_DIR = RESULT_DIR
MAX_FILE_AGE = timedelta(hours=1)
# 确保目录存在
os.makedirs(UPLOAD_DIR, exist_ok=True)
os.makedirs(RESULT_DIR, exist_ok=True)
# 定义支持的任务类型
KAFKA_TOPICS = {
'pose': 'pose',
'mediapipe': 'mediapipe',
'qwenvl': 'qwenvl',
'yolo': 'yolo',
'fall': 'fall',
'face': 'face',
'cpm': 'cpm'
}
TASK_TYPES = list(KAFKA_TOPICS.keys())
# 初始化 Kafka Producer
producer = KafkaProducer(
bootstrap_servers=[KAFKA_BROKER],
value_serializer=lambda v: json.dumps(v).encode('utf-8')
)
# 初始化 Redis
redis_client = Redis(
host=REDIS_HOST,
port=REDIS_PORT,
password=REDIS_PASSWORD,
db=REDIS_DB
)
redis_api_client = Redis(
host=REDIS_HOST,
port=REDIS_PORT,
password=REDIS_PASSWORD,
db=REDIS_API_DB
)
redis_api_usage_client = Redis(
host=REDIS_HOST,
port=REDIS_PORT,
password=REDIS_PASSWORD,
db=REDIS_API_USAGE_DB
)
redis_pose_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=WORKER_CONFIGS['pose']['redis_db'])
redis_cpm_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=WORKER_CONFIGS['cpm']['redis_db'])
redis_yolo_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=WORKER_CONFIGS['yolo']['redis_db'])
redis_face_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=WORKER_CONFIGS['face']['redis_db'])
redis_fall_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=WORKER_CONFIGS['fall']['redis_db'])
redis_mediapipe_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=WORKER_CONFIGS['mediapipe']['redis_db'])
redis_qwenvl_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=WORKER_CONFIGS['qwenvl']['redis_db'])
@v1_app.get('/favicon.ico', include_in_schema=False)
async def favicon():
file_name = "favicon.ico"
file_path = os.path.join(app.root_path, "static", file_name)
if os.path.isfile(file_path):
return FileResponse(file_path)
else:
return {"message": "Favicon not found"}, 404
# 定义API密钥头部
api_key_header = APIKeyHeader(name="Authorization", auto_error=False)
# 定义 base62 字符集
BASE62 = string.digits + string.ascii_letters
# 验证API密钥的函数
async def get_api_key(api_key: str = Security(api_key_header)):
if api_key and api_key.startswith("Bearer "):
key = api_key.split(" ")[1]
if key.startswith("obs-"):
return key
raise HTTPException(
status_code=401,
detail="无效的API密钥",
headers={"WWW-Authenticate": "Bearer"},
)
async def verify_api_key(api_key: str = Depends(get_api_key)):
logging.info(f"验证API密钥: {api_key}")
redis_key = f"api_key:{api_key}"
api_key_info = redis_api_client.hgetall(redis_key)
if not api_key_info:
logging.warning(f"API密钥不存在: {api_key}")
raise HTTPException(status_code=401, detail="无效的API密钥")
api_key_info = {k.decode('utf-8'): v.decode('utf-8') for k, v in api_key_info.items()}
if api_key_info.get('is_active') != '1':
logging.warning(f"API密钥已停用: {api_key}")
raise HTTPException(status_code=401, detail="API密钥已停用")
expires_at = datetime.fromisoformat(api_key_info.get('expires_at'))
if datetime.now(timezone.utc) > expires_at:
logging.warning(f"API密钥已过期: {api_key}")
raise HTTPException(status_code=401, detail="API密钥已过期")
usage_info = redis_api_usage_client.hgetall(redis_key)
usage_info = {k.decode('utf-8'): v.decode('utf-8') for k, v in usage_info.items()}
logging.info(f"API密钥验证成功: {api_key}")
return {
"api_key": api_key,
**api_key_info,
**usage_info
}
async def update_token_usage(api_key: str, new_tokens_used: int, model_name: str):
redis_key = f"api_key:{api_key}"
current_time = datetime.now(timezone.utc).isoformat()
pipe = redis_api_usage_client.pipeline()
# 更新总的token使用量
pipe.hincrby(redis_key, "tokens_used", new_tokens_used)
pipe.hset(redis_key, "last_used_at", current_time)
# 更新特定模型的token使用量
model_tokens_field = f"{model_name}_tokens_used"
model_last_used_field = f"{model_name}_last_used_at"
pipe.hincrby(redis_key, model_tokens_field, new_tokens_used)
pipe.hset(redis_key, model_last_used_field, current_time)
pipe.execute()
def calculate_tokens(file_path: str, file_type: str) -> int:
base_tokens = 0
try:
file_size = os.path.getsize(file_path) # 获取文件大小(字节)
# 基础token:每MB文件大小消耗10个token
base_tokens = int((file_size / (1024 * 1024)) * 10)
if file_type == "image":
img = Image.open(file_path)
width, height = img.size
pixel_count = width * height
# 图片token:每100个像素额外消耗5个token
image_tokens = int((pixel_count / 10000) * 5)
base_tokens += image_tokens
elif file_type == "video":
vr = VideoReader(file_path)
fps = vr.get_avg_fps()
frame_count = len(vr)
width, height = vr[0].shape[1], vr[0].shape[0]
pixel_count = width * height * frame_count
duration = frame_count / fps # 视频时长(秒)
# 视频token:每100万像素每秒额外消耗1个token
video_tokens = int((pixel_count / 10000) * (duration / 60))
base_tokens += video_tokens
return max(1, base_tokens) # 确保至少返回1个token
except Exception as e:
print(f"计算token时出错: {str(e)}")
return 1 # 出错时返回默认值1
async def upload_file(file: UploadFile, task_type: str, api_key_info: dict):
if task_type not in KAFKA_TOPICS:
raise HTTPException(status_code=400, detail="不支持的任务类型")
content = await file.read()
file_extension = os.path.splitext(file.filename)[1].lower()
new_filename = f"{uuid.uuid4()}{file_extension}"
file_path = os.path.join(UPLOAD_DIR, new_filename)
with open(file_path, "wb") as f:
f.write(content)
# 计算 token
file_type = "image" if file_extension in ['.jpg', '.jpeg', '.png'] else "video"
tokens_required = calculate_tokens(file_path, file_type)
# 检查并更新 token 使用量
api_key = api_key_info['api_key']
usage_key = f"api_key:{api_key}"
total_tokens = int(redis_api_usage_client.hget(usage_key, "total_tokens") or 0)
tokens_used = int(redis_api_usage_client.hget(usage_key, "tokens_used") or 0)
if tokens_used + tokens_required > total_tokens:
raise HTTPException(status_code=403, detail="Token 余额不足")
# 更新 token 使用量
await update_token_usage(api_key, tokens_required, task_type)
# 创建任务记录
task_id = str(uuid.uuid4())
task_data = {
"task_id": task_id,
"filename": new_filename,
"file_type": file_type,
"task_type": task_type,
"status": "queued",
"created_at": datetime.now(timezone.utc).isoformat(),
}
# 存储任务信息到Redis
redis_client.set(f"task:{task_id}", json.dumps(task_data))
logging.info(f"任务信息已存储到Redis: {task_id}")
# 发送任务到对应的Kafka主题
kafka_topic = KAFKA_TOPICS[task_type]
producer.send(kafka_topic, task_data)
logging.info(f"任务已发送到Kafka主题: {kafka_topic}")
# 获取更新后的 token 使用情况
updated_api_key_info = await verify_api_key(api_key)
new_tokens_used = int(updated_api_key_info.get("tokens_used", 0))
model_tokens_used = int(updated_api_key_info.get(f"{task_type}_tokens_used", 0))
response_data = {
"message": "文件已上传并排队等待处理",
"task_id": task_id,
"tokens_used": tokens_required,
"total_tokens_used": new_tokens_used,
f"{task_type}_tokens_used": model_tokens_used,
"tokens_remaining": total_tokens - new_tokens_used
}
logging.info(f"上传文件完成: {task_id}")
return JSONResponse(content=response_data)
# 为每个任务类型创建单独的端点
@v1_app.post("/pose")
async def upload_pose(file: UploadFile = File(...), api_key_info: dict = Depends(verify_api_key)):
logging.info(f"收到 /pose端点的请求")
return await upload_file(file, task_type="pose", api_key_info=api_key_info)
@v1_app.post("/cpm")
async def upload_cpm(file: UploadFile = File(...), api_key_info: dict = Depends(verify_api_key)):
return await upload_file(file, task_type="cpm", api_key_info=api_key_info)
@v1_app.post("/qwenvl")
async def upload_qwenvl(file: UploadFile = File(...), api_key_info: dict = Depends(verify_api_key)):
return await upload_file(file, task_type="qwenvl", api_key_info=api_key_info)
@v1_app.post("/yolo")
async def upload_yolo(file: UploadFile = File(...), api_key_info: dict = Depends(verify_api_key)):
return await upload_file(file, task_type="yolo", api_key_info=api_key_info)
@v1_app.post("/fall")
async def upload_fall(file: UploadFile = File(...), api_key_info: dict = Depends(verify_api_key)):
return await upload_file(file, task_type="fall", api_key_info=api_key_info)
@v1_app.post("/face")
async def upload_face(file: UploadFile = File(...), api_key_info: dict = Depends(verify_api_key)):
logging.info(f"收到 /face 端点的请求")
return await upload_file(file, task_type="face", api_key_info=api_key_info)
@v1_app.post("/mediapipe")
async def upload_mediapipe(file: UploadFile = File(...), api_key_info: dict = Depends(verify_api_key)):
return await upload_file(file, task_type="mediapipe", api_key_info=api_key_info)
@v1_app.get("/result/{task_id}")
async def get_result(task_id: str, api_key: str = Depends(get_api_key)):
api_key_info = await verify_api_key(api_key)
if not api_key_info:
raise HTTPException(status_code=403, detail="无效的API密钥")
# 从 REDIS_DB (15) 获取任务状态
task_info = redis_client.hgetall(f"task:{task_id}")
if not task_info:
raise HTTPException(status_code=404, detail="Task not found")
task_info = {k.decode('utf-8'): v.decode('utf-8') for k, v in task_info.items()}
if task_info['status'] != 'completed':
return {"status": task_info['status'], "message": "Task is not completed yet"}
result_type = task_info['result_type']
result_key = task_info['result_key']
# 根据任务类型选择相应的 Redis 客户端
redis_client_map = {
'pose': redis_pose_client,
'cpm': redis_cpm_client,
'yolo': redis_yolo_client,
'face': redis_face_client,
'fall': redis_fall_client,
'mediapipe': redis_mediapipe_client,
'qwenvl': redis_qwenvl_client
}
result_redis = redis_client_map.get(result_type)
if not result_redis:
raise HTTPException(status_code=400, detail="Unsupported result type")
result = result_redis.hgetall(result_key)
if not result:
raise HTTPException(status_code=404, detail=f"{result_type.upper()} result not found")
result = {k.decode('utf-8'): v.decode('utf-8') for k, v in result.items()}
# 将 result 字段解析为 JSON(如果存在)
if 'result' in result:
result['result'] = json.loads(result['result'])
return {
"status": "completed",
"result_type": result_type,
"result": result
}
@v1_app.get("/annotated/{task_id}")
async def get_annotated_image(task_id: str, api_key: str = Depends(get_api_key)):
api_key_info = await verify_api_key(api_key)
if not api_key_info:
raise HTTPException(status_code=403, detail="无效的API密钥")
# 从 REDIS_DB (15) 获取任务信息
task_info = redis_client.hgetall(f"task:{task_id}")
if not task_info:
raise HTTPException(status_code=404, detail="Task not found")
task_info = {k.decode('utf-8'): v.decode('utf-8') for k, v in task_info.items()}
if task_info['status'] != 'completed':
raise HTTPException(status_code=400, detail="Task is not completed yet")
result_type = task_info.get('result_type')
result_key = task_info.get('result_key')
if not result_key:
raise HTTPException(status_code=404, detail="Result key not found")
if result_type in ['cpm', 'qwenvl']:
raise HTTPException(status_code=400, detail="Annotated image not available for this task type")
# 根据任务类型选择相应的 Redis 客户端
redis_client_map = {
'pose': redis_pose_client,
'yolo': redis_yolo_client,
'face': redis_face_client,
'fall': redis_fall_client,
'mediapipe': redis_mediapipe_client
}
result_redis = redis_client_map.get(result_type)
if not result_redis:
raise HTTPException(status_code=400, detail="Unsupported result type")
result = result_redis.hgetall(result_key)
if not result:
raise HTTPException(status_code=404, detail=f"{result_type.upper()} result not found")
result = {k.decode('utf-8'): v.decode('utf-8') for k, v in result.items()}
result_file = result.get('result_file')
if not result_file:
raise HTTPException(status_code=404, detail="Result file not found")
file_path = os.path.join(RESULT_DIR, result_file)
if not os.path.exists(file_path):
raise HTTPException(status_code=404, detail="Result image file not found")
return FileResponse(file_path, media_type="image/png")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8005)
+271
View File
@@ -0,0 +1,271 @@
import os
import json
import uuid
from datetime import datetime, timedelta
from kafka import KafkaConsumer
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info
from decord import VideoReader, cpu
from PIL import Image
import redis
from redis import Redis
import io
import re
import threading
from config import *
# 配置
MODEL_PATH = QWEN_MODEL_PATH
KAFKA_BROKER = KAFKA_BROKER
KAFKA_TOPIC = WORKER_CONFIGS["qwenvl_analyze"]["kafka_topic"]
KAFKA_GROUP_ID = f"qwenvl_analyze_{KAFKA_GROUP_ID_PREFIX}"
REDIS_HOST = REDIS_HOST
REDIS_PORT = REDIS_PORT
REDIS_PASSWORD = REDIS_PASSWORD
REDIS_DB = WORKER_CONFIGS["qwenvl_analyze"]["redis_db"] # Worker使用的Redis DB
MAIN_REDIS_DB = MAIN_REDIS_DB # 主Redis DB
UPLOAD_DIR = UPLOAD_DIR
RESULT_DIR = RESULT_DIR
# 确保目录存在
os.makedirs(UPLOAD_DIR, exist_ok=True)
os.makedirs(RESULT_DIR, exist_ok=True)
# 初始化 Kafka
consumer = KafkaConsumer(
KAFKA_TOPIC,
bootstrap_servers=[KAFKA_BROKER],
group_id=KAFKA_GROUP_ID,
auto_offset_reset='earliest',
enable_auto_commit=True,
value_deserializer=lambda x: json.loads(x.decode('utf-8'))
)
# 初始化 Redis
redis_client = Redis(
host=REDIS_HOST,
port=REDIS_PORT,
password=REDIS_PASSWORD,
db=REDIS_DB
)
main_redis_client = Redis(
host=REDIS_HOST,
port=REDIS_PORT,
password=REDIS_PASSWORD,
db=MAIN_REDIS_DB
)
# 初始化模型
model = Qwen2VLForConditionalGeneration.from_pretrained(
MODEL_PATH, torch_dtype="auto", device_map="cuda:0"
)
min_pixels = 128*28*28
max_pixels = 512*28*28
processor = AutoProcessor.from_pretrained(MODEL_PATH, min_pixels=min_pixels, max_pixels=max_pixels)
class MediaAnalysisSystem:
def __init__(self, model, processor):
self.model = model
self.processor = processor
self.MAX_NUM_FRAMES = 10
def encode_video(self, video_data):
def uniform_sample(l, n):
gap = len(l) / n
return [l[int(i * gap + gap / 2)] for i in range(n)]
video_file = io.BytesIO(video_data)
vr = VideoReader(video_file)
sample_fps = round(vr.get_avg_fps() / 1)
frame_idx = list(range(0, len(vr), sample_fps))
if len(frame_idx) > self.MAX_NUM_FRAMES:
frame_idx = uniform_sample(frame_idx, self.MAX_NUM_FRAMES)
frames = vr.get_batch(frame_idx).asnumpy()
frames = [Image.fromarray(v.astype('uint8')) for v in frames]
print('num frames:', len(frames))
return frames
def process_media(self, media_data, object_name, media_type='image'):
if not media_data:
raise ValueError(f"Empty {media_type} data for {object_name}")
print(f"Processing {media_type}: {object_name}, data size: {len(media_data)} bytes")
if media_type == 'video':
frames = self.encode_video(media_data)
media_content = {"type": "video", "video": frames, "fps": 1.0}
else: # image
image = Image.open(io.BytesIO(media_data))
media_content = {"type": "image", "image": image}
messages = [
{
"role": "user",
"content": [
media_content,
{"type": "text", "text": """您是一个高级OCR和文本分析助手。您的主要任务是:
1) 从这个""" + ("视频" if media_type == "video" else "图片") + """中高精度提取所有文本内容,包括标准文本和数字数据,
2) 对提取的内容进行全面分析,
3) 将信息进行逻辑性的组织和结构化,
4) 提供从文本/数字中发现的详细见解。
对于数字数据,请包含统计分析。以清晰的层次格式呈现您的发现,请提供完整的原始文本。如果遇到任何不清楚或模糊的元素,请突出显示并要求澄清。
请支持:
- 多种语言(简体中文、繁体中文、英文)
- 不同的文本方向和布局
- 表格、图表和结构化数据格式
- 特殊字符、符号和数学符号
- 原始格式和文本位置信息"""},
],
}
]
text = self.processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
image_inputs, video_inputs = process_vision_info(messages)
inputs = self.processor(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt",
)
inputs = inputs.to('cuda:0')
generated_ids = self.model.generate(**inputs, max_new_tokens=2048)
generated_ids_trimmed = [
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
answer = self.processor.batch_decode(
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)[0]
result = {
"original_answer": answer,
"timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S")
}
if media_type == 'video':
result["num_frames"] = len(frames)
return result
def process_video(self, video_data, object_name):
return self.process_media(video_data, object_name, media_type='video')
def process_image(self, image_data, object_name):
return self.process_media(image_data, object_name, media_type='image')
@staticmethod
def extract_time_from_filename(object_name):
filename = os.path.basename(object_name)
time_str = filename.split('_')[0] + '_' + filename.split('_')[1].split('.')[0]
try:
start_time = datetime.strptime(time_str, "%Y%m%d_%H%M%S")
end_time = start_time + timedelta(seconds=10)
return start_time, end_time
except ValueError:
print(f"无法从文件名 '{filename}' 解析时间。使用默认时间。")
return datetime.now(), datetime.now() + timedelta(seconds=10)
# 初始化 MediaAnalysisSystem
media_analysis_system = MediaAnalysisSystem(model, processor)
def process_task():
print("开始处理任务,等待Kafka消息...")
for message in consumer:
print(f"收到Kafka消息: topic={message.topic}, partition={message.partition}, offset={message.offset}")
task = message.value
task_id = task['task_id']
filename = task['filename']
file_type = task['file_type']
print(f"解析任务信息: ID={task_id}, filename={filename}, type={file_type}")
file_path = os.path.join(UPLOAD_DIR, filename)
# Check key type and update status
task_key = f"task:{task_id}"
try:
key_type = main_redis_client.type(task_key)
if key_type != b'hash':
main_redis_client.delete(task_key)
main_redis_client.hset(f"task:{task_id}", "status", "processing")
print(f"任务 {task_id} 状态更新为 'processing'")
except redis.exceptions.ResponseError as e:
print(f"更新任务 {task_id} 状态时出错: {str(e)}")
continue # Skip this task and continue with the next one
try:
if file_type == "image":
print(f"Processing image: {filename}")
with open(file_path, 'rb') as f:
image_data = f.read()
result = media_analysis_system.process_image(image_data, filename)
else: # video
print(f"Processing video: {filename}")
with open(file_path, 'rb') as f:
video_data = f.read()
result = media_analysis_system.process_video(video_data, filename)
if result:
redis_client.hset(f"qwenvl_analyze_result:{task_id}", mapping={
"result": json.dumps(result),
"result_file": filename
})
main_redis_client.hset(f"task:{task_id}", "status", "completed")
main_redis_client.hset(f"task:{task_id}", "result_type", "qwenvl_analyze")
main_redis_client.hset(f"task:{task_id}", "result_key", f"qwenvl_analyze_result:{task_id}")
print(f"{file_type.capitalize()} {filename} processed, result saved")
else:
print(f"{file_type.capitalize()} {filename} processing failed")
main_redis_client.hset(f"task:{task_id}", "status", "failed")
except Exception as e:
error_msg = str(e)
print(f"处理任务 {task_id} 时出错: {error_msg}")
try:
# 分开设置每个字段,避免使用字典
main_redis_client.hset(f"task:{task_id}", "status", "failed")
main_redis_client.hset(f"task:{task_id}", "error", error_msg)
except Exception as redis_error:
print(f"更新任务状态时出错: {str(redis_error)}")
print(f"任务 {task_id} 处理完毕,等待下一个Kafka消息...")
def listen_redis_changes():
pubsub = redis_client.pubsub()
pubsub.psubscribe('__keyspace@3__:qwenvl_analyze_result:*') # Listen for changes on all qwenvl_result keys
for message in pubsub.listen():
if message['type'] == 'pmessage':
key = message['channel'].decode('utf-8').split(':')[-1]
operation = message['data'].decode('utf-8')
if operation == 'hset':
value = redis_client.hgetall(f"qwenvl_analyze_result:{key}")
if value:
result = {k.decode(): v.decode() for k, v in value.items()}
print(f"Result update for task {key}: {result}")
if __name__ == "__main__":
print("qwenvl_analyze处理程序启动...")
# Start the task processing thread
task_thread = threading.Thread(target=process_task, daemon=True)
task_thread.start()
print("任务处理线程已启动")
# Start the Redis listening thread
redis_thread = threading.Thread(target=listen_redis_changes, daemon=True)
redis_thread.start()
print("Redis监听线程已启动")
print("主程序进入等待状态...")
# Keep the main thread running
task_thread.join()
redis_thread.join()
+329
View File
@@ -0,0 +1,329 @@
import os
import json
import uuid
from datetime import datetime, timedelta
from kafka import KafkaConsumer
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info
from decord import VideoReader, cpu
from PIL import Image
import redis
from redis import Redis
import io
import re
import threading
from config import *
# 配置
MODEL_PATH = QWEN_MODEL_PATH
KAFKA_BROKER = KAFKA_BROKER
KAFKA_TOPIC = WORKER_CONFIGS["qwenvl"]["kafka_topic"]
KAFKA_GROUP_ID = f"qwenvl_{KAFKA_GROUP_ID_PREFIX}"
REDIS_HOST = REDIS_HOST
REDIS_PORT = REDIS_PORT
REDIS_PASSWORD = REDIS_PASSWORD
REDIS_DB = WORKER_CONFIGS["qwenvl"]["redis_db"] # Worker使用的Redis DB
MAIN_REDIS_DB = MAIN_REDIS_DB # 主Redis DB
UPLOAD_DIR = UPLOAD_DIR
RESULT_DIR = RESULT_DIR
# 确保目录存在
os.makedirs(UPLOAD_DIR, exist_ok=True)
os.makedirs(RESULT_DIR, exist_ok=True)
# 初始化 Kafka
consumer = KafkaConsumer(
KAFKA_TOPIC,
bootstrap_servers=[KAFKA_BROKER],
group_id=KAFKA_GROUP_ID,
auto_offset_reset='earliest',
enable_auto_commit=True,
value_deserializer=lambda x: json.loads(x.decode('utf-8'))
)
# 初始化 Redis
redis_client = Redis(
host=REDIS_HOST,
port=REDIS_PORT,
password=REDIS_PASSWORD,
db=REDIS_DB
)
main_redis_client = Redis(
host=REDIS_HOST,
port=REDIS_PORT,
password=REDIS_PASSWORD,
db=MAIN_REDIS_DB
)
# 初始化模型
model = Qwen2VLForConditionalGeneration.from_pretrained(
MODEL_PATH, torch_dtype="auto", device_map="cuda:1"
)
min_pixels = 128*28*28
max_pixels = 512*28*28
processor = AutoProcessor.from_pretrained(MODEL_PATH, min_pixels=min_pixels, max_pixels=max_pixels)
class MediaAnalysisSystem:
def __init__(self, model, processor):
self.model = model
self.processor = processor
self.MAX_NUM_FRAMES = 10
def encode_video(self, video_data):
def uniform_sample(l, n):
gap = len(l) / n
return [l[int(i * gap + gap / 2)] for i in range(n)]
video_file = io.BytesIO(video_data)
vr = VideoReader(video_file)
sample_fps = round(vr.get_avg_fps() / 1)
frame_idx = list(range(0, len(vr), sample_fps))
if len(frame_idx) > self.MAX_NUM_FRAMES:
frame_idx = uniform_sample(frame_idx, self.MAX_NUM_FRAMES)
frames = vr.get_batch(frame_idx).asnumpy()
frames = [Image.fromarray(v.astype('uint8')) for v in frames]
print('num frames:', len(frames))
return frames
def process_media(self, media_data, object_name, media_type='image'):
if not media_data:
raise ValueError(f"Empty {media_type} data for {object_name}")
print(f"Processing {media_type}: {object_name}, data size: {len(media_data)} bytes")
if media_type == 'video':
frames = self.encode_video(media_data)
media_content = {"type": "video", "video": frames, "fps": 1.0}
else: # image
image = Image.open(io.BytesIO(media_data))
media_content = {"type": "image", "image": image}
messages = [
{
"role": "user",
"content": [
media_content,
{"type": "text", "text": f"""请对这{'段监控视频' if media_type == 'video' else '张监控图像'}进行详细分析,包括以下方面:
1. 场景中人数的精确统计
2. 每个人的个人行为分析
3. 面部表情识别和情绪状态评估
4. 整体场景和环境的详细描述
5. 人与人之间的互动情况
6. 时间和环境条件(如果可见)
7. 任何可疑或异常活动
8. 人员的具体特征(估计年龄范围、性别、着装)
9. 人员的{'移动模式和方向' if media_type == 'video' else '位置和姿态'}
10. 携带的物品或物体
11. 群体动态和聚集情况
12. {'视频' if media_type == 'video' else '图像'}中的时间戳信息(如果有)
请用清晰、有条理的格式描述,并突出重要发现。"""},
],
}
]
text = self.processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
image_inputs, video_inputs = process_vision_info(messages)
inputs = self.processor(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt",
)
inputs = inputs.to('cuda:1')
generated_ids = self.model.generate(**inputs, max_new_tokens=2048)
generated_ids_trimmed = [
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
answer = self.processor.batch_decode(
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)[0]
extracted_info = self.extract_info(answer)
result = {
"original_answer": answer,
"extracted_info": extracted_info,
"timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S")
}
if media_type == 'video':
result["num_frames"] = len(frames)
return result
def process_video(self, video_data, object_name):
return self.process_media(video_data, object_name, media_type='video')
def process_image(self, image_data, object_name):
return self.process_media(image_data, object_name, media_type='image')
@staticmethod
def extract_time_from_filename(object_name):
filename = os.path.basename(object_name)
time_str = filename.split('_')[0] + '_' + filename.split('_')[1].split('.')[0]
try:
start_time = datetime.strptime(time_str, "%Y%m%d_%H%M%S")
end_time = start_time + timedelta(seconds=10)
return start_time, end_time
except ValueError:
print(f"无法从文件名 '{filename}' 解析时间。使用默认时间。")
return datetime.now(), datetime.now() + timedelta(seconds=10)
@staticmethod
def extract_info(answer):
info = {
"environment": None,
"num_people": None,
"actions": [],
"interactions": [],
"objects": [],
"furniture": []
}
environments = ["办公室", "室内", "室外", "会议室"]
for env in environments:
if env in answer.lower():
info["environment"] = env
break
people_patterns = [
r'(\d+)\s*(人|个人|位|名|员工|用户|小朋友|成年人|女性|男性)',
r'(一|二|三|四|五|六|七|八|九|十)\s*(人|个人|位|名|员工|用户|小朋友|成年人|女性|男性)',
r'(一个|几个)\s*(人|个人|员工|用户|小朋友|成年人|女性|男性)',
r'\s*(名|位)\s*(人|员工|用户|小朋友|成年人|女性|男性)?',
r'(男|女)(性|生|士)',
r'(成年|未成年|青少年|老年)\s*(人|群体)',
r'(员工|职工|工人|学生|顾客|观众|游客|乘客)',
r'(群众|民众|大众|公众)',
r'(男女|老少|老幼|大人|小孩)'
]
for pattern in people_patterns:
match = re.search(pattern, answer)
if match:
if match.group(1).isdigit():
info["num_people"] = int(match.group(1))
elif match.group(1) in ['一个', '']:
info["num_people"] = 1
else:
num_word_to_digit = {
'': 2, '': 3, '': 4, '': 5,
'': 6, '': 7, '': 8, '': 9, '': 10
}
info["num_people"] = num_word_to_digit.get(match.group(1), 0)
break
actions = ["", "", "摔倒", "跳舞", "转身", "", "", "倒下", "躺下", "转身", "跳跃", "", "", "", "说话"]
interactions = ["互动", "交流", "身体语言", "交谈", "讨论", "开会"]
objects = ["水瓶", "办公用品", "文件", "电脑"]
furniture = ["椅子", "桌子", "咖啡桌", "文件柜", "", "沙发"]
for item_list, key in [(actions, "actions"), (interactions, "interactions"), (objects, "objects"), (furniture, "furniture")]:
for item in item_list:
if item in answer:
info[key].append(item)
return info
# 初始化 MediaAnalysisSystem
media_analysis_system = MediaAnalysisSystem(model, processor)
def process_task():
print("开始处理任务,等待Kafka消息...")
for message in consumer:
print(f"收到Kafka消息: topic={message.topic}, partition={message.partition}, offset={message.offset}")
task = message.value
task_id = task['task_id']
filename = task['filename']
file_type = task['file_type']
print(f"解析任务信息: ID={task_id}, filename={filename}, type={file_type}")
file_path = os.path.join(UPLOAD_DIR, filename)
# Check key type and update status
task_key = f"task:{task_id}"
try:
key_type = main_redis_client.type(task_key)
if key_type != b'hash':
main_redis_client.delete(task_key)
main_redis_client.hset(f"task:{task_id}", "status", "processing")
print(f"任务 {task_id} 状态更新为 'processing'")
except redis.exceptions.ResponseError as e:
print(f"更新任务 {task_id} 状态时出错: {str(e)}")
continue # Skip this task and continue with the next one
try:
if file_type == "image":
print(f"Processing image: {filename}")
with open(file_path, 'rb') as f:
image_data = f.read()
result = media_analysis_system.process_image(image_data, filename)
else: # video
print(f"Processing video: {filename}")
with open(file_path, 'rb') as f:
video_data = f.read()
result = media_analysis_system.process_video(video_data, filename)
if result:
redis_client.hset(f"qwenvl_result:{task_id}", mapping={
"result": json.dumps(result),
"result_file": filename
})
main_redis_client.hset(f"task:{task_id}", "status", "completed")
main_redis_client.hset(f"task:{task_id}", "result_type", "qwenvl")
main_redis_client.hset(f"task:{task_id}", "result_key", f"qwenvl_result:{task_id}")
print(f"{file_type.capitalize()} {filename} processed, result saved")
else:
print(f"{file_type.capitalize()} {filename} processing failed")
main_redis_client.hset(f"task:{task_id}", "status", "failed")
except Exception as e:
error_msg = str(e)
print(f"处理任务 {task_id} 时出错: {error_msg}")
try:
# 分开设置每个字段,避免使用字典
main_redis_client.hset(f"task:{task_id}", "status", "failed")
main_redis_client.hset(f"task:{task_id}", "error", error_msg)
except Exception as redis_error:
print(f"更新任务状态时出错: {str(redis_error)}")
print(f"任务 {task_id} 处理完毕,等待下一个Kafka消息...")
def listen_redis_changes():
pubsub = redis_client.pubsub()
pubsub.psubscribe('__keyspace@3__:qwenvl_result:*') # Listen for changes on all qwenvl_result keys
for message in pubsub.listen():
if message['type'] == 'pmessage':
key = message['channel'].decode('utf-8').split(':')[-1]
operation = message['data'].decode('utf-8')
if operation == 'hset':
value = redis_client.hgetall(f"qwenvl_result:{key}")
if value:
result = {k.decode(): v.decode() for k, v in value.items()}
print(f"Result update for task {key}: {result}")
if __name__ == "__main__":
print("qwenvl处理程序启动...")
# Start the task processing thread
task_thread = threading.Thread(target=process_task, daemon=True)
task_thread.start()
print("任务处理线程已启动")
# Start the Redis listening thread
redis_thread = threading.Thread(target=listen_redis_changes, daemon=True)
redis_thread.start()
print("Redis监听线程已启动")
print("主程序进入等待状态...")
# Keep the main thread running
task_thread.join()
redis_thread.join()
+282
View File
@@ -0,0 +1,282 @@
import os
import cv2
import torch
import numpy as np
from redis import Redis
from ultralytics import YOLO
import json
from kafka import KafkaConsumer
import threading
import redis
from config import *
# 配置
MODEL_PATH = YOLO_MODEL_PATH
KAFKA_BROKER = KAFKA_BROKER
KAFKA_TOPIC = WORKER_CONFIGS["yolo"]["kafka_topic"]
KAFKA_GROUP_ID = f"yolo_{KAFKA_GROUP_ID_PREFIX}"
REDIS_HOST = REDIS_HOST
REDIS_PORT = REDIS_PORT
REDIS_PASSWORD = REDIS_PASSWORD
REDIS_DB = WORKER_CONFIGS["yolo"]["redis_db"] # Worker使用的Redis DB
MAIN_REDIS_DB = MAIN_REDIS_DB # 主Redis DB
UPLOAD_DIR = UPLOAD_DIR
RESULT_DIR = RESULT_DIR
# 确保目录存在
os.makedirs(UPLOAD_DIR, exist_ok=True)
os.makedirs(RESULT_DIR, exist_ok=True)
KAFKA_TOPIC = 'yolo'
# 初始化 Kafka Consumer
consumer = KafkaConsumer(
KAFKA_TOPIC,
bootstrap_servers=[KAFKA_BROKER],
group_id=KAFKA_GROUP_ID,
auto_offset_reset='earliest',
enable_auto_commit=True,
value_deserializer=lambda x: json.loads(x.decode('utf-8'))
)
# 初始化 Redis
redis_client = Redis(
host=REDIS_HOST,
port=REDIS_PORT,
password=REDIS_PASSWORD,
db=REDIS_DB
)
main_redis_client = Redis(
host=REDIS_HOST,
port=REDIS_PORT,
password=REDIS_PASSWORD,
db=MAIN_REDIS_DB
)
# YOLO detector class
class yoloDetector:
def __init__(self, model_path):
self.model = YOLO(model_path)
def detect(self, frame):
results = self.model(frame)
return results
def format_results(self, results, original_shape):
formatted_results = []
for r in results:
boxes = r.boxes
for box in boxes:
x1, y1, x2, y2 = box.xyxy[0].tolist()
x1, x2 = [x * original_shape[1] / 640 for x in [x1, x2]]
y1, y2 = [y * original_shape[0] / 640 for y in [y1, y2]]
conf = box.conf.item()
cls = int(box.cls.item())
name = self.model.names[cls]
formatted_results.append({
"class": name,
"confidence": conf,
"bbox": [x1, y1, x2, y2]
})
return formatted_results
def draw_results(self, frame, formatted_results):
for result in formatted_results:
x1, y1, x2, y2 = map(int, result['bbox'])
name = result['class']
conf = result['confidence']
cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 255, 0), 2)
label = f"{name} {conf:.2f}"
(text_width, text_height), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 1, 2)
cv2.rectangle(frame, (x1, y1 - text_height - 5), (x1 + text_width, y1), (0, 255, 0), -1)
cv2.putText(frame, label, (x1, y1 - 5), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 0), 2)
return frame
detector = yoloDetector(MODEL_PATH)
def process_image(image_path):
try:
original_img = cv2.imread(image_path)
original_shape = original_img.shape
img = cv2.cvtColor(original_img, cv2.COLOR_BGR2RGB)
img = cv2.resize(img, (640, 640))
img = img.transpose((2, 0, 1))
img = np.ascontiguousarray(img)
img = torch.from_numpy(img).float()
img /= 255.0
img = img.unsqueeze(0)
results = detector.detect(img)
json_results = detector.format_results(results, original_shape)
annotated_img = detector.draw_results(original_img, json_results)
return json_results, annotated_img
except Exception as e:
print(f"处理图像时出错: {str(e)}")
return None, None
def process_video(video_path):
try:
cap = cv2.VideoCapture(video_path)
frame_count = 0
json_results = []
fps = int(cap.get(cv2.CAP_PROP_FPS))
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
original_shape = (height, width)
out = cv2.VideoWriter(video_path.replace(UPLOAD_DIR, RESULT_DIR), cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height))
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
if frame_count % fps == 0:
preprocessed_frame = preprocess_frame(frame)
results = detector.detect(preprocessed_frame)
frame_json_results = detector.format_results(results, original_shape)
json_results.append({"frame": frame_count, "detections": frame_json_results})
annotated_frame = detector.draw_results(frame, frame_json_results)
out.write(annotated_frame)
frame_count += 1
cap.release()
out.release()
return json_results
except Exception as e:
print(f"处理视频时出错: {str(e)}")
return None
def preprocess_frame(frame):
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frame_resized = cv2.resize(frame_rgb, (640, 640))
frame_transposed = frame_resized.transpose((2, 0, 1))
frame_contiguous = np.ascontiguousarray(frame_transposed)
frame_tensor = torch.from_numpy(frame_contiguous).float()
frame_normalized = frame_tensor / 255.0
frame_batched = frame_normalized.unsqueeze(0)
return frame_batched
def process_task():
print("开始处理任务,等待Kafka消息...")
for message in consumer:
print(f"收到Kafka消息: topic={message.topic}, partition={message.partition}, offset={message.offset}")
task = message.value
task_id = task['task_id']
filename = task['filename']
file_type = task['file_type']
print(f"解析任务信息: ID={task_id}, 文件名={filename}, 类型={file_type}")
file_path = os.path.join(UPLOAD_DIR, filename)
# 检查键类型并更新状态
task_key = f"task:{task_id}"
try:
key_type = main_redis_client.type(task_key)
if key_type != b'hash':
main_redis_client.delete(task_key)
main_redis_client.hset(task_key, "status", "processing")
print(f"任务 {task_id} 状态更新为 'processing'")
except redis.exceptions.ResponseError as e:
print(f"更新任务 {task_id} 状态时出错: {str(e)}")
continue # 跳过这个任务,继续处理下一个
try:
if file_type == "image":
print(f"开始处理图像: {filename}")
json_results, annotated_img = process_image(file_path)
if json_results and annotated_img is not None:
result_filename = f"yolo_{filename}"
result_path = os.path.join(RESULT_DIR, result_filename)
cv2.imwrite(result_path, annotated_img)
redis_client.hmset(f"yolo_result:{task_id}", {
"result": json.dumps(json_results),
"result_file": result_filename
})
main_redis_client.hmset(f"task:{task_id}", {
"status": "completed",
"result_type": "yolo",
"result_key": f"yolo_result:{task_id}"
})
print(f"图像 {filename} 处理完成,结果已保存")
else:
print(f"图像 {filename} 处理失败")
main_redis_client.hset(f"task:{task_id}", "status", "failed")
else: # video
print(f"开始处理视频: {filename}")
json_results = process_video(file_path)
if json_results:
result_filename = f"yolo_{filename}"
redis_client.hmset(f"yolo_result:{task_id}", {
"result": json.dumps(json_results),
"result_file": result_filename
})
main_redis_client.hmset(f"task:{task_id}", {
"status": "completed",
"result_type": "yolo",
"result_key": f"yolo_result:{task_id}"
})
print(f"视频 {filename} 处理完成,结果已保存")
else:
print(f"视频 {filename} 处理失败")
main_redis_client.hset(f"task:{task_id}", "status", "failed")
except Exception as e:
print(f"处理任务 {task_id} 时出错: {str(e)}")
main_redis_client.hmset(f"task:{task_id}", {
"status": "failed",
"error": str(e)
})
print(f"任务 {task_id} 处理完毕,等待下一个Kafka消息...")
def listen_redis_changes():
pubsub = redis_client.pubsub()
pubsub.psubscribe('__keyspace@3__:yolo_result:*') # 监听所有yolo_result键的变化
for message in pubsub.listen():
if message['type'] == 'pmessage':
key = message['channel'].decode('utf-8').split(':')[-1]
operation = message['data'].decode('utf-8')
if operation == 'hmset':
value = redis_client.hgetall(f"yolo_result:{key}")
if value:
result = {k.decode(): v.decode() for k, v in value.items()}
print(f"Result update for task {key}: {result}")
if __name__ == "__main__":
print("YOLO处理程序启动...")
# 启动处理任务的线程
task_thread = threading.Thread(target=process_task, daemon=True)
task_thread.start()
print("任务处理线程已启动")
# 启动Redis监听线程
redis_thread = threading.Thread(target=listen_redis_changes, daemon=True)
redis_thread.start()
print("Redis监听线程已启动")
print("主程序进入等待状态...")
# 保持主线程运行
task_thread.join()
redis_thread.join()
+94
View File
@@ -0,0 +1,94 @@
# Kafka 配置
KAFKA_BROKER=222.186.10.253:9092
KAFKA_ASR_TOPIC=asr
KAFKA_CHAT_TOPIC=chat
KAFKA_TTS_TOPIC=tts
# Redis 配置
REDIS_HOST=150.158.144.159
REDIS_PORT=13003
REDIS_ASR_DB=12
REDIS_CHAT_DB=13
REDIS_TTS_DB=14
REDIS_PASSWORD=Obscura@2024
REDIS_API_DB=2
REDIS_API_USAGE_DB=3
REDIS_TASK_DB=11
REDIS_SESSION_DB=63
REDIS_SESSION_DB_ZH=63
REDIS_SESSION_DB_EN=62
REDIS_SESSION_DB_KO=61
# CORS 配置
# ALLOWED_ORIGINS=https://beta.obscura.work
# GPT-SoVITS 配置
GPT_MODEL_PATH=GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt
SOVITS_MODEL_PATH=GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth
REF_AUDIO_PATH=sample/woman.wav
REF_TEXT_PATH=sample/woman.txt
REF_LANGUAGE=中文
TARGET_LANGUAGE=多语种混合
OUTPUT_PATH=/obscura/task/audio_files
# VOICE_CONFIGS
GIRL_REF_AUDIO=sample/gril.wav
GIRL_REF_TEXT=sample/gril.txt
WOMAN_REF_AUDIO=sample/woman.wav
WOMAN_REF_TEXT=sample/woman.txt
MAN_REF_AUDIO=sample/man.wav
MAN_REF_TEXT=sample/man.txt
LEIJUN_REF_AUDIO=sample/leijun.wav
LEIJUN_REF_TEXT=sample/leijun.txt
DUFU_REF_AUDIO=sample/dufu.wav
DUFU_REF_TEXT=sample/dufu.txt
HEJIONG_REF_AUDIO=sample/hejiong.wav
HEJIONG_REF_TEXT=sample/hejiong.txt
MAHUATENG_REF_AUDIO=sample/mahuateng.wav
MAHUATENG_REF_TEXT=sample/mahuateng.txt
LIDAN_REF_AUDIO=sample/lidan.wav
LIDAN_REF_TEXT=sample/lidan.txt
YUHUA_REF_AUDIO=sample/yuhua.wav
YUHUA_REF_TEXT=sample/yuhua.txt
LIUZHENYUN_REF_AUDIO=sample/liuzhenyun.wav
LIUZHENYUN_REF_TEXT=sample/liuzhenyun.txt
DABING_REF_AUDIO=sample/dabing.wav
DABING_REF_TEXT=sample/dabing.txt
LUOXIANG_REF_AUDIO=sample/luoxiang.wav
LUOXIANG_REF_TEXT=sample/luoxiang.txt
XUZHIYUAN_REF_AUDIO=sample/xuzhiyuan.wav
XUZHIYUAN_REF_TEXT=sample/xuzhiyuan.txt
REDIS_GIRL_DB = 15
REDIS_WOMAN_DB = 16
REDIS_MAN_DB = 17
REDIS_LEIJUN_DB = 18
REDIS_DUFU_DB = 19
REDIS_HEJIONG_DB = 20
REDIS_MAHUATENG_DB = 21
REDIS_LIDAN_DB = 22
REDIS_DABING_DB = 23
REDIS_LUOXIANG_DB = 24
REDIS_XUZHIYUAN_DB = 25
REDIS_YUHUA_DB = 26
REDIS_LIUZHENYUN_DB = 27
View File
View File
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
+110
View File
@@ -0,0 +1,110 @@
import whisper
import os
import json
import redis
from dotenv import load_dotenv
from kafka import KafkaConsumer
import asyncio
# 设置要使用的GPU ID
GPU_ID = 1 # 修改这个值来选择要使用的GPU
# 设置CUDA_VISIBLE_DEVICES环境变量
os.environ["CUDA_VISIBLE_DEVICES"] = str(GPU_ID)
# 加载环境变量
load_dotenv()
print("正在加载Whisper模型...")
model = whisper.load_model("large-v3")
print("Whisper模型加载完成。")
# Kafka配置
KAFKA_BROKER = os.getenv('KAFKA_BROKER')
KAFKA_TOPIC = os.getenv('KAFKA_ASR_TOPIC')
# Redis配置
REDIS_HOST = os.getenv('REDIS_HOST')
REDIS_PORT = int(os.getenv('REDIS_PORT'))
REDIS_ASR_DB = int(os.getenv('REDIS_ASR_DB'))
REDIS_PASSWORD = os.getenv('REDIS_PASSWORD')
REDIS_TASK_DB = int(os.getenv('REDIS_TASK_DB'))
# 创建Redis客户端
redis_asr_client = redis.Redis(
host=REDIS_HOST,
port=REDIS_PORT,
db=REDIS_ASR_DB,
password=REDIS_PASSWORD
)
redis_task_client = redis.Redis(
host=REDIS_HOST,
port=REDIS_PORT,
db=REDIS_TASK_DB,
password=REDIS_PASSWORD
)
async def process_audio(file_path: str, cache_key: str):
try:
# 更新任务状态
redis_task_client.set(f"task_status:{cache_key}", "processing")
result = model.transcribe(file_path)
transcription = result['text']
print(f"处理了文件: {file_path}")
print(f"转录结果: {transcription}")
redis_asr_client.setex(cache_key, 3600, transcription)
result_data = {
'transcription': transcription
}
redis_asr_client.publish('asr_results', json.dumps(result_data))
# 更新任务状态
redis_task_client.set(f"task_status:{cache_key}", "completed")
os.remove(file_path)
except Exception as e:
print(f"处理音频文件时发生错误: {str(e)}")
# 更新任务状态
redis_task_client.set(f"task_status:{cache_key}", "error")
async def kafka_consumer():
consumer = KafkaConsumer(
KAFKA_TOPIC,
bootstrap_servers=[KAFKA_BROKER],
value_deserializer=lambda x: json.loads(x.decode('utf-8')),
group_id='asr_group',
auto_offset_reset='earliest',
enable_auto_commit=True
)
print(f"ASR消费者已启动")
for message in consumer:
try:
task = message.value
file_path = task.get('file_path')
task_id = task.get('task_id')
status = task.get('status')
if not file_path or not task_id or status != 'queued':
print(f"收到无效任务: {task}")
continue
cache_key = f"asr:{task_id}"
print(f"开始处理任务: {cache_key}")
await process_audio(file_path, cache_key)
print(f"完成处理任务: {cache_key}")
except Exception as e:
print(f"处理消息时发生错误: {str(e)}")
if __name__ == "__main__":
print("启动Kafka消费者处理ASR请求...")
asyncio.run(kafka_consumer())
+115
View File
@@ -0,0 +1,115 @@
import whisper
import os
import json
import redis
from dotenv import load_dotenv
from kafka import KafkaConsumer
import threading
# 设置要使用的GPU ID
GPU_ID = 1 # 修改这个值来选择要使用的GPU
# 设置CUDA_VISIBLE_DEVICES环境变量
os.environ["CUDA_VISIBLE_DEVICES"] = str(GPU_ID)
# 加载环境变量
load_dotenv()
print("正在加载Whisper模型...")
model = whisper.load_model("large-v3")
print("Whisper模型加载完成。")
# Kafka配置
KAFKA_BROKER = os.getenv('KAFKA_BROKER')
KAFKA_TOPIC = os.getenv('KAFKA_ASR_TOPIC')
# Redis配置
REDIS_HOST = os.getenv('REDIS_HOST')
REDIS_PORT = int(os.getenv('REDIS_PORT'))
REDIS_ASR_DB = int(os.getenv('REDIS_ASR_DB'))
REDIS_PASSWORD = os.getenv('REDIS_PASSWORD')
REDIS_TASK_DB = int(os.getenv('REDIS_TASK_DB'))
# 创建Redis客户端
redis_asr_client = redis.Redis(
host=REDIS_HOST,
port=REDIS_PORT,
db=REDIS_ASR_DB,
password=REDIS_PASSWORD
)
redis_task_client = redis.Redis(
host=REDIS_HOST,
port=REDIS_PORT,
db=REDIS_TASK_DB,
password=REDIS_PASSWORD
)
def process_audio(file_path: str, client_id: str, cache_key: str):
try:
# 设置任务状态为 "processing"
redis_task_client.set(f"task_status:{cache_key}", "processing")
result = model.transcribe(file_path)
transcription = result['text']
print(f"处理了文件: {file_path}")
print(f"转录结果: {transcription}")
# 将结果存入Redis缓存
redis_asr_client.setex(cache_key, 3600, transcription) # 缓存1小时
# 发布结果到Redis频道
result_data = {
'client_id': client_id,
'transcription': transcription
}
redis_asr_client.publish('asr_results', json.dumps(result_data))
# 设置任务状态为 "completed"
redis_task_client.set(f"task_status:{cache_key}", "completed")
# 清理临时文件
os.remove(file_path)
except Exception as e:
print(f"处理音频文件时发生错误: {str(e)}")
# 设置任务状态为 "error"
redis_task_client.set(f"task_status:{cache_key}", "error")
def kafka_consumer():
consumer = KafkaConsumer(
KAFKA_TOPIC,
bootstrap_servers=[KAFKA_BROKER],
value_deserializer=lambda x: json.loads(x.decode('utf-8')),
group_id='asr_group',
auto_offset_reset='earliest',
enable_auto_commit=True
)
print(f"ASR消费者已启动")
for message in consumer:
try:
task = message.value
file_path = task.get('file_path')
task_id = task.get('task_id')
status = task.get('status')
if not file_path or not task_id or status != 'queued':
print(f"收到无效任务: {task}")
continue
cache_key = f"asr:{task_id}"
client_id = task_id # 使用task_id作为client_id
print(f"开始处理任务: {cache_key}")
process_audio(file_path, client_id, cache_key)
print(f"完成处理任务: {cache_key}")
except Exception as e:
print(f"处理消息时发生错误: {str(e)}")
if __name__ == "__main__":
print("启动Kafka消费者处理ASR请求...")
kafka_consumer()
+117
View File
@@ -0,0 +1,117 @@
from kafka import KafkaConsumer
import json
import asyncio
import redis
import os
from dotenv import load_dotenv
import requests
from concurrent.futures import ThreadPoolExecutor
# 加载 .env 文件
load_dotenv()
# Kafka 设置
KAFKA_BROKER = os.getenv('KAFKA_BROKER')
KAFKA_CHAT_TOPIC = os.getenv('KAFKA_CHAT_TOPIC')
KAFKA_CONSUMER_GROUP = 'chat_group'
KAFKA_CONSUMER_NUM = 3 # 消费者数量
# Redis 设置
REDIS_HOST = os.getenv('REDIS_HOST')
REDIS_PORT = int(os.getenv('REDIS_PORT'))
REDIS_CHAT_DB = int(os.getenv('REDIS_CHAT_DB'))
REDIS_PASSWORD = os.getenv('REDIS_PASSWORD')
REDIS_TASK_DB = int(os.getenv('REDIS_TASK_DB'))
# 创建Redis客户端
redis_client = redis.Redis(
host=REDIS_HOST,
port=REDIS_PORT,
db=REDIS_CHAT_DB,
password=REDIS_PASSWORD
)
redis_task_client = redis.Redis(
host=REDIS_HOST,
port=REDIS_PORT,
db=REDIS_TASK_DB,
password=REDIS_PASSWORD
)
DEFAULT_SYSTEM_PROMPT = "你是一个智能助手,请用尽可能简短、简洁、友好的方式回答问题。输入的所有内容都来自于语音识别输入,因此可能会出现各种错误,请尽可能猜测用户的意思"
# 创建Kafka消费者
def create_kafka_consumer():
return KafkaConsumer(
KAFKA_CHAT_TOPIC,
bootstrap_servers=KAFKA_BROKER,
auto_offset_reset='latest',
enable_auto_commit=True,
group_id=KAFKA_CONSUMER_GROUP,
value_deserializer=lambda x: json.loads(x.decode('utf-8'))
)
async def process_chat_request(chat_request):
try:
session_id = chat_request['session_id']
query = chat_request['query']
model = chat_request.get('model', 'qwen2.5:3b')
# 设置任务状态为 "processing"
redis_task_client.set(f"task_status:{session_id}", "processing")
# 从Redis获取历史记录
history = json.loads(redis_client.get(session_id) or '[]')
# 构建包含历史对话的完整提示
full_prompt = DEFAULT_SYSTEM_PROMPT + "\n"
for past_query, past_response in history:
full_prompt += f"用户: {past_query}\n助手: {past_response}\n"
full_prompt += f"用户: {query}\n助手:"
data = {
"model": model,
"prompt": full_prompt,
"stream": True,
"temperature": 0
}
response = requests.post("http://127.0.0.1:11434/api/generate", json=data, stream=True)
response.raise_for_status()
text_output = ""
for line in response.iter_lines():
if line:
json_data = json.loads(line)
if 'response' in json_data:
text_output += json_data['response']
# 更新历史记录
history.append((query, text_output))
redis_client.set(session_id, json.dumps(history))
# 设置任务状态为 "completed"
redis_task_client.set(f"task_status:{session_id}", "completed")
print(f"处理完成 session {session_id}: {text_output}")
except Exception as e:
print(f"处理 session {chat_request['session_id']} 时出错: {str(e)}")
# 设置任务状态为 "error"
redis_task_client.set(f"task_status:{chat_request['session_id']}", "error")
def kafka_consumer_thread(consumer_id):
consumer = create_kafka_consumer()
print(f"消费者 {consumer_id} 已启动")
for message in consumer:
chat_request = message.value
asyncio.run(process_chat_request(chat_request))
def main():
print("启动Kafka消费者处理聊天请求...")
with ThreadPoolExecutor(max_workers=KAFKA_CONSUMER_NUM) as executor:
for i in range(KAFKA_CONSUMER_NUM):
executor.submit(kafka_consumer_thread, i)
if __name__ == '__main__':
main()
+182
View File
@@ -0,0 +1,182 @@
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import requests
import json
from typing import List, Tuple
from kafka import KafkaConsumer, TopicPartition
from concurrent.futures import ThreadPoolExecutor
import threading
import asyncio
import redis
import uuid
import logging
import uvicorn
from dotenv import load_dotenv
import os
import torch
from modelscope import AutoModelForCausalLM, AutoTokenizer
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = FastAPI()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
torch.cuda.set_device(device)
print(f"Using device: {device}")
# Load MiniCPM3-4B model
path = "/home/zydi/worker_chat/api/OpenBMB/MiniCPM3-4B"
tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(path, torch_dtype=torch.bfloat16, device_map=device, trust_remote_code=True)
# 加载 .env 文件
load_dotenv()
# CORS 配置
ALLOWED_ORIGINS = os.getenv('ALLOWED_ORIGINS').split(',')
# 添加CORS中间件
app.add_middleware(
CORSMiddleware,
allow_origins=ALLOWED_ORIGINS,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Kafka 设置
KAFKA_BROKER = os.getenv('KAFKA_BROKER')
KAFKA_TOPIC = os.getenv('KAFKA_MINI3_TOPIC')
KAFKA_CONSUMER_GROUP = 'mini3_group'
KAFKA_CONSUMER_NUM = 1
# Redis 设置
REDIS_HOST = os.getenv('REDIS_HOST')
REDIS_PORT = int(os.getenv('REDIS_PORT'))
REDIS_DB = int(os.getenv('REDIS_MINI3_DB'))
REDIS_PASSWORD = os.getenv('REDIS_PASSWORD')
# 创建Redis客户端
redis_client = redis.Redis(
host=REDIS_HOST,
port=REDIS_PORT,
db=REDIS_DB,
password=REDIS_PASSWORD # 使用密码进行认证
)
# 创建Kafka消费者
def create_kafka_consumer():
return KafkaConsumer(
bootstrap_servers=KAFKA_BROKER,
auto_offset_reset='earliest',
enable_auto_commit=True,
group_id=KAFKA_CONSUMER_GROUP,
value_deserializer=lambda x: json.loads(x.decode('utf-8'))
)
# Kafka消费者函数
def kafka_consumer(consumer, consumer_id):
# 获取消费者分配的分区
consumer.subscribe([KAFKA_TOPIC])
partitions = consumer.assignment()
logger.info(f"消费者 {consumer_id} 被分配了以下分区: {[p.partition for p in partitions]}")
for message in consumer:
partition = message.partition
offset = message.offset
chat_request = message.value # 直接使用 message.value,它已经是一个字典
session_id = chat_request['session_id']
query = chat_request['query']
logger.info(f"消费者 {consumer_id} 正在处理来自分区 {partition} 的消息:")
asyncio.run(process_chat_request(chat_request))
# 启动Kafka消费者线程
def start_kafka_consumers(num_consumers=KAFKA_CONSUMER_NUM):
consumers = []
for i in range(num_consumers):
consumer = create_kafka_consumer()
consumer_thread = threading.Thread(target=kafka_consumer, args=(consumer, i), daemon=True)
consumer_thread.start()
consumers.append((consumer, consumer_thread))
return consumers
DEFAULT_SYSTEM_PROMPT = "你是一个智能助手,请用尽可能简短、简洁、友好的方式回答问题。输入的所有内容都来自于语音识别输入,因此可能会出现各种错误,请尽可能猜测用户的意思"
class ChatRequest(BaseModel):
session_id: str
query: str
model: str = "minicpm3-4b"
class ChatResponse(BaseModel):
response: str
history: List[Tuple[str, str]]
# 处理聊天请求的异步函数
async def process_chat_request(chat_request):
try:
response = await chat(ChatRequest(**chat_request))
print(f"Processed message for session {chat_request['session_id']}: {response}")
except Exception as e:
print(f"Error processing message for session {chat_request['session_id']}: {str(e)}")
@app.post("/mini3", response_model=ChatResponse)
async def chat(request: ChatRequest):
session_id = request.session_id
query = request.query
# 从Redis获取历史记录
history = json.loads(redis_client.get(session_id) or '[]')
# 构建包含历史对话的完整提示
messages = [{"role": "system", "content": DEFAULT_SYSTEM_PROMPT}]
for past_query, past_response in history:
messages.append({"role": "user", "content": past_query})
messages.append({"role": "assistant", "content": past_response})
messages.append({"role": "user", "content": query})
try:
model_inputs = tokenizer.apply_chat_template(messages, return_tensors="pt", add_generation_prompt=True)
# 创建注意力掩码
attention_mask = model_inputs.ne(tokenizer.pad_token_id).long()
# 将输入移动到正确的设备(CPU或GPU)
model_inputs = model_inputs.to(device)
attention_mask = attention_mask.to(device)
model_outputs = model.generate(
model_inputs,
attention_mask=attention_mask,
max_new_tokens=1024,
top_p=0.7,
temperature=0.7,
pad_token_id=tokenizer.eos_token_id, # 将pad_token_id设置为eos_token_id
do_sample=True
)
output_token_ids = model_outputs[0][len(model_inputs[0]):]
text_output = tokenizer.decode(output_token_ids, skip_special_tokens=True)
# 更新历史记录
history.append((query, text_output))
redis_client.set(session_id, json.dumps(history))
return ChatResponse(response=text_output, history=history)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/start_chat")
async def start_chat():
session_id = str(uuid.uuid4())
return {"session_id": session_id}
if __name__ == '__main__':
# 启动Kafka消费者线程
start_kafka_consumers()
# 启动FastAPI服务器
uvicorn.run(app, host="0.0.0.0", port=6003)
+170
View File
@@ -0,0 +1,170 @@
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import requests
import json
from typing import List, Tuple
from kafka import KafkaConsumer, TopicPartition
from concurrent.futures import ThreadPoolExecutor
import threading
import asyncio
import redis
import uuid
import logging
import uvicorn
from dotenv import load_dotenv
import os
import torch
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = FastAPI()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
torch.cuda.set_device(device)
print(f"Using device: {device}")
# 加载 .env 文件
load_dotenv()
# CORS 配置
ALLOWED_ORIGINS = os.getenv('ALLOWED_ORIGINS').split(',')
# 添加CORS中间件
app.add_middleware(
CORSMiddleware,
allow_origins=ALLOWED_ORIGINS,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Kafka 设置
KAFKA_BROKER = os.getenv('KAFKA_BROKER')
KAFKA_TOPIC = os.getenv('KAFKA_CHAT_TOPIC')
KAFKA_CONSUMER_GROUP = 'chat_group'
KAFKA_CONSUMER_NUM = 1
# Redis 设置
REDIS_HOST = os.getenv('REDIS_HOST')
REDIS_PORT = int(os.getenv('REDIS_PORT'))
REDIS_DB = int(os.getenv('REDIS_CHAT_DB'))
REDIS_PASSWORD = os.getenv('REDIS_PASSWORD')
# 创建Redis客户端
redis_client = redis.Redis(
host=REDIS_HOST,
port=REDIS_PORT,
db=REDIS_DB,
password=REDIS_PASSWORD # 使用密码进行认证
)
# 创建Kafka消费者
def create_kafka_consumer():
return KafkaConsumer(
bootstrap_servers=KAFKA_BROKER,
auto_offset_reset='earliest',
enable_auto_commit=True,
group_id=KAFKA_CONSUMER_GROUP,
value_deserializer=lambda x: json.loads(x.decode('utf-8'))
)
# Kafka消费者函数
def kafka_consumer(consumer, consumer_id):
# 获取消费者分配的分区
consumer.subscribe([KAFKA_TOPIC])
partitions = consumer.assignment()
logger.info(f"消费者 {consumer_id} 被分配了以下分区: {[p.partition for p in partitions]}")
for message in consumer:
partition = message.partition
offset = message.offset
chat_request = message.value # 直接使用 message.value,它已经是一个字典
session_id = chat_request['session_id']
query = chat_request['query']
logger.info(f"消费者 {consumer_id} 正在处理来自分区 {partition} 的消息:")
asyncio.run(process_chat_request(chat_request))
# 启动Kafka消费者线程
def start_kafka_consumers(num_consumers=KAFKA_CONSUMER_NUM):
consumers = []
for i in range(num_consumers):
consumer = create_kafka_consumer()
consumer_thread = threading.Thread(target=kafka_consumer, args=(consumer, i), daemon=True)
consumer_thread.start()
consumers.append((consumer, consumer_thread))
return consumers
DEFAULT_SYSTEM_PROMPT = "你是一个智能助手,请用尽可能简短、简洁、友好的方式回答问题。输入的所有内容都来自于语音识别输入,因此可能会出现各种错误,请尽可能猜测用户的意思"
class ChatRequest(BaseModel):
session_id: str
query: str
model: str = "qwen2.5:3b"
class ChatResponse(BaseModel):
response: str
history: List[Tuple[str, str]]
# 处理聊天请求的异步函数
async def process_chat_request(chat_request):
try:
response = await chat(ChatRequest(**chat_request))
print(f"Processed message for session {chat_request['session_id']}: {response}")
except Exception as e:
print(f"Error processing message for session {chat_request['session_id']}: {str(e)}")
@app.post("/chat", response_model=ChatResponse)
async def chat(request: ChatRequest):
session_id = request.session_id
query = request.query
model = request.model
# 从Redis获取历史记录
history = json.loads(redis_client.get(session_id) or '[]')
# 构建包含历史对话的完整提示
full_prompt = DEFAULT_SYSTEM_PROMPT + "\n"
for past_query, past_response in history:
full_prompt += f"用户: {past_query}\n助手: {past_response}\n"
full_prompt += f"用户: {query}"
data = {
"model": model,
"prompt": full_prompt,
"stream": True,
"temperature": 0
}
try:
response = requests.post("http://127.0.0.1:11434/api/generate", json=data, stream=True)
response.raise_for_status()
text_output = ""
for line in response.iter_lines():
if line:
json_data = json.loads(line)
if 'response' in json_data:
text_output += json_data['response']
# 更新历史记录
history.append((query, text_output))
redis_client.set(session_id, json.dumps(history))
return ChatResponse(response=text_output, history=history)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/start_chat")
async def start_chat():
session_id = str(uuid.uuid4())
return {"session_id": session_id}
if __name__ == '__main__':
# 启动Kafka消费者线程
start_kafka_consumers()
# 启动FastAPI服务器
uvicorn.run(app, host="0.0.0.0", port=6001)
+319
View File
@@ -0,0 +1,319 @@
from fastapi import FastAPI, HTTPException, Depends, Security, File, UploadFile, WebSocket, WebSocketDisconnect
from fastapi.middleware.cors import CORSMiddleware
from fastapi.security import APIKeyHeader
from pydantic import BaseModel
from kafka import KafkaProducer
from redis import Redis
import os
import json
import uuid
from datetime import datetime, timezone
from dotenv import load_dotenv
import tempfile
import hashlib
import asyncio
# 加载 .env 文件
load_dotenv()
app = FastAPI()
v1_chat_app = FastAPI()
app.mount("/v1_chat", v1_chat_app)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# 配置
KAFKA_BROKER = os.getenv('KAFKA_BROKER')
REDIS_HOST = os.getenv('REDIS_HOST')
REDIS_PORT = int(os.getenv('REDIS_PORT'))
REDIS_PASSWORD = os.getenv('REDIS_PASSWORD')
REDIS_TTS_DB = int(os.getenv('REDIS_TTS_DB'))
REDIS_ASR_DB = int(os.getenv('REDIS_ASR_DB'))
REDIS_CHAT_DB = int(os.getenv('REDIS_CHAT_DB'))
REDIS_API_DB = int(os.getenv('REDIS_API_DB'))
REDIS_API_USAGE_DB = int(os.getenv('REDIS_API_USAGE_DB'))
REDIS_TASK_DB = int(os.getenv('REDIS_TASK_DB'))
KAFKA_TTS_TOPIC = os.getenv('KAFKA_TTS_TOPIC')
KAFKA_ASR_TOPIC = os.getenv('KAFKA_ASR_TOPIC')
KAFKA_CHAT_TOPIC = os.getenv('KAFKA_CHAT_TOPIC')
# 初始化 Kafka Producer
producer = KafkaProducer(
bootstrap_servers=[KAFKA_BROKER],
value_serializer=lambda v: json.dumps(v).encode('utf-8')
)
# 初始化 Redis
redis_tts_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_TTS_DB)
redis_asr_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_ASR_DB)
redis_chat_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_CHAT_DB)
redis_api_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_API_DB)
redis_api_usage_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_API_USAGE_DB)
redis_task_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_TASK_DB)
# 定义API密钥头部
api_key_header = APIKeyHeader(name="Authorization", auto_error=False)
def get_audio_hash(text):
return hashlib.md5(text.encode()).hexdigest()
# 验证API密钥的函数
async def get_api_key(api_key: str = Security(api_key_header)):
if api_key and api_key.startswith("Bearer "):
key = api_key.split(" ")[1]
if key.startswith("obs-"):
return key
raise HTTPException(
status_code=401,
detail="无效的API密钥",
headers={"WWW-Authenticate": "Bearer"},
)
async def verify_api_key(api_key: str = Depends(get_api_key)):
redis_key = f"api_key:{api_key}"
api_key_info = redis_api_client.hgetall(redis_key)
if not api_key_info:
raise HTTPException(status_code=401, detail="无效的API密钥")
api_key_info = {k.decode('utf-8'): v.decode('utf-8') for k, v in api_key_info.items()}
if api_key_info.get('is_active') != '1':
raise HTTPException(status_code=401, detail="API密钥已停用")
expires_at = datetime.fromisoformat(api_key_info.get('expires_at'))
if datetime.now(timezone.utc) > expires_at:
raise HTTPException(status_code=401, detail="API密钥已过期")
usage_info = redis_api_usage_client.hgetall(redis_key)
usage_info = {k.decode('utf-8'): v.decode('utf-8') for k, v in usage_info.items()}
return {
"api_key": api_key,
**api_key_info,
**usage_info
}
async def update_token_usage(api_key: str, new_tokens_used: int, model_name: str):
redis_key = f"api_key:{api_key}"
current_time = datetime.now(timezone.utc).isoformat()
pipe = redis_api_usage_client.pipeline()
pipe.hincrby(redis_key, "tokens_used", new_tokens_used)
pipe.hset(redis_key, "last_used_at", current_time)
model_tokens_field = f"{model_name}_tokens_used"
model_last_used_field = f"{model_name}_last_used_at"
pipe.hincrby(redis_key, model_tokens_field, new_tokens_used)
pipe.hset(redis_key, model_last_used_field, current_time)
pipe.execute()
async def process_request(api_key_info: dict, model_name: str, tokens_required: int, task_data: dict, kafka_topic: str):
api_key = api_key_info['api_key']
usage_key = f"api_key:{api_key}"
total_tokens = int(redis_api_usage_client.hget(usage_key, "total_tokens") or 0)
tokens_used = int(redis_api_usage_client.hget(usage_key, "tokens_used") or 0)
if tokens_used + tokens_required > total_tokens:
raise HTTPException(status_code=403, detail="Token 余额不足")
# 更新 token 使用量
await update_token_usage(api_key, tokens_required, model_name)
# 发送任务到Kafka
producer.send(kafka_topic, task_data)
# 获取更新后的 token 使用情况
updated_api_key_info = await verify_api_key(api_key)
new_tokens_used = int(updated_api_key_info.get("tokens_used", 0))
model_tokens_used = int(updated_api_key_info.get(f"{model_name}_tokens_used", 0))
return {
"message": f"{model_name.upper()}请求已排队等待处理",
"tokens_used": tokens_required,
"total_tokens_used": new_tokens_used,
f"{model_name}_tokens_used": model_tokens_used,
"tokens_remaining": total_tokens - new_tokens_used
}
class TTSRequest(BaseModel):
text: str
class ChatRequest(BaseModel):
session_id: str
query: str
model: str = "qwen2.5:3b"
# 添加WebSocket连接管理
class ConnectionManager:
def __init__(self):
self.active_connections = {}
async def connect(self, websocket: WebSocket, client_id: str):
await websocket.accept()
self.active_connections[client_id] = websocket
def disconnect(self, client_id: str):
self.active_connections.pop(client_id, None)
async def send_message(self, message: str, client_id: str):
if client_id in self.active_connections:
await self.active_connections[client_id].send_text(message)
manager = ConnectionManager()
@v1_chat_app.websocket("/ws/{client_id}")
async def websocket_endpoint(websocket: WebSocket, client_id: str):
await manager.connect(websocket, client_id)
try:
while True:
await websocket.receive_text()
except WebSocketDisconnect:
manager.disconnect(client_id)
# 修改TTS请求处理函数
@v1_chat_app.post("/tts")
async def tts_request(request: TTSRequest, api_key_info: dict = Depends(verify_api_key)):
task_id = str(uuid.uuid4())
task_data = {
"task_id": task_id,
"text": request.text,
"status": "queued",
"created_at": datetime.now(timezone.utc).isoformat(),
}
redis_task_client.set(f"task_status:{task_id}", "queued")
result = await process_request(api_key_info, "tts", 100, task_data, KAFKA_TTS_TOPIC)
result["task_id"] = task_id
# 将任务ID存储到Redis,以便后续WebSocket通信使用
redis_task_client.set(f"task_client:{task_id}", api_key_info['api_key'])
return result
# 修改ASR请求处理函数
@v1_chat_app.post("/asr")
async def asr_request(audio: UploadFile = File(...), api_key_info: dict = Depends(verify_api_key)):
task_id = str(uuid.uuid4())
UPLOAD_DIR = "/obscura/task/audio_upload"
os.makedirs(UPLOAD_DIR, exist_ok=True)
file_path = os.path.join(UPLOAD_DIR, f"{task_id}.wav")
with open(file_path, "wb") as temp_audio:
content = await audio.read()
temp_audio.write(content)
task_data = {
'file_path': file_path,
'task_id': task_id,
'status': 'queued'
}
redis_task_client.set(f"task_status:{task_id}", "queued")
result = await process_request(api_key_info, "asr", 100, task_data, KAFKA_ASR_TOPIC)
result["task_id"] = task_id
# 将任务ID存储到Redis,以便后续WebSocket通信使用
redis_task_client.set(f"task_client:{task_id}", api_key_info['api_key'])
return result
# 修改聊天请求处理函数
@v1_chat_app.post("/chat")
async def chat_request(request: ChatRequest, api_key_info: dict = Depends(verify_api_key)):
task_id = str(uuid.uuid4())
task_data = {
"task_id": task_id,
"session_id": request.session_id,
"query": request.query,
"model": request.model,
"status": "queued",
"created_at": datetime.now(timezone.utc).isoformat(),
}
redis_task_client.set(f"task_status:{task_id}", "queued")
result = await process_request(api_key_info, "chat", 100, task_data, KAFKA_CHAT_TOPIC)
result["task_id"] = task_id
# 将任务ID存储到Redis,以便后续WebSocket通信使用
redis_task_client.set(f"task_client:{task_id}", api_key_info['api_key'])
return result
@v1_chat_app.get("/chat_result/{task_id}")
async def get_chat_result(task_id: str, api_key_info: dict = Depends(verify_api_key)):
# 从Redis任务数据库获取任务状态
task_status = redis_task_client.get(f"task_status:{task_id}")
if task_status:
status = task_status.decode('utf-8')
if status == "completed":
# 从Redis聊天结果数据库获取聊天结果
chat_result = redis_chat_client.get(task_id)
if chat_result:
result = json.loads(chat_result)
return {
"status": "completed",
"history": result # 直接返回整个历史记录
}
return {"status": status}
return {"status": "not_found"}
@v1_chat_app.get("/tts_result/{task_id}")
async def get_tts_result(task_id: str, api_key_info: dict = Depends(verify_api_key)):
# 从Redis任务数据库获取任务状态
task_status = redis_task_client.get(f"task_status:{task_id}")
if task_status:
status = task_status.decode('utf-8')
if status == "completed":
# 从Redis TTS结果数据库获取音频文件路径
audio_info = redis_tts_client.get(task_id)
if audio_info:
audio_path = json.loads(audio_info)['path']
return {
"status": "completed",
"audio_path": audio_path
}
return {"status": status}
return {"status": "not_found"}
@v1_chat_app.get("/asr_result/{task_id}")
async def get_asr_result(task_id: str, api_key_info: dict = Depends(verify_api_key)):
# 从Redis任务数据库获取任务状态
task_status = redis_task_client.get(f"task_status:{task_id}")
if task_status:
status = task_status.decode('utf-8')
if status == "completed":
# 从Redis ASR结果数据库获取转录结果
transcription = redis_asr_client.get(task_id)
if transcription:
return {
"status": "completed",
"transcription": transcription.decode('utf-8')
}
return {"status": status}
return {"status": "not_found"}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8008)
+180
View File
@@ -0,0 +1,180 @@
import os
import soundfile as sf
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse
from pydantic import BaseModel, Field
import uvicorn
import redis
import hashlib
import json
from kafka import KafkaProducer, KafkaConsumer
import threading
import time
from tools.i18n.i18n import I18nAuto
from GPT_SoVITS.inference_webui import change_gpt_weights, change_sovits_weights, get_tts_wav
from dotenv import load_dotenv
import os
import torch
# 加载 .env 文件
load_dotenv()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# FastAPI configuration
app = FastAPI()
i18n = I18nAuto()
# CORS configuration
ALLOWED_ORIGINS = os.getenv('ALLOWED_ORIGINS').split(',')
# Redis configuration
REDIS_HOST = os.getenv('REDIS_HOST')
REDIS_PORT = int(os.getenv('REDIS_PORT'))
REDIS_DB = int(os.getenv('REDIS_TTS_DB'))
REDIS_PASSWORD = os.getenv('REDIS_PASSWORD')
# Kafka configuration
KAFKA_BROKER = os.getenv('KAFKA_BROKER')
KAFKA_TOPIC = os.getenv('KAFKA_TTS_TOPIC')
# KAFKA_GROUP_ID = 'tts_group'
KAFKA_CONSUMER_THREADS = 1
# TTS configuration
GPT_MODEL_PATH = os.getenv('GPT_MODEL_PATH')
SOVITS_MODEL_PATH = os.getenv('SOVITS_MODEL_PATH')
REF_AUDIO_PATH = os.getenv('REF_AUDIO_PATH')
REF_TEXT_PATH = os.getenv('REF_TEXT_PATH')
REF_LANGUAGE = os.getenv('REF_LANGUAGE')
TARGET_LANGUAGE = os.getenv('TARGET_LANGUAGE')
OUTPUT_PATH = os.getenv('OUTPUT_PATH')
# Initialize FastAPI CORS
app.add_middleware(
CORSMiddleware,
allow_origins=ALLOWED_ORIGINS,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Initialize Redis client
redis_client = redis.Redis(
host=REDIS_HOST,
port=REDIS_PORT,
db=REDIS_DB,
password=REDIS_PASSWORD
)
# Initialize Kafka producer
kafka_producer = KafkaProducer(bootstrap_servers=KAFKA_BROKER)
class TTSRequest(BaseModel):
text: str = Field(..., alias="text")
def get_audio_hash(text):
return hashlib.md5(text.encode()).hexdigest()
# Initialize models at startup
print("Initializing models...")
change_gpt_weights(gpt_path=GPT_MODEL_PATH)
change_sovits_weights(sovits_path=SOVITS_MODEL_PATH)
# Read reference text
with open(REF_TEXT_PATH, 'r', encoding='utf-8') as file:
ref_text = file.read()
print("Models initialized successfully.")
def synthesize(target_text, output_path):
# Synthesize audio
with torch.cuda.device(device):
synthesis_result = get_tts_wav(ref_wav_path=REF_AUDIO_PATH,
prompt_text=ref_text,
prompt_language=i18n(REF_LANGUAGE),
text=target_text,
text_language=i18n(TARGET_LANGUAGE), top_p=1, temperature=1)
result_list = list(synthesis_result)
if result_list:
last_sampling_rate, last_audio_data = result_list[-1]
audio_hash = get_audio_hash(target_text)
output_wav_path = os.path.join(output_path, f"{audio_hash}.wav")
sf.write(output_wav_path, last_audio_data, last_sampling_rate)
return output_wav_path
else:
return None
@app.post("/tts")
async def synthesize_audio(request: TTSRequest):
try:
print(f"Received TTS request: {request.dict()}")
target_text = request.text
audio_hash = get_audio_hash(target_text)
# Check Redis cache
cached_audio = redis_client.get(audio_hash)
if cached_audio:
audio_info = json.loads(cached_audio)
return FileResponse(audio_info['path'], media_type="audio/wav")
# Check file system
file_path = os.path.join(OUTPUT_PATH, f"{audio_hash}.wav")
if os.path.exists(file_path):
# Cache the file path in Redis
redis_client.set(audio_hash, json.dumps({"path": file_path}))
return FileResponse(file_path, media_type="audio/wav")
# Send message to Kafka
kafka_producer.send(KAFKA_TOPIC, json.dumps({
'text': target_text,
'audio_hash': audio_hash
}).encode('utf-8'))
# Wait for the audio to be generated (you might want to implement a more sophisticated waiting mechanism)
for _ in range(60): # Wait for up to 30 seconds
if os.path.exists(file_path):
return FileResponse(file_path, media_type="audio/wav")
time.sleep(1)
# If audio is not generated within the timeout
raise HTTPException(status_code=504, detail="Audio generation timed out")
except Exception as e:
print(f"Error processing TTS request: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/")
async def root():
return {"message": "TTS API is running"}
def kafka_consumer_thread():
consumer = KafkaConsumer(
KAFKA_TOPIC,
bootstrap_servers=KAFKA_BROKER,
# group_id=KAFKA_GROUP_ID,
auto_offset_reset='latest',
value_deserializer=lambda m: json.loads(m.decode('utf-8'))
)
for message in consumer:
target_text = message.value['text']
audio_hash = message.value['audio_hash']
output_path = synthesize(target_text, OUTPUT_PATH)
if output_path:
redis_client.set(audio_hash, json.dumps({"path": output_path}))
print(f"Audio synthesized successfully: {output_path}")
else:
print("Failed to synthesize audio")
if __name__ == "__main__":
# Start Kafka consumer threads
torch.cuda.set_device(device)
for _ in range(KAFKA_CONSUMER_THREADS):
consumer_thread = threading.Thread(target=kafka_consumer_thread)
consumer_thread.start()
uvicorn.run(app, host="0.0.0.0", port=6002)
+180
View File
@@ -0,0 +1,180 @@
import os
import soundfile as sf
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse
from pydantic import BaseModel, Field
import uvicorn
import redis
import hashlib
import json
from kafka import KafkaProducer, KafkaConsumer
import threading
import time
from tools.i18n.i18n import I18nAuto
from GPT_SoVITS.inference_webui import change_gpt_weights, change_sovits_weights, get_tts_wav
from dotenv import load_dotenv
import os
import torch
# 加载 .env 文件
load_dotenv()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# FastAPI configuration
app = FastAPI()
i18n = I18nAuto()
# CORS configuration
ALLOWED_ORIGINS = os.getenv('ALLOWED_ORIGINS').split(',')
# Redis configuration
REDIS_HOST = os.getenv('REDIS_HOST')
REDIS_PORT = int(os.getenv('REDIS_PORT'))
REDIS_DB = int(os.getenv('REDIS_TTS_DB'))
REDIS_PASSWORD = os.getenv('REDIS_PASSWORD')
# Kafka configuration
KAFKA_BROKER = os.getenv('KAFKA_BROKER')
KAFKA_TOPIC = os.getenv('KAFKA_TTS_TOPIC')
# KAFKA_GROUP_ID = 'tts_group'
KAFKA_CONSUMER_THREADS = 1
# TTS configuration
GPT_MODEL_PATH = os.getenv('GPT_MODEL_PATH')
SOVITS_MODEL_PATH = os.getenv('SOVITS_MODEL_PATH')
REF_AUDIO_PATH = os.getenv('REF_AUDIO_KO_PATH')
REF_TEXT_PATH = os.getenv('REF_TEXT_KO_PATH')
REF_LANGUAGE = os.getenv('REF_KO_LANGUAGE')
TARGET_LANGUAGE = os.getenv('TARGET_LANGUAGE')
OUTPUT_PATH = os.getenv('OUTPUT_PATH')
# Initialize FastAPI CORS
app.add_middleware(
CORSMiddleware,
allow_origins=ALLOWED_ORIGINS,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Initialize Redis client
redis_client = redis.Redis(
host=REDIS_HOST,
port=REDIS_PORT,
db=REDIS_DB,
password=REDIS_PASSWORD
)
# Initialize Kafka producer
kafka_producer = KafkaProducer(bootstrap_servers=KAFKA_BROKER)
class TTSRequest(BaseModel):
text: str = Field(..., alias="text")
def get_audio_hash(text):
return hashlib.md5(text.encode()).hexdigest()
# Initialize models at startup
print("Initializing models...")
change_gpt_weights(gpt_path=GPT_MODEL_PATH)
change_sovits_weights(sovits_path=SOVITS_MODEL_PATH)
# Read reference text
with open(REF_TEXT_PATH, 'r', encoding='utf-8') as file:
ref_text = file.read()
print("Models initialized successfully.")
def synthesize(target_text, output_path):
# Synthesize audio
with torch.cuda.device(device):
synthesis_result = get_tts_wav(ref_wav_path=REF_AUDIO_PATH,
prompt_text=ref_text,
prompt_language=i18n(REF_LANGUAGE),
text=target_text,
text_language=i18n(TARGET_LANGUAGE), top_p=1, temperature=1)
result_list = list(synthesis_result)
if result_list:
last_sampling_rate, last_audio_data = result_list[-1]
audio_hash = get_audio_hash(target_text)
output_wav_path = os.path.join(output_path, f"{audio_hash}.wav")
sf.write(output_wav_path, last_audio_data, last_sampling_rate)
return output_wav_path
else:
return None
@app.post("/tts_ko")
async def synthesize_audio(request: TTSRequest):
try:
print(f"Received TTS request: {request.dict()}")
target_text = request.text
audio_hash = get_audio_hash(target_text)
# Check Redis cache
cached_audio = redis_client.get(audio_hash)
if cached_audio:
audio_info = json.loads(cached_audio)
return FileResponse(audio_info['path'], media_type="audio/wav")
# Check file system
file_path = os.path.join(OUTPUT_PATH, f"{audio_hash}.wav")
if os.path.exists(file_path):
# Cache the file path in Redis
redis_client.set(audio_hash, json.dumps({"path": file_path}))
return FileResponse(file_path, media_type="audio/wav")
# Send message to Kafka
kafka_producer.send(KAFKA_TOPIC, json.dumps({
'text': target_text,
'audio_hash': audio_hash
}).encode('utf-8'))
# Wait for the audio to be generated (you might want to implement a more sophisticated waiting mechanism)
for _ in range(60): # Wait for up to 30 seconds
if os.path.exists(file_path):
return FileResponse(file_path, media_type="audio/wav")
time.sleep(1)
# If audio is not generated within the timeout
raise HTTPException(status_code=504, detail="Audio generation timed out")
except Exception as e:
print(f"Error processing TTS request: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/")
async def root():
return {"message": "TTS API is running"}
def kafka_consumer_thread():
consumer = KafkaConsumer(
KAFKA_TOPIC,
bootstrap_servers=KAFKA_BROKER,
# group_id=KAFKA_GROUP_ID,
auto_offset_reset='latest',
value_deserializer=lambda m: json.loads(m.decode('utf-8'))
)
for message in consumer:
target_text = message.value['text']
audio_hash = message.value['audio_hash']
output_path = synthesize(target_text, OUTPUT_PATH)
if output_path:
redis_client.set(audio_hash, json.dumps({"path": output_path}))
print(f"Audio synthesized successfully: {output_path}")
else:
print("Failed to synthesize audio")
if __name__ == "__main__":
# Start Kafka consumer threads
torch.cuda.set_device(device)
for _ in range(KAFKA_CONSUMER_THREADS):
consumer_thread = threading.Thread(target=kafka_consumer_thread)
consumer_thread.start()
uvicorn.run(app, host="0.0.0.0", port=6003)
+176
View File
@@ -0,0 +1,176 @@
# 导入所需的库
import os
import soundfile as sf
import redis
import hashlib
import json
from kafka import KafkaConsumer
from tools.i18n.i18n import I18nAuto
from GPT_SoVITS.inference_webui import change_gpt_weights, change_sovits_weights, get_tts_wav
from dotenv import load_dotenv
import torch
"""
整体设计说明:
这个脚本实现了一个文本到语音(TTS)的服务。它使用Kafka作为消息队列接收TTS任务,
使用Redis存储任务状态和结果,并利用GPT-SoVITS模型进行语音合成。
主要功能包括:
1. 初始化配置和模型
2. 提供语音合成功能
3. 监听Kafka消息并处理TTS任务
4. 将合成结果存储到Redis并更新任务状态
"""
# 加载环境变量
load_dotenv()
# 设置GPU设备(如果可用)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device}")
# 从环境变量中读取Redis配置
REDIS_HOST = os.getenv('REDIS_HOST')
REDIS_PORT = int(os.getenv('REDIS_PORT'))
REDIS_TTS_DB = int(os.getenv('REDIS_TTS_DB')) # DB 2用于存储TTS结果
REDIS_TASK_DB = int(os.getenv('REDIS_TASK_DB')) # DB 3用于存储任务状态
REDIS_PASSWORD = os.getenv('REDIS_PASSWORD')
# 从环境变量中读取Kafka配置
KAFKA_BROKER = os.getenv('KAFKA_BROKER')
KAFKA_TTS_TOPIC = os.getenv('KAFKA_TTS_TOPIC')
# 从环境变量中读取TTS相关配置
GPT_MODEL_PATH = os.getenv('GPT_MODEL_PATH')
SOVITS_MODEL_PATH = os.getenv('SOVITS_MODEL_PATH')
REF_AUDIO_PATH = os.getenv('REF_AUDIO_ZN_PATH')
REF_TEXT_PATH = os.getenv('REF_TEXT_ZN_PATH')
REF_LANGUAGE = os.getenv('REF_LANGUAGE')
TARGET_LANGUAGE = os.getenv('TARGET_LANGUAGE')
OUTPUT_PATH = os.getenv('OUTPUT_PATH')
# 初始化Redis客户端
redis_client = redis.Redis(
host=REDIS_HOST,
port=REDIS_PORT,
db=REDIS_TTS_DB,
password=REDIS_PASSWORD
)
redis_task_client = redis.Redis(
host=REDIS_HOST,
port=REDIS_PORT,
db=REDIS_TASK_DB,
password=REDIS_PASSWORD
)
# 初始化国际化工具
i18n = I18nAuto()
def get_audio_hash(text):
"""
生成文本的MD5哈希值,用作音频文件名的一部分
参数:
text (str): 需要生成哈希的文本
返回:
str: 文本的MD5哈希值
"""
return hashlib.md5(text.encode()).hexdigest()
# 初始化模型
print("正在初始化模型...")
change_gpt_weights(gpt_path=GPT_MODEL_PATH)
change_sovits_weights(sovits_path=SOVITS_MODEL_PATH)
# 读取参考文本
with open(REF_TEXT_PATH, 'r', encoding='utf-8') as file:
ref_text = file.read()
print("模型初始化成功。")
def synthesize(target_text, output_wav_path):
"""
使用GPT-SoVITS模型合成语音
参数:
target_text (str): 需要合成语音的目标文本
output_wav_path (str): 输出音频文件的路径
返回:
str: 如果成功,返回输出音频文件的路径;如果失败,返回None
"""
with torch.cuda.device(device):
synthesis_result = get_tts_wav(ref_wav_path=REF_AUDIO_PATH,
prompt_text=ref_text,
prompt_language=i18n(REF_LANGUAGE),
text=target_text,
text_language=i18n(TARGET_LANGUAGE), top_p=1, temperature=1)
result_list = list(synthesis_result)
if result_list:
last_sampling_rate, last_audio_data = result_list[-1]
sf.write(output_wav_path, last_audio_data, last_sampling_rate)
return output_wav_path
else:
return None
def kafka_consumer():
"""
Kafka消费者函数,用于接收和处理TTS任务
该函数会持续监听Kafka的TTS主题,接收任务并进行处理:
1. 接收任务信息
2. 更新任务状态
3. 调用synthesize函数合成语音
4. 将结果保存到Redis
5. 更新任务完成状态
"""
consumer = KafkaConsumer(
KAFKA_TTS_TOPIC,
bootstrap_servers=KAFKA_BROKER,
auto_offset_reset='latest',
value_deserializer=lambda m: json.loads(m.decode('utf-8'))
)
print(f"TTS消费者已启动")
for message in consumer:
try:
task_id = message.value['task_id']
target_text = message.value['text']
text_hash = message.value['text_hash']
# 更新任务状态为 "processing"
redis_task_client.set(f"task_status:tts:{task_id}", "processing")
output_wav_path = os.path.join(OUTPUT_PATH, f"{text_hash}.wav")
# 再次检查文件是否存在(以防在此期间被其他进程创建)
if not os.path.exists(output_wav_path):
output_path = synthesize(target_text, output_wav_path)
else:
output_path = output_wav_path
if output_path:
# 将结果保存在 DB 2
redis_client.set(f"tts:{task_id}", json.dumps({"path": output_path}))
print(f"音频合成成功: {output_path}")
# 更新任务状态为 "completed"
redis_task_client.set(f"task_status:tts:{task_id}", "completed")
else:
print("音频合成失败")
# 更新任务状态为 "failed"
redis_task_client.set(f"task_status:tts:{task_id}", "failed")
except Exception as e:
print(f"处理消息时出错: {str(e)}")
# 更新任务状态为 "failed"
redis_task_client.set(f"task_status:tts:{task_id}", "failed")
if __name__ == "__main__":
# 设置CUDA设备
torch.cuda.set_device(device)
# 启动Kafka消费者
kafka_consumer()
+186
View File
@@ -0,0 +1,186 @@
from fastapi import FastAPI, File, UploadFile, HTTPException, WebSocket
from fastapi.middleware.cors import CORSMiddleware
import whisper
import tempfile
import uvicorn
from kafka import KafkaProducer, KafkaConsumer
from starlette.websockets import WebSocketDisconnect
import json
import threading
import os
import uuid
import asyncio
import logging
import redis
from dotenv import load_dotenv
# 设置要使用的GPU ID
GPU_ID = 1 # 修改这个值来选择要使用的GPU
# 设置CUDA_VISIBLE_DEVICES环境变量
os.environ["CUDA_VISIBLE_DEVICES"] = str(GPU_ID)
# 设置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = FastAPI()
load_dotenv()
# CORS 配置
ALLOWED_ORIGINS = os.getenv('ALLOWED_ORIGINS').split(',')
# 添加CORS中间件
app.add_middleware(
CORSMiddleware,
allow_origins=ALLOWED_ORIGINS,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
print("正在加载Whisper模型...")
model = whisper.load_model("large-v3")
print("Whisper模型加载完成。")
# Kafka配置
KAFKA_BROKER = os.getenv('KAFKA_BROKER')
KAFKA_TOPIC = os.getenv('KAFKA_ASR_TOPIC')
# Redis配置
REDIS_HOST = os.getenv('REDIS_HOST')
REDIS_PORT = int(os.getenv('REDIS_PORT'))
REDIS_DB = int(os.getenv('REDIS_ASR_DB'))
REDIS_PASSWORD = os.getenv('REDIS_PASSWORD')
# 创建Redis客户端
redis_client = redis.Redis(
host=REDIS_HOST,
port=REDIS_PORT,
db=REDIS_DB,
password=REDIS_PASSWORD # 添加密码
)
# Kafka生产者
producer = KafkaProducer(
bootstrap_servers=[KAFKA_BROKER],
value_serializer=lambda v: json.dumps(v).encode('utf-8')
)
# 存储WebSocket连接的字典
active_connections = {}
@app.websocket("/asr/ws/{client_id}")
async def websocket_endpoint(websocket: WebSocket, client_id: str):
await websocket.accept()
active_connections[client_id] = websocket
try:
while True:
try:
# 设置接收超时
data = await asyncio.wait_for(websocket.receive_text(), timeout=30)
if data == "ping":
await websocket.send_text("pong")
else:
await websocket.send_text(f"收到消息: {data}")
except asyncio.TimeoutError:
try:
# 发送心跳
await websocket.send_text("heartbeat")
except WebSocketDisconnect:
logger.info(f"客户端 {client_id} 断开连接")
break
except WebSocketDisconnect:
logger.info(f"客户端 {client_id} 断开连接")
except Exception as e:
logger.error(f"WebSocket错误: {e}")
finally:
if client_id in active_connections:
del active_connections[client_id]
@app.post("/asr")
async def transcribe(audio: UploadFile = File(...)):
if not audio:
raise HTTPException(status_code=400, detail="未提供音频文件")
client_id = str(uuid.uuid4())
# 生成缓存键
cache_key = f"asr:{audio.filename}:{client_id}"
# 检查缓存
cached_result = redis_client.get(cache_key)
if cached_result:
logger.info(f"缓存命中: {cache_key}")
return {"message": "从缓存获取转录结果", "transcription": cached_result.decode('utf-8'), "client_id": client_id}
with tempfile.NamedTemporaryFile(delete=False, suffix='.wav') as temp_audio:
content = await audio.read()
temp_audio.write(content)
temp_audio.flush()
task = {
'file_path': temp_audio.name,
'client_id': client_id,
'cache_key': cache_key
}
producer.send(KAFKA_TOPIC, value=task)
producer.flush()
logger.info(f"发送任务到Kafka: {task}")
return {"message": "音频文件已接收并发送任务进行处理", "client_id": client_id}
async def send_transcription(client_id: str, transcription: str):
if client_id in active_connections:
websocket = active_connections[client_id]
await websocket.send_json({"transcription": transcription})
else:
logger.warning(f"客户端 {client_id} 的WebSocket连接不存在")
def kafka_consumer(consumer_id):
consumer = KafkaConsumer(
KAFKA_TOPIC,
bootstrap_servers=[KAFKA_BROKER],
value_deserializer=lambda x: json.loads(x.decode('utf-8')),
group_id='asr_group',
max_poll_interval_ms=300000
)
for message in consumer:
try:
task = message.value
file_path = task.get('file_path')
client_id = task.get('client_id')
cache_key = task.get('cache_key')
if not file_path or not client_id or not cache_key:
logger.error(f"消费者 {consumer_id} 收到无效任务: {task}")
consumer.commit()
continue
result = model.transcribe(file_path)
logger.info(f"消费者 {consumer_id} 处理了文件: {file_path}")
logger.info(f"转录结果: {result['text']}")
# 将结果存入Redis缓存
redis_client.setex(cache_key, 3600, result['text']) # 缓存1小时
asyncio.run(send_transcription(client_id, result['text']))
os.remove(file_path)
consumer.commit()
except Exception as e:
logger.error(f"消费者 {consumer_id} 处理消息时发生错误: {str(e)}")
def start_consumers(num_consumers=1):
for i in range(num_consumers):
consumer_thread = threading.Thread(target=kafka_consumer, args=(i,))
consumer_thread.start()
if __name__ == '__main__':
start_consumers()
uvicorn.run(app, host="0.0.0.0", port=6000)
+122
View File
@@ -0,0 +1,122 @@
from kafka import KafkaConsumer
import json
import asyncio
import redis
import os
from dotenv import load_dotenv
import requests
from concurrent.futures import ThreadPoolExecutor
# 加载 .env 文件
load_dotenv()
# Kafka 设置
KAFKA_BROKER = os.getenv('KAFKA_BROKER')
KAFKA_CHAT_TOPIC = os.getenv('KAFKA_CHAT_TOPIC')
KAFKA_CONSUMER_GROUP = 'chat_group'
KAFKA_CONSUMER_NUM = 1 # 消费者数量
# Redis 设置
REDIS_HOST = os.getenv('REDIS_HOST')
REDIS_PORT = int(os.getenv('REDIS_PORT'))
REDIS_CHAT_DB = int(os.getenv('REDIS_CHAT_DB'))
REDIS_PASSWORD = os.getenv('REDIS_PASSWORD')
REDIS_TASK_DB = int(os.getenv('REDIS_TASK_DB'))
# 创建Redis客户端
redis_client = redis.Redis(
host=REDIS_HOST,
port=REDIS_PORT,
db=REDIS_CHAT_DB,
password=REDIS_PASSWORD
)
redis_task_client = redis.Redis(
host=REDIS_HOST,
port=REDIS_PORT,
db=REDIS_TASK_DB,
password=REDIS_PASSWORD
)
DEFAULT_SYSTEM_PROMPT = "你是一个智能助手,请用尽可能简短、简洁、友好的方式回答问题。输入的所有内容都来自于语音识别输入,因此可能会出现各种错误,请尽可能猜测用户的意思"
# 创建Kafka消费者
def create_kafka_consumer():
return KafkaConsumer(
KAFKA_CHAT_TOPIC,
bootstrap_servers=KAFKA_BROKER,
auto_offset_reset='latest',
enable_auto_commit=True,
group_id=KAFKA_CONSUMER_GROUP,
value_deserializer=lambda x: json.loads(x.decode('utf-8'))
)
async def process_chat_request(chat_request):
try:
task_id = chat_request['task_id']
session_id = chat_request['session_id']
query = chat_request['query']
model = chat_request.get('model', 'qwen2.5:3b')
# 设置任务状态为 "processing"
redis_task_client.set(f"chat:{task_id}:status", "processing")
# 从Redis获取历史记录 (使用 session_id)
history = json.loads(redis_client.get(f"chat:{session_id}") or '[]')
# 构建包含历史对话的完整提示
full_prompt = DEFAULT_SYSTEM_PROMPT + "\n"
for past_query, past_response in history:
full_prompt += f"用户: {past_query}\n助手: {past_response}\n"
full_prompt += f"用户: {query}\n助手:"
data = {
"model": model,
"prompt": full_prompt,
"stream": True,
"temperature": 0
}
response = requests.post("https://ffgregevrdcfyhtnhyudvr.myfastools.com/api/generate", json=data, stream=True)
response.raise_for_status()
text_output = ""
for line in response.iter_lines():
if line:
json_data = json.loads(line)
if 'response' in json_data:
text_output += json_data['response']
# 更新历史记录 (使用 session_id)
history.append((query, text_output))
redis_client.set(f"chat:{session_id}", json.dumps(history))
# 设置任务状态为 "completed" 并存储响应 (使用 task_id)
redis_task_client.set(f"chat:{task_id}:status", "completed")
redis_task_client.set(f"chat:{task_id}:response", text_output)
# 存储当前任务的结果到 REDIS_TASK_DB (db3)
redis_task_client.set(f"chat:{task_id}:result", json.dumps({"query": query, "response": text_output}))
print(f"处理完成 task_id {task_id}, session_id {session_id}: {text_output}")
except Exception as e:
print(f"处理 task {task_id} 时出错: {str(e)}")
# 设置任务状态为 "error"
redis_task_client.set(f"chat:{task_id}:status", "error")
redis_task_client.set(f"chat:{task_id}:error", str(e))
def kafka_consumer_thread(consumer_id):
consumer = create_kafka_consumer()
print(f"消费者 {consumer_id} 已启动")
for message in consumer:
chat_request = message.value
asyncio.run(process_chat_request(chat_request))
def main():
print("启动Kafka消费者处理聊天请求...")
with ThreadPoolExecutor(max_workers=KAFKA_CONSUMER_NUM) as executor:
for i in range(KAFKA_CONSUMER_NUM):
executor.submit(kafka_consumer_thread, i)
if __name__ == '__main__':
main()
View File
+63
View File
@@ -0,0 +1,63 @@
import os
from moviepy.editor import VideoFileClip
def mp4_to_wav(input_file, output_file):
"""
将MP4文件转换为WAV格式
:param input_file: 输入的MP4文件路径
:param output_file: 输出的WAV文件路径
"""
try:
# 加载视频文件
video = VideoFileClip(input_file)
# 提取音频
audio = video.audio
# 将音频写入WAV文件
audio.write_audiofile(output_file)
# 关闭视频和音频对象
audio.close()
video.close()
print(f"转换成功: {input_file} -> {output_file}")
except Exception as e:
print(f"转换失败: {input_file} - {str(e)}")
def process_directory(directory):
"""
处理目录中的所有MP4文件
:param directory: 包含MP4文件的目录路径
"""
for filename in os.listdir(directory):
if filename.lower().endswith('.mp4'):
input_file = os.path.join(directory, filename)
output_file = os.path.splitext(input_file)[0] + ".wav"
mp4_to_wav(input_file, output_file)
def main():
# 获取输入路径
input_path = input("请输入MP4文件或包含MP4文件的目录路径: ").strip()
# 检查输入路径是否存在
if not os.path.exists(input_path):
print("错误: 输入路径不存在")
return
# 判断输入路径是文件还是目录
if os.path.isfile(input_path):
if not input_path.lower().endswith('.mp4'):
print("错误: 输入文件不是MP4格式")
return
output_file = os.path.splitext(input_path)[0] + ".wav"
mp4_to_wav(input_path, output_file)
elif os.path.isdir(input_path):
process_directory(input_path)
else:
print("错误: 输入路径既不是文件也不是目录")
if __name__ == "__main__":
main()
+136
View File
@@ -0,0 +1,136 @@
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from fastapi.middleware.cors import CORSMiddleware
import httpx
import json
import redis
from typing import List, Dict, Optional
import logging
import ollama
import uuid
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Redis连接
redis_client = redis.Redis(host='222.186.10.253', port=6379, db=14, password="Obscura@2024")
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class GenerateRequest(BaseModel):
model: Optional[str] = "qwen2.5:3b"
prompt: str
class RawGenerateRequest(BaseModel):
model: Optional[str] = "qwen2.5:3b"
prompt: str
system_prompt: Optional[str] = None
stream: Optional[bool] = False
raw: Optional[bool] = False
format: Optional[str] = None
options: Optional[Dict] = None
class GenerateResponse(BaseModel):
response: dict
request_id: str
@app.post("/generate", response_model=GenerateResponse)
async def generate(request: GenerateRequest):
logger.info(f"收到请求: {request}")
request_id = str(uuid.uuid4())
try:
response = ollama.chat(model=request.model, messages=[{"role": "user", "content": request.prompt}])
full_response = response['message']['content']
request_data = {
"model": request.model,
"prompt": request.prompt,
"response": full_response
}
redis_client.set(f"request:{request_id}", json.dumps(request_data))
response_data = {
"response": full_response,
"model": request.model
}
return GenerateResponse(response=response_data, request_id=request_id)
except Exception as e:
logger.error(f"发生错误: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/api/generate")
async def generate_without_history(request: RawGenerateRequest):
"""
处理无历史记录的生成请求。
参数:
- request: RawGenerateRequest对象,包含生成请求的所有参数。
返回:
- 包含生成结果的字典。
"""
try:
response = ollama.generate(
model=request.model,
prompt=request.prompt,
system=request.system_prompt,
format=request.format,
options=request.options,
stream=request.stream
)
response_data = {
"model": request.model,
"response": response['response'],
"done": True,
"context": response.get('context'),
"total_duration": response.get('total_duration'),
"load_duration": response.get('load_duration'),
"prompt_eval_count": response.get('prompt_eval_count'),
"prompt_eval_duration": response.get('prompt_eval_duration'),
"eval_count": response.get('eval_count'),
"eval_duration": response.get('eval_duration')
}
request_id = str(uuid.uuid4())
redis_client.set(f"request:{request_id}", json.dumps(response_data))
return response_data
except Exception as e:
logger.error(f"发生未预期的错误: {e}")
logger.exception("详细错误信息:")
raise HTTPException(status_code=500, detail=f"处理Ollama请求时发生错误: {str(e)}")
@app.get("/request/{request_id}", response_model=Dict)
async def get_request(request_id: str):
request_data = redis_client.get(f"request:{request_id}")
if request_data:
return json.loads(request_data)
raise HTTPException(status_code=404, detail="请求未找到")
@app.get("/models")
async def list_models():
return ollama.list()
@app.get("/models/{model_name}")
async def show_model(model_name: str):
return ollama.show(model_name)
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7000)
+406
View File
@@ -0,0 +1,406 @@
from fastapi import FastAPI, HTTPException, Depends, Security, File, UploadFile
from fastapi.middleware.cors import CORSMiddleware
from fastapi.security import APIKeyHeader
from fastapi.responses import FileResponse
from pydantic import BaseModel
from kafka import KafkaProducer
from redis import Redis
import os
import json
import uuid
from datetime import datetime, timezone
from dotenv import load_dotenv
import tempfile
import hashlib
from pydantic import BaseModel, Field
# 在文件顶部添加这个函数
def get_audio_hash(text):
return hashlib.md5(text.encode()).hexdigest()
# 加载 .env 文件
load_dotenv()
app = FastAPI()
v1_chat_app = FastAPI()
app.mount("/v1_chat", v1_chat_app)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# 配置
KAFKA_BROKER = os.getenv('KAFKA_BROKER')
REDIS_HOST = os.getenv('REDIS_HOST')
REDIS_PORT = int(os.getenv('REDIS_PORT'))
REDIS_PASSWORD = os.getenv('REDIS_PASSWORD')
REDIS_TTS_DB = int(os.getenv('REDIS_TTS_DB'))
REDIS_ASR_DB = int(os.getenv('REDIS_ASR_DB'))
REDIS_CHAT_DB = int(os.getenv('REDIS_CHAT_DB'))
REDIS_API_DB = int(os.getenv('REDIS_API_DB'))
REDIS_API_USAGE_DB = int(os.getenv('REDIS_API_USAGE_DB'))
REDIS_TASK_DB = int(os.getenv('REDIS_TASK_DB'))
# Redis 配置
REDIS_GIRL_DB = int(os.getenv('REDIS_GIRL_DB'))
REDIS_WOMAN_DB = int(os.getenv('REDIS_WOMAN_DB'))
REDIS_MAN_DB = int(os.getenv('REDIS_MAN_DB'))
REDIS_LEIJUN_DB = int(os.getenv('REDIS_LEIJUN_DB'))
REDIS_DUFU_DB = int(os.getenv('REDIS_DUFU_DB'))
REDIS_HEJIONG_DB = int(os.getenv('REDIS_HEJIONG_DB'))
REDIS_MAHUATENG_DB = int(os.getenv('REDIS_MAHUATENG_DB'))
REDIS_LIDAN_DB = int(os.getenv('REDIS_LIDAN_DB'))
REDIS_YUHUA_DB = int(os.getenv('REDIS_YUHUA_DB'))
REDIS_LIUZHENYUN_DB = int(os.getenv('REDIS_LIUZHENYUN_DB'))
REDIS_DABING_DB = int(os.getenv('REDIS_DABING_DB'))
REDIS_LUOXIANG_DB = int(os.getenv('REDIS_LUOXIANG_DB'))
REDIS_XUZHIYUAN_DB = int(os.getenv('REDIS_XUZHIYUAN_DB'))
KAFKA_TTS_TOPIC = os.getenv('KAFKA_TTS_TOPIC')
KAFKA_ASR_TOPIC = os.getenv('KAFKA_ASR_TOPIC')
KAFKA_CHAT_TOPIC = os.getenv('KAFKA_CHAT_TOPIC')
OUTPUT_PATH= os.getenv('OUTPUT_PATH')
# 初始化 Kafka Producer
producer = KafkaProducer(
bootstrap_servers=[KAFKA_BROKER],
value_serializer=lambda v: json.dumps(v).encode('utf-8')
)
# 初始化 Redis
redis_tts_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_TTS_DB)
redis_asr_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_ASR_DB)
redis_chat_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_CHAT_DB)
redis_api_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_API_DB)
redis_api_usage_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_API_USAGE_DB)
redis_task_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_TASK_DB)
redis_tts_girl = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_GIRL_DB)
redis_tts_woman = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_WOMAN_DB)
redis_tts_man = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_MAN_DB)
redis_tts_leijun = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_LEIJUN_DB)
redis_tts_dufu = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_DUFU_DB)
redis_tts_hejiong = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_HEJIONG_DB)
redis_tts_mahuateng = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_MAHUATENG_DB)
redis_tts_lidan = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_LIDAN_DB)
redis_tts_yuhua = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_YUHUA_DB)
redis_tts_liuzhenyun = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_LIUZHENYUN_DB)
redis_tts_dabing = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_DABING_DB)
redis_tts_luoxiang = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_LUOXIANG_DB)
redis_tts_xuzhiyuan = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_XUZHIYUAN_DB)
# 创建一个音色到对应 Redis 客户端的映射
voice_to_redis = {
"default": redis_tts_girl,
"girl": redis_tts_girl,
"woman": redis_tts_woman,
"man": redis_tts_man,
"leijun": redis_tts_leijun,
"dufu": redis_tts_dufu,
"hejiong": redis_tts_hejiong,
"mahuateng": redis_tts_mahuateng,
"lidan": redis_tts_lidan,
"yuhua": redis_tts_yuhua,
"liuzhenyun": redis_tts_liuzhenyun,
"dabing": redis_tts_dabing,
"luoxiang": redis_tts_luoxiang,
"xuzhiyuan": redis_tts_xuzhiyuan
}
# 定义API密钥头部
api_key_header = APIKeyHeader(name="Authorization", auto_error=False)
def get_audio_hash(text):
return hashlib.md5(text.encode()).hexdigest()
# 验证API密钥的函数
async def get_api_key(api_key: str = Security(api_key_header)):
if api_key and api_key.startswith("Bearer "):
key = api_key.split(" ")[1]
if key.startswith("obs-"):
return key
raise HTTPException(
status_code=401,
detail="无效的API密钥",
headers={"WWW-Authenticate": "Bearer"},
)
async def verify_api_key(api_key: str = Depends(get_api_key)):
redis_key = f"api_key:{api_key}"
api_key_info = redis_api_client.hgetall(redis_key)
if not api_key_info:
raise HTTPException(status_code=401, detail="无效的API密钥")
api_key_info = {k.decode('utf-8'): v.decode('utf-8') for k, v in api_key_info.items()}
if api_key_info.get('is_active') != '1':
raise HTTPException(status_code=401, detail="API密钥已停用")
expires_at = datetime.fromisoformat(api_key_info.get('expires_at'))
if datetime.now(timezone.utc) > expires_at:
raise HTTPException(status_code=401, detail="API密钥已过期")
usage_info = redis_api_usage_client.hgetall(redis_key)
usage_info = {k.decode('utf-8'): v.decode('utf-8') for k, v in usage_info.items()}
return {
"api_key": api_key,
**api_key_info,
**usage_info
}
async def update_token_usage(api_key: str, new_tokens_used: int, model_name: str):
redis_key = f"api_key:{api_key}"
current_time = datetime.now(timezone.utc).isoformat()
pipe = redis_api_usage_client.pipeline()
pipe.hincrby(redis_key, "tokens_used", new_tokens_used)
pipe.hset(redis_key, "last_used_at", current_time)
model_tokens_field = f"{model_name}_tokens_used"
model_last_used_field = f"{model_name}_last_used_at"
pipe.hincrby(redis_key, model_tokens_field, new_tokens_used)
pipe.hset(redis_key, model_last_used_field, current_time)
pipe.execute()
async def process_request(api_key_info: dict, model_name: str, tokens_required: int, task_data: dict, kafka_topic: str):
api_key = api_key_info['api_key']
usage_key = f"api_key:{api_key}"
total_tokens = int(redis_api_usage_client.hget(usage_key, "total_tokens") or 0)
tokens_used = int(redis_api_usage_client.hget(usage_key, "tokens_used") or 0)
if tokens_used + tokens_required > total_tokens:
raise HTTPException(status_code=403, detail="Token 余额不足")
# 更新 token 使用量
await update_token_usage(api_key, tokens_required, model_name)
# 发送任务到Kafka
producer.send(kafka_topic, task_data)
# 获取更新后的 token 使用情况
updated_api_key_info = await verify_api_key(api_key)
new_tokens_used = int(updated_api_key_info.get("tokens_used", 0))
model_tokens_used = int(updated_api_key_info.get(f"{model_name}_tokens_used", 0))
return {
"message": f"{model_name.upper()}请求已排队等待处理",
"tokens_used": tokens_required,
"total_tokens_used": new_tokens_used,
f"{model_name}_tokens_used": model_tokens_used,
"tokens_remaining": total_tokens - new_tokens_used
}
class TTSRequest(BaseModel):
text: str
voice: str = Field(..., description="选择的音色")
class ChatRequest(BaseModel):
session_id: str
query: str
model: str = "qwen2.5:3b"
@v1_chat_app.post("/tts")
async def tts_request(request: TTSRequest, api_key_info: dict = Depends(verify_api_key)):
task_id = str(uuid.uuid4())
text_hash = get_audio_hash(request.text)
# 验证音色选择
valid_voices = ["default", "girl", "woman", "man", "leijun", "dufu", "hejiong", "mahuateng", "lidan", "yuhua", "liuzhenyun", "dabing", "luoxiang", "xuzhiyuan"]
if request.voice not in valid_voices:
raise HTTPException(status_code=400, detail="无效的音色选择")
# 如果声音是 'default',则将其视为 'girl'
voice = 'girl' if request.voice == 'default' else request.voice
# 使用对应音色的 Redis 客户端
redis_tts = voice_to_redis[request.voice]
# 检查是否已存在相同内容的音频文件
existing_audio_info = redis_tts.get(f"tts:{text_hash}")
if existing_audio_info:
existing_audio_path = json.loads(existing_audio_info)['path']
if os.path.exists(existing_audio_path):
return {
"message": "TTS请求已完成",
"task_id": task_id,
"status": "completed",
"audio_path": existing_audio_path
}
# 如果不存在,创建新的任务
task_data = {
"task_id": task_id,
"text": request.text,
"text_hash": text_hash,
"voice": request.voice,
"status": "queued",
"created_at": datetime.now(timezone.utc).isoformat(),
}
# 存储任务信息到Redis
redis_task_client.set(f"task_status:tts:{task_id}", "queued")
result = await process_request(api_key_info, "tts", 1, task_data, KAFKA_TTS_TOPIC)
result["task_id"] = task_id
return result
@v1_chat_app.post("/asr")
async def asr_request(audio: UploadFile = File(...), api_key_info: dict = Depends(verify_api_key)):
task_id = str(uuid.uuid4())
UPLOAD_DIR = "/obscura/task/audio_upload"
os.makedirs(UPLOAD_DIR, exist_ok=True)
file_path = os.path.join(UPLOAD_DIR, f"{task_id}.wav")
with open(file_path, "wb") as temp_audio:
content = await audio.read()
temp_audio.write(content)
task_data = {
'file_path': file_path,
'task_id': task_id,
'status': 'queued'
}
# 存储任务状态,使用一致的键名格式
redis_task_client.set(f"task_status:asr:{task_id}", "queued")
result = await process_request(api_key_info, "asr", 1, task_data, KAFKA_ASR_TOPIC)
result["task_id"] = task_id
return result
@v1_chat_app.post("/chat")
async def chat_request(request: ChatRequest, api_key_info: dict = Depends(verify_api_key)):
task_id = str(uuid.uuid4())
task_data = {
"task_id": task_id,
"session_id": request.session_id,
"query": request.query,
"model": request.model,
"status": "queued",
"created_at": datetime.now(timezone.utc).isoformat(),
}
# 设置任务状态为 "queued"
redis_task_client.set(f"chat:{task_id}:status", "queued")
result = await process_request(api_key_info, "chat", 1, task_data, KAFKA_CHAT_TOPIC)
result["task_id"] = task_id
return result
@v1_chat_app.get("/chat_result/{task_id}")
async def get_chat_result(task_id: str, api_key_info: dict = Depends(verify_api_key)):
# 从Redis任务数据库获取任务状态
task_status = redis_task_client.get(f"chat:{task_id}:status")
if task_status:
status = task_status.decode('utf-8')
if status == "completed":
# 从Redis任务数据库获取聊天结果
chat_result = redis_task_client.get(f"chat:{task_id}:result")
if chat_result:
result = json.loads(chat_result)
return {
"status": "completed",
"result": result
}
return {"status": status}
return {"status": "not_found"}
@v1_chat_app.get("/tts_result/{task_id}")
async def get_tts_result(task_id: str, api_key_info: dict = Depends(verify_api_key)):
task_status = redis_task_client.get(f"task_status:tts:{task_id}")
if task_status:
status = task_status.decode('utf-8')
if status == "completed":
task_info = redis_task_client.get(f"task_info:tts:{task_id}")
if task_info:
task_data = json.loads(task_info)
text_hash = task_data['text_hash']
voice = task_data['voice']
# 'default' 和 'girl' 都使用 girl 的 Redis
redis_tts = voice_to_redis['girl'] if voice in ['default', 'girl'] else voice_to_redis[voice]
audio_info = redis_tts.get(f"tts:{text_hash}")
if audio_info:
audio_path = json.loads(audio_info)['path']
return {
"status": "completed",
"audio_path": audio_path
}
return {"status": status}
return {"status": "not_found"}
@v1_chat_app.get("/asr_result/{task_id}")
async def get_asr_result(task_id: str, api_key_info: dict = Depends(verify_api_key)):
# 从Redis任务数据库获取任务状态,使用一致的键名格式
task_status = redis_task_client.get(f"task_status:asr:{task_id}")
if task_status:
status = task_status.decode('utf-8')
if status == "completed":
# 从Redis ASR结果数据库获取转录结果
transcription = redis_asr_client.get(f"asr:{task_id}")
if transcription:
return {
"status": "completed",
"transcription": transcription.decode('utf-8')
}
return {"status": status}
return {"status": "not_found"}
@v1_chat_app.get("/tts_audio/{task_id}")
async def get_tts_audio(task_id: str, api_key_info: dict = Depends(verify_api_key)):
task_status = redis_task_client.get(f"task_status:tts:{task_id}")
if task_status:
status = task_status.decode('utf-8')
if status == "completed":
# 从任务信息中获取使用的音色
task_info = redis_task_client.get(f"task_info:tts:{task_id}")
if task_info:
task_data = json.loads(task_info)
voice = task_data.get('voice', 'girl') # 默认使用 'girl'
# 'default' 和 'girl' 都使用 girl 的 Redis
redis_tts = voice_to_redis['girl'] if voice in ['default', 'girl'] else voice_to_redis[voice]
# 从对应音色的 Redis 获取音频文件路径
audio_info = redis_tts.get(f"tts:{task_data['text_hash']}")
if audio_info:
audio_path = json.loads(audio_info)['path']
if os.path.exists(audio_path):
file_name = os.path.basename(audio_path)
return FileResponse(audio_path, media_type="audio/wav", filename=file_name)
else:
raise HTTPException(status_code=404, detail="音频文件不存在")
elif status == "queued" or status == "processing":
raise HTTPException(status_code=202, detail="音频文件正在生成中")
else:
raise HTTPException(status_code=500, detail="任务处理失败")
raise HTTPException(status_code=404, detail="任务不存在")
@v1_chat_app.get("/getvoice")
async def get_available_voices(api_key_info: dict = Depends(verify_api_key)):
valid_voices = ["default", "girl", "woman", "man", "leijun", "dufu", "hejiong", "mahuateng", "lidan", "yuhua", "liuzhenyun", "dabing", "luoxiang", "xuzhiyuan"]
return {"available_voices": valid_voices}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8008)
View File
+5
View File
@@ -0,0 +1,5 @@
那一年济南冬天零下十几度
我暖气费交不上
因为从那年往前推好几年
我不接任何的商业了早就
因为不愿意再唱唐会了嘛
Binary file not shown.
+4
View File
@@ -0,0 +1,4 @@
金无足赤人无完人嘛
连朕也出过错误
就说这黄羽全图吧
朕每次见着他
Binary file not shown.
+3
View File
@@ -0,0 +1,3 @@
一些研究表明一個晚上良好的睡眠
就能幫助大腦恢復到最佳狀態
所以如果你已經一週
Binary file not shown.
+4
View File
@@ -0,0 +1,4 @@
很多年前我是主持人
做音乐节目
然后当时我们节目敲了当年最红的一个歌手
叫陈冠希
Binary file not shown.
+3
View File
@@ -0,0 +1,3 @@
我们在短短的半年之间的时间里面
就组成了超过一千人的团队
在过去三年多的时间里面
Binary file not shown.
+4
View File
@@ -0,0 +1,4 @@
他去那个商场
两口子去逛 买电视
大家知道现在的智能电视那个遥控器都有一个语音搜索功能
年轻人不怎么用其实
Binary file not shown.
+5
View File
@@ -0,0 +1,5 @@
在我们村里
最有见识的人呢
是我舅
他是个赶马车的
他不但去过县城
Binary file not shown.
+3
View File
@@ -0,0 +1,3 @@
所以无论是中外
最大限度的反抗标准
在很长一段时间都是一种主流立场
Binary file not shown.
+2
View File
@@ -0,0 +1,2 @@
那么今天呢 政治主持人介绍是我们第四次的互联网家峰会
那么这次的规模是世界以来规模最大的
Binary file not shown.
+2
View File
@@ -0,0 +1,2 @@
今年以來 我國全力推動鄉村產業全鏈條升級
鄉村產業振興呈現良好勢頭
Binary file not shown.
+3
View File
@@ -0,0 +1,3 @@
政法機關堅持黨對政法工作的絕對領導
推動政法體制和工作機制
實現歷史性變革
Binary file not shown.
+5
View File
@@ -0,0 +1,5 @@
我们约了他去做他的采访
他已经答应了
然后结果去那天
他那时候已经得了白血病了
他说真是不巧不好意思
Binary file not shown.
+5
View File
@@ -0,0 +1,5 @@
很小很小的房间
来人的话呢
如果一个人说要外出去上厕所
因为厕所是公共厕所
所有人都得起来走到外面去
Binary file not shown.
View File
+315
View File
@@ -0,0 +1,315 @@
import os
import soundfile as sf
import redis
import hashlib
import json
import traceback
from kafka import KafkaConsumer
from tools.i18n.i18n import I18nAuto
from GPT_SoVITS.inference_webui import change_gpt_weights, change_sovits_weights, get_tts_wav
from dotenv import load_dotenv
import torch
# 加载 .env 文件
load_dotenv()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device}")
# Redis 配置
REDIS_HOST = os.getenv('REDIS_HOST')
REDIS_PORT = int(os.getenv('REDIS_PORT'))
REDIS_TASK_DB = int(os.getenv('REDIS_TASK_DB')) # DB 3
REDIS_PASSWORD = os.getenv('REDIS_PASSWORD')
# Kafka 配置
KAFKA_BROKER = os.getenv('KAFKA_BROKER')
KAFKA_TTS_TOPIC = os.getenv('KAFKA_TTS_TOPIC')
# TTS 配置
GPT_MODEL_PATH = os.getenv('GPT_MODEL_PATH')
SOVITS_MODEL_PATH = os.getenv('SOVITS_MODEL_PATH')
REF_LANGUAGE = os.getenv('REF_LANGUAGE')
TARGET_LANGUAGE = os.getenv('TARGET_LANGUAGE')
OUTPUT_PATH = os.getenv('OUTPUT_PATH')
# Redis 配置
REDIS_GIRL_DB = int(os.getenv('REDIS_GIRL_DB'))
REDIS_WOMAN_DB = int(os.getenv('REDIS_WOMAN_DB'))
REDIS_MAN_DB = int(os.getenv('REDIS_MAN_DB'))
REDIS_LEIJUN_DB = int(os.getenv('REDIS_LEIJUN_DB'))
REDIS_DUFU_DB = int(os.getenv('REDIS_DUFU_DB'))
REDIS_HEJIONG_DB = int(os.getenv('REDIS_HEJIONG_DB'))
REDIS_MAHUATENG_DB = int(os.getenv('REDIS_MAHUATENG_DB'))
REDIS_LIDAN_DB = int(os.getenv('REDIS_LIDAN_DB'))
REDIS_YUHUA_DB = int(os.getenv('REDIS_YUHUA_DB'))
REDIS_LIUZHENYUN_DB = int(os.getenv('REDIS_LIUZHENYUN_DB'))
REDIS_DABING_DB = int(os.getenv('REDIS_DABING_DB'))
REDIS_LUOXIANG_DB = int(os.getenv('REDIS_LUOXIANG_DB'))
REDIS_XUZHIYUAN_DB = int(os.getenv('REDIS_XUZHIYUAN_DB'))
# 初始化 Redis 客户端
redis_tts_girl = redis.Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_GIRL_DB)
redis_tts_woman = redis.Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_WOMAN_DB)
redis_tts_man = redis.Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_MAN_DB)
redis_tts_leijun = redis.Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_LEIJUN_DB)
redis_tts_dufu = redis.Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_DUFU_DB)
redis_tts_hejiong = redis.Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_HEJIONG_DB)
redis_tts_mahuateng = redis.Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_MAHUATENG_DB)
redis_tts_lidan = redis.Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_LIDAN_DB)
redis_tts_yuhua = redis.Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_YUHUA_DB)
redis_tts_liuzhenyun = redis.Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_LIUZHENYUN_DB)
redis_tts_dabing = redis.Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_DABING_DB)
redis_tts_luoxiang = redis.Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_LUOXIANG_DB)
redis_tts_xuzhiyuan = redis.Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_XUZHIYUAN_DB)
redis_task_client = redis.Redis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_TASK_DB, password=REDIS_PASSWORD)
# 创建音色到对应 Redis 客户端的映射
voice_to_redis = {
"default": redis_tts_girl,
"girl": redis_tts_girl,
"woman": redis_tts_woman,
"man": redis_tts_man,
"leijun": redis_tts_leijun,
"dufu": redis_tts_dufu,
"hejiong": redis_tts_hejiong,
"mahuateng": redis_tts_mahuateng,
"lidan": redis_tts_lidan,
"yuhua": redis_tts_yuhua,
"liuzhenyun": redis_tts_liuzhenyun,
"dabing": redis_tts_dabing,
"luoxiang": redis_tts_luoxiang,
"xuzhiyuan": redis_tts_xuzhiyuan
}
i18n = I18nAuto()
# Voice configurations
GIRL_REF_AUDIO = os.getenv('GIRL_REF_AUDIO')
GIRL_REF_TEXT = os.getenv('GIRL_REF_TEXT')
WOMAN_REF_AUDIO = os.getenv('WOMAN_REF_AUDIO')
WOMAN_REF_TEXT = os.getenv('WOMAN_REF_TEXT')
MAN_REF_AUDIO = os.getenv('MAN_REF_AUDIO')
MAN_REF_TEXT = os.getenv('MAN_REF_TEXT')
LEIJUN_REF_AUDIO = os.getenv('LEIJUN_REF_AUDIO')
LEIJUN_REF_TEXT = os.getenv('LEIJUN_REF_TEXT')
DUFU_REF_AUDIO = os.getenv('DUFU_REF_AUDIO')
DUFU_REF_TEXT = os.getenv('DUFU_REF_TEXT')
HEJIONG_REF_AUDIO = os.getenv('HEJIONG_REF_AUDIO')
HEJIONG_REF_TEXT = os.getenv('HEJIONG_REF_TEXT')
MAHUATENG_REF_AUDIO = os.getenv('MAHUATENG_REF_AUDIO')
MAHUATENG_REF_TEXT = os.getenv('MAHUATENG_REF_TEXT')
LIDAN_REF_AUDIO = os.getenv('LIDAN_REF_AUDIO')
LIDAN_REF_TEXT = os.getenv('LIDAN_REF_TEXT')
YUHUA_REF_AUDIO = os.getenv('YUHUA_REF_AUDIO')
YUHUA_REF_TEXT = os.getenv('YUHUA_REF_TEXT')
LIUZHENYUN_REF_AUDIO = os.getenv('LIUZHENYUN_REF_AUDIO')
LIUZHENYUN_REF_TEXT = os.getenv('LIUZHENYUN_REF_TEXT')
DABING_REF_AUDIO = os.getenv('DABING_REF_AUDIO')
DABING_REF_TEXT = os.getenv('DABING_REF_TEXT')
LUOXIANG_REF_AUDIO = os.getenv('LUOXIANG_REF_AUDIO')
LUOXIANG_REF_TEXT = os.getenv('LUOXIANG_REF_TEXT')
XUZHIYUAN_REF_AUDIO = os.getenv('XUZHIYUAN_REF_AUDIO')
XUZHIYUAN_REF_TEXT = os.getenv('XUZHIYUAN_REF_TEXT')
VOICE_CONFIGS = {
"girl": {
"ref_audio": GIRL_REF_AUDIO,
"ref_text": GIRL_REF_TEXT,
"ref_language": REF_LANGUAGE
},
"woman": {
"ref_audio": WOMAN_REF_AUDIO,
"ref_text": WOMAN_REF_TEXT,
"ref_language": REF_LANGUAGE
},
"man": {
"ref_audio": MAN_REF_AUDIO,
"ref_text": MAN_REF_TEXT,
"ref_language": REF_LANGUAGE
},
"leijun": {
"ref_audio": LEIJUN_REF_AUDIO,
"ref_text": LEIJUN_REF_TEXT,
"ref_language": REF_LANGUAGE
},
"dufu": {
"ref_audio": DUFU_REF_AUDIO,
"ref_text": DUFU_REF_TEXT,
"ref_language": REF_LANGUAGE
},
"hejiong": {
"ref_audio": HEJIONG_REF_AUDIO,
"ref_text": HEJIONG_REF_TEXT,
"ref_language": REF_LANGUAGE
},
"mahuateng": {
"ref_audio": MAHUATENG_REF_AUDIO,
"ref_text": MAHUATENG_REF_TEXT,
"ref_language": REF_LANGUAGE
},
"lidan": {
"ref_audio": LIDAN_REF_AUDIO,
"ref_text": LIDAN_REF_TEXT,
"ref_language": REF_LANGUAGE
},
"default": {
"ref_audio": GIRL_REF_AUDIO,
"ref_text": GIRL_REF_TEXT,
"ref_language": REF_LANGUAGE
},
"yuhua": {
"ref_audio": YUHUA_REF_AUDIO,
"ref_text": YUHUA_REF_TEXT,
"ref_language": REF_LANGUAGE
},
"liuzhenyun": {
"ref_audio": LIUZHENYUN_REF_AUDIO,
"ref_text": LIUZHENYUN_REF_TEXT,
"ref_language": REF_LANGUAGE
},
"dabing": {
"ref_audio": DABING_REF_AUDIO,
"ref_text": DABING_REF_TEXT,
"ref_language": REF_LANGUAGE
},
"luoxiang": {
"ref_audio": LUOXIANG_REF_AUDIO,
"ref_text": LUOXIANG_REF_TEXT,
"ref_language": REF_LANGUAGE
},
"xuzhiyuan": {
"ref_audio": XUZHIYUAN_REF_AUDIO,
"ref_text": XUZHIYUAN_REF_TEXT,
"ref_language": REF_LANGUAGE
}
}
def get_audio_hash(text):
return hashlib.md5(text.encode()).hexdigest()
# 在启动时初始化模型
print("正在初始化模型...")
change_gpt_weights(gpt_path=GPT_MODEL_PATH)
change_sovits_weights(sovits_path=SOVITS_MODEL_PATH)
print("模型初始化成功。")
def read_ref_text(voice_type):
ref_text_path = VOICE_CONFIGS[voice_type]["ref_text"]
ref_text = ""
try:
if os.path.exists(ref_text_path):
with open(ref_text_path, 'r', encoding='utf-8') as file:
ref_text = file.read()
else:
print(f"警告:{voice_type} 的参考文本文件 '{ref_text_path}' 不存在。")
except IOError as e:
print(f"错误:无法读取 {voice_type} 的参考文本文件 '{ref_text_path}'{str(e)}")
return ref_text
def synthesize(target_text, output_wav_path, voice):
voice_config = VOICE_CONFIGS[voice]
ref_audio_path = voice_config["ref_audio"]
with open(voice_config["ref_text"], 'r', encoding='utf-8') as file:
ref_text = file.read()
with torch.cuda.device(device):
synthesis_result = get_tts_wav(
ref_wav_path=ref_audio_path,
prompt_text=ref_text,
prompt_language=i18n(voice_config["ref_language"]),
text=target_text,
text_language=i18n(TARGET_LANGUAGE),
top_p=1,
temperature=1
)
result_list = list(synthesis_result)
if result_list:
last_sampling_rate, last_audio_data = result_list[-1]
sf.write(output_wav_path, last_audio_data, last_sampling_rate)
return output_wav_path
else:
return None
def kafka_consumer():
consumer = KafkaConsumer(
KAFKA_TTS_TOPIC,
bootstrap_servers=KAFKA_BROKER,
auto_offset_reset='latest',
value_deserializer=lambda m: json.loads(m.decode('utf-8'))
)
print(f"TTS消费者已启动")
for message in consumer:
task_id = None
error_occurred = False # 将这行移到循环的开始
try:
task_id = message.value['task_id']
target_text = message.value['text']
text_hash = message.value['text_hash']
voice = message.value.get('voice', 'default')
if voice == 'default':
voice = 'girl'
if voice not in VOICE_CONFIGS:
print(f"警告:无效的音色类型 '{voice}'。使用默认音色。")
voice = "girl"
# 更新任务状态为 "processing"
redis_task_client.set(f"task_status:tts:{task_id}", "processing")
# 使用对应音色的 Redis 客户端
redis_tts = voice_to_redis[voice]
# 检查是否已存在相同内容的音频文件
existing_audio_info = redis_tts.get(f"tts:{text_hash}")
if existing_audio_info:
existing_audio_path = json.loads(existing_audio_info)['path']
if os.path.exists(existing_audio_path):
# 如果文件已存在,直接使用现有文件
output_path = existing_audio_path
else:
# 如果文件不存在,重新生成
output_wav_path = os.path.join(OUTPUT_PATH, f"{text_hash}_{voice}.wav")
output_path = synthesize(target_text, output_wav_path, voice)
else:
# 如果不存在,创建新的音频文件
output_wav_path = os.path.join(OUTPUT_PATH, f"{text_hash}_{voice}.wav")
output_path = synthesize(target_text, output_wav_path, voice)
if output_path:
# 将结果保存在对应音色的 Redis 中
redis_tts.set(f"tts:{text_hash}", json.dumps({"path": output_path}))
print(f"音频合成成功: {output_path}")
# 更新任务状态为 "completed"
redis_task_client.set(f"task_status:tts:{task_id}", "completed")
# 存储任务信息
redis_task_client.set(f"task_info:tts:{task_id}", json.dumps({
"text_hash": text_hash,
"voice": voice
}))
else:
print("音频合成失败")
error_occurred = True
except KeyError as e:
print(f"错误:消息中缺少必要的键: {e}")
error_occurred = True
except Exception as e:
print(f"处理消息时出错: {str(e)}")
print(traceback.format_exc())
error_occurred = True
finally:
if error_occurred:
print("处理消息时发生错误")
if task_id:
redis_task_client.set(f"task_status:tts:{task_id}", "failed")
else:
print("消息处理完成")
if __name__ == "__main__":
torch.cuda.set_device(device)
kafka_consumer()
+68
View File
@@ -0,0 +1,68 @@
import os
import whisper
import argparse
def transcribe_audio(model, audio_path):
"""
使用Whisper模型转录音频文件
:param model: 加载的Whisper模型
:param audio_path: 音频文件路径
:return: 转录的文本
"""
try:
result = model.transcribe(audio_path)
return result["text"]
except Exception as e:
print(f"转录失败 {audio_path}: {str(e)}")
return None
def process_directory(directory, model):
"""
处理目录中的所有WAV文件
:param directory: 包含WAV文件的目录路径
:param model: 加载的Whisper模型
"""
for filename in os.listdir(directory):
if filename.lower().endswith('.wav'):
input_file = os.path.join(directory, filename)
output_file = os.path.splitext(input_file)[0] + ".txt"
print(f"正在处理: {input_file}")
transcription = transcribe_audio(model, input_file)
if transcription:
with open(output_file, 'w', encoding='utf-8') as f:
f.write(transcription)
print(f"转录完成: {output_file}")
else:
print(f"转录失败: {input_file}")
def main():
parser = argparse.ArgumentParser(description="使用Whisper将WAV文件转换为文本")
parser.add_argument("input_path", help="输入的WAV文件或包含WAV文件的目录路径")
parser.add_argument("--model", default="small", choices=["tiny", "base", "small", "medium", "large", "large-v3"], help="Whisper模型大小")
args = parser.parse_args()
print(f"正在加载Whisper模型 ({args.model})...")
model = whisper.load_model(args.model)
print("模型加载完成")
if os.path.isfile(args.input_path):
if not args.input_path.lower().endswith('.wav'):
print("错误: 输入文件不是WAV格式")
return
output_file = os.path.splitext(args.input_path)[0] + ".txt"
transcription = transcribe_audio(model, args.input_path)
if transcription:
with open(output_file, 'w', encoding='utf-8') as f:
f.write(transcription)
print(f"转录完成: {output_file}")
elif os.path.isdir(args.input_path):
process_directory(args.input_path, model)
else:
print("错误: 输入路径既不是文件也不是目录")
if __name__ == "__main__":
main()
+1
View File
@@ -0,0 +1 @@
{"GPT": {"v2": "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt"}, "SoVITS": {"v2": "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth"}}
View File
+35
View File
@@ -0,0 +1,35 @@
kafka:
bootstrap_servers:
- "222.186.136.78:9092"
value_serializer: "json"
topics:
all_frames:
name: "pose-input"
num_consumers: 3
ten_seconds:
name: "cpm-input"
num_consumers: 1
input_topic:
name: "raw-data"
num_consumers: 3
redis:
host: "222.186.136.78"
port: 6379
db: 0
password: "Obscura@2024"
minio:
endpoint: "api.obscura.work"
access_key: "00v3MtLtIAIkR3hkIuYR"
secret_key: "XfDeVe5bJjPU21NEYc023gzJVUTJzQqxsWHqIKMf"
secure: true
mongodb:
uri: "mongodb://minio_mongo:BCd4npzKBnwmCRdh@222.186.136.78:27017/minio_mongo"
db_name: "minio_mongo"
model:
pose-path: "worker_sys/function/yolov8n-pose.pt"
cpm-path: "worker_sys/OpenBMB/MiniCPM-V-2_6"
+362
View File
@@ -0,0 +1,362 @@
import json
import io
from PIL import Image
import torch
from kafka import KafkaConsumer, KafkaProducer
from transformers import AutoModel, AutoTokenizer
import threading
import re
from datetime import datetime
import time
import base64
import numpy as np
import cv2
import redis
from redis import ConnectionPool
import yaml
# 加载配置文件
with open('worker_sys/function/config.yaml', 'r') as file:
config = yaml.safe_load(file)
class JSONEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, datetime):
return obj.isoformat()
return json.JSONEncoder.default(self, obj)
class ImageSequenceProcessor:
def __init__(self, model_dir):
self.model = AutoModel.from_pretrained(model_dir, trust_remote_code=True,
attn_implementation='sdpa', torch_dtype=torch.bfloat16).eval().cuda()
self.tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
self.max_size = 512 # 设置最大尺寸
def extract_time_from_key(self, key_name):
# 从key_name中提取时间信息
time_str = key_name.split('_')[-2] + '_' + key_name.split('_')[-1].split('.')[0]
return datetime.strptime(time_str, "%Y%m%d_%H%M%S")
def process_image_sequence(self, image_data, key_names):
frames = []
image_times = []
for i, (img, key_name) in enumerate(zip(image_data, key_names)):
try:
if isinstance(img, np.ndarray):
# 确保图像是 RGB 格式
if img.shape[2] == 3:
frame = Image.fromarray(img)
else:
print(f"Unexpected number of channels for image {i}: {img.shape[2]}")
continue
else:
print(f"Unexpected data type for image {i}: {type(img)}")
continue
frames.append(frame)
image_times.append(self.extract_time_from_key(key_name))
print(f"Successfully processed frame {i}")
except Exception as e:
print(f"Error processing frame {i}: {str(e)}")
continue
if not frames:
raise ValueError("No valid frames were processed")
# 修改时间范围格式
start_time = min(image_times)
end_time = max(image_times)
time_range = {
'start': start_time.strftime('%Y-%m-%d %H:%M'),
'end': end_time.strftime('%Y-%m-%d %H:%M')
}
# # 计算平均时间间隔(以分钟为单位)
# if len(image_times) > 1:
# sequence_period_seconds = (image_times[-1] - image_times[0]).total_seconds()/ (len(image_times) - 1)
# else:
# sequence_period_seconds = 0
total_duration = (end_time - start_time).total_seconds()
num_images = len(image_times)
if num_images > 1:
sequence_period_seconds = total_duration / (num_images - 1)
else:
sequence_period_seconds = 0
question = "Analyze these 3 images as if they were frames from a video. Describe the scene in detail in Chinese, including the setting, number of people, their actions, and any changes or movements observed across the frames."
msgs = [
{'role': 'user', 'content': frames + [question]},
]
params = {
"use_image_id": False,
"max_slice_nums": 1
}
answer = self.model.chat(
image=None,
msgs=msgs,
tokenizer=self.tokenizer,
**params
)
extracted_info = self.extract_info(answer)
return {
"original_answer": answer,
"extracted_info": extracted_info,
"num_frames": len(frames),
"time_range": time_range,
"sequence_period_seconds": sequence_period_seconds
}
@staticmethod
def extract_info(answer):
info = {
"environment": None,
"num_people": None,
"actions": [],
"interactions": [],
"objects": [],
"furniture": []
}
# 环境提取
environments = ["办公室", "室内", "室外", "会议室", "办公"]
for env in environments:
if env in answer.lower():
info["environment"] = env
break
# 改进的人数提取
people_patterns = [
r'(\d+)\s*(人|个人|位|名|员工|用户|小朋友|成年人|女性|男性)',
r'(一|二|三|四|五|六|七|八|九|十)\s*(人|个人|位|名|员工|用户|小朋友|成年人|女性|男性)',
r'(一个|几个)\s*(人|个人|员工|用户|小朋友|成年人|女性|男性)',
r'\s*(名|位)\s*(人|员工|用户|小朋友|成年人|女性|男性)?',
r'(男|女)(性|生|士)',
r'(成年|未成年|青少年|老年)\s*(人|群体)',
r'(员工|职工|工人|学生|顾客|观众|游客|乘客)',
r'(群众|民众|大众|公众)',
r'(男女|老少|老幼|大人|小孩)'
]
for pattern in people_patterns:
match = re.search(pattern, answer)
if match:
if match.group(1).isdigit():
info["num_people"] = int(match.group(1))
elif match.group(1) in ['一个', '']:
info["num_people"] = 1
else:
num_word_to_digit = {
'': 2, '': 3, '': 4, '': 5,
'': 6, '': 7, '': 8, '': 9, '': 10
}
info["num_people"] = num_word_to_digit.get(match.group(1), 0)
break
# 动作和互动提取
actions = ["", "", "摔倒", "跳舞", "转身", "", "", "倒下", "躺下", "转身", "跳跃", "", "", "", "说话"]
for action in actions:
if action in answer:
info["actions"].append(action)
interactions = ["互动", "交流", "身体语言", "交谈", "讨论", "开会"]
for interaction in interactions:
if interaction in answer:
info["interactions"].append(interaction)
# 物体和家具提取
objects = ["水瓶", "办公用品", "文件", "电脑"]
furniture = ["椅子", "桌子", "咖啡桌", "文件柜", "", "沙发"]
for obj in objects:
if obj in answer:
info["objects"].append(obj)
for item in furniture:
if item in answer:
info["furniture"].append(item)
return info
class ImageSequenceAnalysisSystem:
def __init__(self, model_dir):
self.image_processor = ImageSequenceProcessor(model_dir)
def process_image_sequence(self, image_data, cache_keys):
print(f"Attempting to process sequence of {len(image_data)} images")
start_time = time.time()
try:
print("Processing new image sequence...")
for i, (img, cache_key) in enumerate(zip(image_data, cache_keys)):
print(f"Image {i} type: {type(img)}")
if isinstance(img, np.ndarray):
print(f"Image {i} shape: {img.shape}, dtype: {img.dtype}")
else:
print(f"Unexpected data type for image {i}")
print(f"Image {i} cache_key: {cache_key}")
result = self.image_processor.process_image_sequence(image_data, cache_keys)
end_time = time.time()
processing_time = end_time - start_time
print(f"Processed image sequence for time range: {result['time_range']}")
print(f"Average time between frames: {result['sequence_period_seconds']:.2f} minutes")
print(f"Processing time: {processing_time:.2f} seconds")
return result
except Exception as e:
end_time = time.time()
processing_time = end_time - start_time
print(f"Error processing image sequence: {str(e)}")
print(f"Processing time (including error): {processing_time:.2f} seconds")
import traceback
traceback.print_exc()
return None
class ImageProcessingNode:
def __init__(self, bootstrap_servers, input_topic, model_dir, group_id, redis_pool_db0, redis_pool_db1):
self.consumer = KafkaConsumer(
input_topic,
bootstrap_servers=bootstrap_servers,
group_id=group_id,
value_deserializer=lambda x: json.loads(x.decode('utf-8')),
auto_offset_reset='earliest'
)
self.producer = KafkaProducer(
bootstrap_servers=bootstrap_servers,
value_serializer=lambda x: json.dumps(x, cls=JSONEncoder).encode('utf-8')
)
self.input_topic = input_topic
self.analysis_system = ImageSequenceAnalysisSystem(model_dir)
self.group_id = group_id
self.redis_pool_db0 = redis_pool_db0
self.redis_pool_db1 = redis_pool_db1
self.lock = threading.Lock()
self.image_buffer = []
self.image_info_buffer = []
self.cache_key_buffer = [] # 新增:用于存储cache_key
def process_and_produce(self, message_value):
cache_key = message_value.get('cache_key')
etag = message_value.get('etag')
size = message_value.get('size')
object_key = message_value.get('object')
# print(f"Consumer {self.group_id} received image with key_hash: {key_hash}")
try:
with redis.Redis(connection_pool=self.redis_pool_db0) as redis_client_db0:
img_str = redis_client_db0.get(cache_key)
if img_str is None:
print(f"Error: Image data not found in Redis db0 for cache_key: {cache_key}")
return
img_data = base64.b64decode(img_str)
nparr = np.frombuffer(img_data, np.uint8)
img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
self.image_buffer.append(img)
self.image_info_buffer.append({
'key_name': cache_key, # 使用 cache_key 作为 key_name
'etag': etag,
'size': size,
'object_key': object_key
# 'key_hash': key_hash
})
self.cache_key_buffer.append(cache_key)
if len(self.image_buffer) == 3:
result = self.analysis_system.process_image_sequence(self.image_buffer, self.cache_key_buffer)
if result:
result_data = {
'results': result,
'image_sequence': self.image_info_buffer
}
with self.lock:
with redis.Redis(connection_pool=self.redis_pool_db1) as redis_client_db1:
# 使用第一张图片的 cache_key 作为 sequence_key
sequence_key = f"{self.image_info_buffer[0]['key_name']}"
redis_client_db1.set(sequence_key, json.dumps(result_data))
print(f"Consumer {self.group_id} processed and saved results to Redis db1 with sequence_key {sequence_key}")
print(f"Processed image sequence: {[info['key_name'] for info in self.image_info_buffer]}")
# 清空缓冲区
self.image_buffer = []
self.image_info_buffer = []
self.cache_key_buffer = []
except Exception as e:
print(f"Error processing image: {str(e)}")
import traceback
traceback.print_exc()
def run(self):
print(f"Consumer {self.group_id} starting to consume messages from {self.input_topic}")
while True:
messages = self.consumer.poll(timeout_ms=1000)
for tp, records in messages.items():
for record in records:
message_value = record.value
self.process_and_produce(message_value)
# 每分钟打印一次当前缓冲区状态
if len(self.image_buffer) < 3:
print(f"Still waiting for more images. Current buffer size: {len(self.image_buffer)}")
time.sleep(60) # 等待60秒(1分钟)
def start_consumer(kafka_bootstrap_servers, input_topic, model_dir, group_id, redis_host, redis_port, redis_password):
redis_pool_db0 = ConnectionPool(host=redis_host, port=redis_port, db=0, password=redis_password)
redis_pool_db1 = ConnectionPool(host=redis_host, port=redis_port, db=2, password=redis_password)
node = ImageProcessingNode(
kafka_bootstrap_servers,
input_topic,
model_dir,
group_id,
redis_pool_db0,
redis_pool_db1
)
print(f"Image Processing Node {group_id} initialized.")
print(f"Listening on topic: {input_topic}")
node.run()
if __name__ == "__main__":
# Configuration
# kafka_bootstrap_servers = ['222.186.136.78:9092']
# input_topic = 'cpm-input'
# model_dir = 'worker_sys/OpenBMB/MiniCPM-V-2_6'
# redis_host = '222.186.136.78'
# redis_port = 6379
# redis_password = 'Obscura@2024'
# num_consumers = 1
kafka_bootstrap_servers = config['kafka']['bootstrap_servers']
input_topic = config['kafka']['topics']['ten_seconds']['name']
model_dir = config['model']['cpm-path']
redis_host = config['redis']['host']
redis_port = config['redis']['port']
redis_password = config['redis']['password']
num_consumers = config['kafka']['topics']['ten_seconds']['num_consumers']
# 创建多个消费者线程
consumer_threads = []
for i in range(num_consumers):
group_id = f'cpm_group_{i}'
thread = threading.Thread(
target=start_consumer,
args=(kafka_bootstrap_servers, input_topic, model_dir, group_id, redis_host, redis_port, redis_password)
)
consumer_threads.append(thread)
thread.start()
# 等待所有消费者线程完成
for thread in consumer_threads:
thread.join()
+183
View File
@@ -0,0 +1,183 @@
import json
import cv2
import numpy as np
from kafka import KafkaConsumer
from ultralytics import YOLO
import threading
import redis
import base64
import yaml
from PIL import Image
# 加载配置文件
with open('worker_sys/function/config.yaml', 'r') as file:
config = yaml.safe_load(file)
class YOLOv8nPoseProcessor:
def __init__(self, model_path):
self.model = YOLO(model_path)
def process_image(self, img):
results = self.model(img)
return results
def format_results(self, results):
result_data = []
for result in results:
boxes = result.boxes.xywh.tolist()
keypoints = result.keypoints.xy.tolist() if hasattr(result, 'keypoints') else None
classes = result.boxes.cls.tolist()
confs = result.boxes.conf.tolist()
for i, (box, cls, conf) in enumerate(zip(boxes, classes, confs)):
result_data.append({
'box': box,
'keypoints': keypoints[i] if keypoints else None,
'class': int(cls),
'class_name': self.model.names[int(cls)],
'confidence': float(conf)
})
if not result_data:
result_data.append({
'box': None,
'keypoints': None,
'class': None,
'class_name': None,
'confidence': None,
'message': 'No object detected'
})
return json.dumps(result_data)
class ImageProcessingNode:
def __init__(self, bootstrap_servers, input_topic, model_path, group_id, redis_host, redis_port, redis_password):
self.consumer = KafkaConsumer(
input_topic,
bootstrap_servers=bootstrap_servers,
group_id=group_id,
value_deserializer=lambda x: json.loads(x.decode('utf-8')),
auto_offset_reset='earliest'
)
self.input_topic = input_topic
self.yolo_processor = YOLOv8nPoseProcessor(model_path)
self.group_id = group_id
self.redis_client_db0 = redis.Redis(host=redis_host, port=redis_port, db=0, password=redis_password)
self.redis_client_db1 = redis.Redis(host=redis_host, port=redis_port, db=1, password=redis_password)
def process_and_produce(self, message):
cache_key = message.get('cache_key')
object_key = message.get('object')
etag = message.get('etag')
size = message.get('size')
print(f"Consumer {self.group_id} processing image with cache_key: {cache_key}")
# 从Redis db0获取图片数据
img_str = self.redis_client_db0.get(cache_key)
if img_str is None:
print(f"Error: Image data not found in Redis db0 for cache_key: {cache_key}")
return
# 将base64编码的图片数据转换为OpenCV格式
try:
img_data = base64.b64decode(img_str)
nparr = np.frombuffer(img_data, np.uint8)
img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
if img is None:
print(f"Error: Unable to decode image data for cache_key: {cache_key}")
return
results = self.yolo_processor.process_image(img)
formatted_results = self.yolo_processor.format_results(results)
# 创建包含所有必要信息的字典
result_data = {
"pose_results": json.loads(formatted_results),
'cache_key': cache_key,
"object": object_key,
"etag": etag,
"size": size
}
# 将结果保存到Redis db1
try:
serialized_data = json.dumps(result_data)
if not serialized_data:
print(f"Error: Serialized data is empty for cache_key: {cache_key}")
return
self.redis_client_db1.set(cache_key, serialized_data)
print(f"Consumer {self.group_id} processed and saved results to Redis db1 with cache_key {cache_key}")
except TypeError as e:
print(f"Error serializing data: {e}")
print(f"Problematic data: {result_data}")
except Exception as e:
print(f"Unexpected error when saving to Redis: {e}")
except Exception as e:
print(f"Error processing image: {e}")
def run(self):
print(f"Consumer {self.group_id} starting to consume messages from {self.input_topic}")
for message in self.consumer:
message_value = message.value
cache_key = message_value.get('cache_key')
if cache_key:
threading.Thread(target=self.process_and_produce, args=(message_value,)).start()
else:
print(f"Consumer {self.group_id} error: Received message without cache_key")
def start_consumer(kafka_bootstrap_servers, input_topic, model_path, group_id, redis_host, redis_port, redis_password):
node = ImageProcessingNode(
kafka_bootstrap_servers,
input_topic,
model_path,
group_id,
redis_host,
redis_port,
redis_password
)
print(f"Image Processing Node {group_id} initialized.")
print(f"Listening on topic: {input_topic}")
node.run()
if __name__ == "__main__":
# 创建多个消费者线程
# Kafka 配置
kafka_bootstrap_servers = config['kafka']['bootstrap_servers']
input_topic = config['kafka']['topics']['all_frames']['name']
num_consumers = config['kafka']['topics']['all_frames']['num_consumers']
# 模型路径
model_path = config['model']['pose-path']
# Redis 配置
redis_host = config['redis']['host']
redis_port = config['redis']['port']
redis_password = config['redis']['password']
consumer_threads = []
for i in range(num_consumers):
group_id = f'pose_group_{i}'
thread = threading.Thread(
target=start_consumer,
args=(
kafka_bootstrap_servers,
input_topic,
model_path,
group_id,
redis_host,
redis_port,
redis_password
)
)
consumer_threads.append(thread)
thread.start()
# 等待所有消费者线程完成
for thread in consumer_threads:
thread.join()
+204
View File
@@ -0,0 +1,204 @@
import json
import yaml
from kafka import KafkaConsumer, KafkaProducer
import time
from datetime import datetime
import redis
from minio import Minio
import io
from PIL import Image
import base64
import traceback
# 加载配置文件
with open('worker_sys/function/config.yaml', 'r') as file:
config = yaml.safe_load(file)
# Kafka 配置
kafka_config = {
'bootstrap_servers': config['kafka']['bootstrap_servers'],
'value_serializer': lambda v: json.dumps(v).encode('utf-8') if config['kafka']['value_serializer'] == 'json' else None
}
# 创建Kafka生产者
producer = KafkaProducer(**kafka_config)
# 创建Kafka消费者(用于输入主题)
consumer = KafkaConsumer(
config['kafka']['input_topic']['name'],
group_id='image-processor',
bootstrap_servers=config['kafka']['bootstrap_servers']
)
# Redis 配置
redis_config = config['redis']
# 创建 Redis 客户端
redis_client = redis.Redis(**redis_config)
# MinIO 配置
minio_config = config['minio']
# 创建 MinIO 客户端
minio_client = Minio(
minio_config['endpoint'],
access_key=minio_config['access_key'],
secret_key=minio_config['secret_key'],
secure=minio_config['secure']
)
# Kafka topics
topic_all_frames = config['kafka']['topics']['all_frames']['name']
topic_ten_seconds = config['kafka']['topics']['ten_seconds']['name']
topic_input = config['kafka']['input_topic']['name']
# 消费者数量
NUM_CONSUMERS_ALL_FRAMES = config['kafka']['topics']['all_frames']['num_consumers']
NUM_CONSUMERS_TEN_SECONDS = config['kafka']['topics']['ten_seconds']['num_consumers']
NUM_CONSUMERS_INPUT = config['kafka']['input_topic']['num_consumers']
def parse_key_name(key_name):
parts = key_name.split('/')
bucket = parts[0]
image_name = '/'.join(parts[1:])
image_parts = image_name.rsplit('.', 1)[0].split('_')
camera_name = image_parts[0]
timestamp = '_'.join(image_parts[1:3])
return bucket, camera_name, timestamp
def should_send_to_ten_seconds_topic(timestamp):
try:
dt = datetime.strptime(timestamp, '%Y%m%d_%H%M%S')
return dt.second % 10 == 0
except ValueError as e:
print(f"Error parsing timestamp '{timestamp}': {str(e)}")
return False
def get_image_from_minio_and_cache(key_name):
bucket, object_name = key_name.split('/', 1)
cache_key = f"{key_name}"
try:
response = minio_client.get_object(bucket, object_name)
image_data = response.read()
print(f"Successfully retrieved image from MinIO: {bucket}/{object_name}")
image = Image.open(io.BytesIO(image_data))
buffered = io.BytesIO()
image.save(buffered, format="JPEG")
img_str = base64.b64encode(buffered.getvalue()).decode()
redis_client.setex(cache_key, 86400, img_str)
print(f"Successfully cached image in Redis: {cache_key}")
return cache_key
except Exception as e:
print(f"Error in get_image_from_minio_and_cache: {str(e)}")
print("Traceback:")
traceback.print_exc()
return None
def process_message(message):
value_data = json.loads(message.value.decode('utf-8'))
key_name = value_data['Key']
print(f"Processing key: {key_name}")
# 从 key_name 中提取 bucket
parts = key_name.split('/', 1)
if len(parts) != 2:
print(f"Error: Invalid key format. Expected 'bucket/object', got '{key_name}'")
return
bucket, object_name = parts
# 解析 camera_name 和 timestamp
camera_name, timestamp = parse_object_name(object_name)
cache_key = get_image_from_minio_and_cache(key_name)
if cache_key is None:
print(f"Failed to process image: {key_name}")
return
# 提取 etag 和 size 信息
object_info = value_data['Records'][0]['s3']['object']
etag = object_info.get('eTag', '')
size = object_info.get('size', 0)
message_data = {
'bucket': bucket,
'camera_name': camera_name,
'timestamp': timestamp,
'object': object_name,
'cache_key': cache_key,
'etag': etag,
'size': size
}
# 发送到 pose-input 主题
producer.send(topic_all_frames, value=message_data)
print(f"Sent message to {topic_all_frames}: {key_name}")
# 发送到 cpm-input 主题
producer.send(topic_ten_seconds, value=message_data)
print(f"Sent message to {topic_ten_seconds}: {key_name}")
# #只有在满足特定条件时才发送到 cpm-input 主题
# if should_send_to_ten_seconds_topic(timestamp):
# producer.send(TOPIC_TEN_SECONDS, value=message_data)
# print(f"Sent message to {TOPIC_TEN_SECONDS}: {key_name}")
producer.flush()
print(f"Successfully processed and sent messages for: {key_name}")
def parse_object_name(object_name):
# 假设对象名格式为 "cameraX_YYYYMMDD_HHMMSS.jpg"
parts = object_name.split('_')
if len(parts) != 3:
raise ValueError(f"Invalid object name format: {object_name}")
camera_name = parts[0]
timestamp = f"{parts[1]}_{parts[2].split('.')[0]}"
return camera_name, timestamp
def get_image_from_minio_and_cache(key_name):
print(f"Received key_name: {key_name}")
try:
bucket, object_name = key_name.split('/', 1)
except ValueError as e:
print(f"Error splitting key_name '{key_name}': {str(e)}")
return None
cache_key = f"{key_name}"
try:
response = minio_client.get_object(bucket, object_name)
image_data = response.read()
print(f"Successfully retrieved image from MinIO: {bucket}/{object_name}")
image = Image.open(io.BytesIO(image_data))
buffered = io.BytesIO()
image.save(buffered, format="JPEG")
img_str = base64.b64encode(buffered.getvalue()).decode()
redis_client.setex(cache_key, 86400, img_str)
print(f"Successfully cached image in Redis: {cache_key}")
return cache_key
except Exception as e:
print(f"Error in get_image_from_minio_and_cache: {str(e)}")
print("Traceback:")
traceback.print_exc()
return None
def main():
print("Starting image processing service...")
for message in consumer:
process_message(message)
if __name__ == "__main__":
main()
+91
View File
@@ -0,0 +1,91 @@
import redis
import pymongo
import time
import json
import yaml
def sync_data(redis_client, mongo_db, redis_db, mongo_collection_name):
mongo_collection = mongo_db[mongo_collection_name]
all_keys = redis_client.keys('*')
synced_count = 0
for key in all_keys:
try:
poem_data = redis_client.get(key)
if poem_data:
poem_dict = json.loads(poem_data)
# 检查MongoDB中是否已存在相同的文档
existing_doc = mongo_collection.find_one({'_id': key.decode('utf-8')})
if not existing_doc or existing_doc != poem_dict:
mongo_collection.update_one(
{'_id': key.decode('utf-8')},
{'$set': poem_dict},
upsert=True
)
print(f"Synced data with key: {key} from Redis DB {redis_db} to MongoDB collection '{mongo_collection_name}'")
synced_count += 1
except json.JSONDecodeError:
print(f"Error decoding JSON for key: {key} in Redis DB {redis_db}")
except Exception as e:
print(f"Error syncing data for key {key} in Redis DB {redis_db}: {str(e)}")
return synced_count
def main(config):
# Redis配置
redis_host = config['redis']['host']
redis_port = config['redis']['port']
redis_password = config['redis']['password']
# MongoDB配置
mongo_uri = config['mongodb']['uri']
mongo_db_name = config['mongodb']['db_name']
# 连接到MongoDB
mongo_client = pymongo.MongoClient(mongo_uri)
mongo_db = mongo_client[mongo_db_name]
# 固定的Redis数据库和MongoDB集合映射
db_collection_map = {
0: 'pose-result-db0',
1: 'pose-result-db1',
2: 'cpm-result-db2'
}
print("Selected databases and collections for syncing:")
for db, collection in db_collection_map.items():
print(f" Redis DB {db} -> MongoDB collection '{collection}'")
while True:
print("Starting sync...")
total_synced = 0
for db, collection in db_collection_map.items():
print(f"Syncing Redis DB {db} to MongoDB collection '{collection}'...")
try:
redis_client = redis.Redis(host=redis_host, port=redis_port, db=db, password=redis_password)
synced_count = sync_data(redis_client, mongo_db, db, collection)
total_synced += synced_count
except redis.exceptions.AuthenticationError:
print(f"Error: Authentication failed for Redis DB {db}. Skipping...")
except redis.exceptions.ConnectionError:
print(f"Error: Unable to connect to Redis DB {db}. Skipping...")
except Exception as e:
print(f"Error occurred while syncing Redis DB {db}: {str(e)}. Skipping...")
if total_synced > 0:
print(f"Sync completed. {total_synced} documents synced. Waiting for next update...")
else:
print("No new data to sync. Waiting for next update...")
time.sleep(300) # 等待5分钟后再次同步
if __name__ == "__main__":
# 加载配置文件
with open('worker_sys/function/config.yaml', 'r') as file:
config = yaml.safe_load(file)
# 运行主程序
main(config)
+299
View File
@@ -0,0 +1,299 @@
import torch
from PIL import Image
from transformers import AutoModel, AutoTokenizer
import json
import re
from pymongo import MongoClient
import time
from bson import ObjectId
import os
import glob
from datetime import datetime, timedelta
# 数据库连接模块
class DatabaseHandler:
def __init__(self, mongo_uri, database_name, results_collection_name):
self.client = MongoClient(mongo_uri)
self.db = self.client[database_name]
self.results_collection = self.db[results_collection_name]
def save_result(self, result):
# 如果 result 中没有 filename,使用时间戳作为替代
# filename = result.get('filename', f"unknown_{result['timestamp']}")
filename = result.get('filename')
# 检查是否已存在相同 filename 的结果
existing_result = self.results_collection.find_one({'filename': filename})
if existing_result:
print(f"Video with filename {filename} has already been processed. Skipping.")
return
# 确保 result 中有 filename
result['filename'] = filename
# 将 ObjectId 转换为字符串
if 'video_id' in result and isinstance(result['video_id'], ObjectId):
result['video_id'] = str(result['video_id'])
self.results_collection.insert_one(result)
def is_sequence_processed(self, filename):
return self.results_collection.find_one({'filename': filename}) is not None
class JSONEncoder(json.JSONEncoder):
def default(self, o):
if isinstance(o, ObjectId):
return str(o)
return super().default(o)
# 视频处理模块
class ImageSequenceProcessor:
def __init__(self, model_dir):
self.model = AutoModel.from_pretrained(model_dir, trust_remote_code=True,
attn_implementation='sdpa', torch_dtype=torch.bfloat16).eval().cuda()
self.tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
self.max_size = 512 # 设置最大尺寸
def compress_image(self, image):
# 保持纵横比的情况下调整图片大小
image.thumbnail((self.max_size, self.max_size))
# 如果图像已经是JPEG格式,直接返回调整大小后的图像
if image.format == 'JPEG':
return image
# 对于非JPEG格式,进行压缩
buffer = io.BytesIO()
if image.mode in ('RGBA', 'LA') or (image.mode == 'P' and 'transparency' in image.info):
# 保持透明度
image.save(buffer, format="PNG", optimize=True)
else:
# 转换为JPEG并压缩
image.convert('RGB').save(buffer, format="JPEG", quality=85, optimize=True)
buffer.seek(0)
return Image.open(buffer)
def process_image_sequence(self, image_paths):
frames = [self.compress_image(Image.open(img_path)) for img_path in image_paths]
question = "Analyze these 10 images as if they were frames from a video. Describe the scene in detail in Chinese, including the setting, number of people, their actions, and any changes or movements observed across the frames."
msgs = [
{'role': 'user', 'content': frames + [question]},
]
params = {
"use_image_id": False,
"max_slice_nums": 1
}
answer = self.model.chat(
image=None,
msgs=msgs,
tokenizer=self.tokenizer,
**params
)
extracted_info = self.extract_info(answer)
return {
"original_answer": answer,
"extracted_info": extracted_info,
"num_frames": len(frames),
}
@staticmethod
def extract_info(answer):
info = {
"environment": None,
"num_people": None,
"actions": [],
"interactions": [],
"objects": [],
"furniture": []
}
# 环境提取
environments = ["办公室", "室内", "室外", "会议室", "办公"]
for env in environments:
if env in answer.lower():
info["environment"] = env
break
# 改进的人数提取
people_patterns = [
r'(\d+)\s*(人|个人|位|名|员工|用户|小朋友|成年人|女性|男性)',
r'(一|二|三|四|五|六|七|八|九|十)\s*(人|个人|位|名|员工|用户|小朋友|成年人|女性|男性)',
r'(一个|几个)\s*(人|个人|员工|用户|小朋友|成年人|女性|男性)',
r'\s*(名|位)\s*(人|员工|用户|小朋友|成年人|女性|男性)?',
r'(男|女)(性|生|士)',
r'(成年|未成年|青少年|老年)\s*(人|群体)',
r'(员工|职工|工人|学生|顾客|观众|游客|乘客)',
r'(群众|民众|大众|公众)',
r'(男女|老少|老幼|大人|小孩)'
]
for pattern in people_patterns:
match = re.search(pattern, answer)
if match:
if match.group(1).isdigit():
info["num_people"] = int(match.group(1))
elif match.group(1) in ['一个', '']:
info["num_people"] = 1
else:
num_word_to_digit = {
'': 2, '': 3, '': 4, '': 5,
'': 6, '': 7, '': 8, '': 9, '': 10
}
info["num_people"] = num_word_to_digit.get(match.group(1), 0)
break
# 动作和互动提取
actions = ["", "", "摔倒", "跳舞", "转身", "", "", "倒下", "躺下", "转身", "跳跃", "", "", "", "说话"]
for action in actions:
if action in answer:
info["actions"].append(action)
interactions = ["互动", "交流", "身体语言", "交谈", "讨论", "开会"]
for interaction in interactions:
if interaction in answer:
info["interactions"].append(interaction)
# 物体和家具提取
objects = ["水瓶", "办公用品", "文件", "电脑"]
furniture = ["椅子", "桌子", "咖啡桌", "文件柜", "", "沙发"]
for obj in objects:
if obj in answer:
info["objects"].append(obj)
for item in furniture:
if item in answer:
info["furniture"].append(item)
return info
# 主处理类
class ImageSequenceAnalysisSystem:
def __init__(self, mongo_uri, db_name, model_dir, results_collection_name):
self.db_handler = DatabaseHandler(mongo_uri, db_name, results_collection_name)
self.image_processor = ImageSequenceProcessor(model_dir)
self.last_processed_time = datetime.now() - timedelta(hours=1)
def get_all_images(self, image_folders):
image_files = []
for folder in image_folders:
image_files.extend(glob.glob(os.path.join(folder, '*.jpg')))
image_files.sort()
return image_files
def process_image_sequence(self, image_paths):
print(f"Attempting to process sequence: {[os.path.basename(img) for img in image_paths]}")
start_time = time.time()
try:
# 使用第一张图片的文件名作为序列的标识符
filename = os.path.basename(image_paths[0])
if self.db_handler.is_sequence_processed(filename):
print(f"Skipping already processed image sequence: {filename}")
return False
print("Processing new image sequence...")
result = self.image_processor.process_image_sequence(image_paths)
# timestamp = datetime.now()
# result['timestamp'] = timestamp.strftime("%Y%m%d_%H%M%S")
result['image_paths'] = image_paths
result['filename'] = filename
# 计算图片序列的周期
image_times = [self.get_file_time(img) for img in image_paths]
if len(image_times) >= 2:
time_diff = (image_times[-1] - image_times[0]).total_seconds()
period_minutes = time_diff / (len(image_times) - 1) / 60
result['sequence_period_minutes'] = round(period_minutes, 2)
# 添加时间段信息
result['time_range'] = {
'start': image_times[0].strftime("%Y-%m-%d %H:%M"),
'end': image_times[-1].strftime("%Y-%m-%d %H:%M")
}
save_result = self.db_handler.save_result(result)
print(f"Result saved to: {self.db_handler.results_collection.name}")
print(f"Result filename: {filename}")
return save_result
except Exception as e:
end_time = time.time()
processing_time = end_time - start_time
print(f"Error processing image sequence: {str(e)}")
print(f"Processing time (including error): {processing_time:.2f} seconds")
import traceback
traceback.print_exc()
return False
# @staticmethod
# def extract_time_from_filename(filename):
# # 假设文件名格式为 "YYYYMMDDHHMMSS.jpg"
# time_str = filename.split('.')[0]
# return datetime.strptime(time_str, "%Y%m%d%H%M")
@staticmethod
def get_file_time(file_path):
# 获取文件的修改时间
mod_time = os.path.getmtime(file_path)
return datetime.fromtimestamp(mod_time)
def process_all_unprocessed_images(self, image_folders):
print(f"Searching for unprocessed images in: {image_folders}")
all_images = self.get_all_images(image_folders)
print(f"Found {len(all_images)} images in total")
selected_images = all_images[::10]
# Group images into sequences of 3
image_sequences = [selected_images[i:i+10] for i in range(0, len(selected_images), 10)]
# image_sequences = [all_images[i:i+10] for i in range(0, len(all_images), 10)]
processed_sequences = 0
for sequence in image_sequences:
if len(sequence) == 10:
self.process_image_sequence(sequence)
processed_sequences += 1
else:
print(f"Warning: Incomplete sequence. Found {len(sequence)} images.")
if processed_sequences == 0:
print("All current photos have been processed. Waiting for new photos...")
else:
print(f"Processed {processed_sequences} sequences.")
def run(self, root_folders):
print(f"Starting the system with root folder: {', '.join(root_folders)}")
while True:
current_time = datetime.now()
time_since_last_process = (current_time - self.last_processed_time).total_seconds()
if time_since_last_process >= 3600: # 1小时 = 3600秒
self.process_all_unprocessed_images(root_folders)
self.last_processed_time = current_time
# 计算下次检查的等待时间
wait_time = max(0, 3600 - (datetime.now() - self.last_processed_time).total_seconds())
print(f"Waiting for new photos... Next check in {wait_time:.0f} seconds.")
time.sleep(60) # 每分钟检查一次是否需要处理
# 使用示例
if __name__ == "__main__":
mongo_uri = "mongodb://minio_mongo:BCd4npzKBnwmCRdh@222.186.136.78:27017/minio_mongo"
db_name = "minio_mongo"
results_collection_name = "cpm"
model_dir = "worker_sys/OpenBMB/MiniCPM-V-2_6"
root_folders = [
"/www/wwwroot/zj.obscura.ac.cn/ipcam/Office/Cam2/CapturePics" ,
"/www/wwwroot/zj.obscura.ac.cn/ipcam/Office/Cam1/CapturePics"
] # 修改为 cam1 文件夹的路径
system = ImageSequenceAnalysisSystem(mongo_uri, db_name, model_dir, results_collection_name)
system.run(root_folders)
+194
View File
@@ -0,0 +1,194 @@
import torch
from PIL import Image
import json
from pymongo import MongoClient
import os
import glob
from datetime import datetime, timedelta
from ultralytics import YOLO
from bson import ObjectId
import time
# 数据库连接模块
class DatabaseHandler:
def __init__(self, mongo_uri, database_name, results_collection_name):
self.client = MongoClient(mongo_uri)
self.db = self.client[database_name]
self.results_collection = self.db[results_collection_name]
def save_result(self, result):
filename = result.get('filename', f"unknown_{result['timestamp']}")
existing_result = self.results_collection.find_one({'filename': filename})
if existing_result:
print(f"Image with filename {filename} has already been processed. Skipping.")
return False
result['filename'] = filename
if 'image_id' in result and isinstance(result['image_id'], ObjectId):
result['image_id'] = str(result['image_id'])
self.results_collection.insert_one(result)
return True
def is_image_processed(self, filename):
return self.results_collection.find_one({'filename': filename}) is not None
class JSONEncoder(json.JSONEncoder):
def default(self, o):
if isinstance(o, ObjectId):
return str(o)
return super().default(o)
# YOLOv8nPoseProcessor 类
class YOLOv8nPoseProcessor:
def __init__(self, model_path):
self.model = YOLO(model_path)
def process_image(self, img):
results = self.model(img)
return results
def format_results(self, results):
result_data = []
for result in results:
boxes = result.boxes.xywh.tolist()
keypoints = result.keypoints.xy.tolist() if hasattr(result, 'keypoints') else None
classes = result.boxes.cls.tolist()
confs = result.boxes.conf.tolist()
for i, (box, cls, conf) in enumerate(zip(boxes, classes, confs)):
result_data.append({
'box': box,
'keypoints': keypoints[i] if keypoints else None,
'class': int(cls),
'class_name': self.model.names[int(cls)],
'confidence': float(conf)
})
if not result_data:
result_data.append({
'box': None,
'keypoints': None,
'class': None,
'class_name': None,
'confidence': None,
'message': 'No object detected'
})
return json.dumps(result_data)
# 主处理类
class ImageAnalysisSystem:
def __init__(self, mongo_uri, db_name, model_path, results_collection_name):
self.db_handler = DatabaseHandler(mongo_uri, db_name, results_collection_name)
self.image_processor = YOLOv8nPoseProcessor(model_path)
self.last_processed_time = datetime.now() - timedelta(hours=1)
def get_all_images(self, image_folders):
image_files = []
for folder in image_folders:
image_files.extend(glob.glob(os.path.join(folder, '*.jpg')))
image_files.sort()
return image_files
@staticmethod
def get_file_time(file_path):
# 获取文件的修改时间
mod_time = os.path.getmtime(file_path)
return datetime.fromtimestamp(mod_time)
def process_image(self, image_path):
print(f"Attempting to process image: {os.path.basename(image_path)}")
try:
# json_folder = os.path.join("/www/wwwroot/zj.obscura.ac.cn/ipcam/Office/Cam2", 'json')
# json_filename = f"{os.path.basename(image_path).split('.')[0]}.json"
# json_path = os.path.join(json_folder, json_filename)
# if os.path.exists(json_path):
# print(f"Skipping already processed image: {json_path}")
# return
filename = os.path.basename(image_path)
if self.db_handler.is_image_processed(filename):
print(f"Skipping already processed image: {filename}")
return False
print("Processing new image...")
image = Image.open(image_path)
results = self.image_processor.process_image(image)
formatted_results = self.image_processor.format_results(results)
# timestamp = datetime.now()
file_timestamp = self.get_file_time(image_path)
result = {
'timestamp': file_timestamp.strftime("%Y%m%d_%H%M%S"),
'image_path': image_path,
'filename': os.path.basename(image_path),
'results': json.loads(formatted_results)
}
# os.makedirs(json_folder, exist_ok=True)
# with open(json_path, 'w', encoding='utf-8') as f:
# json.dump(result, f, ensure_ascii=False, indent=4, cls=JSONEncoder)
# self.db_handler.save_result(result)
# print(f"Processed image at: {timestamp}")
# print(f"JSON saved to: {json_path}")
# print(f"result saved to: {results_collection_name}")
if self.db_handler.save_result(result):
print(f"Result saved to: {self.db_handler.results_collection.name}")
return True
else:
print(f"Image {filename} was already in the database. Skipping.")
return False
except Exception as e:
print(f"Error processing image: {str(e)}")
import traceback
traceback.print_exc()
def process_all_unprocessed_images(self, image_folders):
print(f"Searching for unprocessed images in: {image_folders}")
all_images = self.get_all_images(image_folders)
print(f"Found {len(all_images)} images in total")
processed_count = 0
for image_path in all_images:
if self.process_image(image_path):
processed_count += 1
return processed_count
def run(self, root_folder):
print(f"Starting the system with root folder: {', '.join(root_folder)}")
# image_folder = os.path.join(root_folder)
while True:
print("Checking for unprocessed images...")
processed_count = self.process_all_unprocessed_images(root_folder)
if processed_count > 0:
print(f"Finished processing {processed_count} images.")
else:
print("No new images to process. Waiting for new images...")
# 等待一段时间后再次检查新图片
time.sleep(60) # 每分钟检查一次是否有新图片
# 使用示例
if __name__ == "__main__":
mongo_uri = "mongodb://minio_mongo:BCd4npzKBnwmCRdh@222.186.136.78:27017/minio_mongo"
db_name = "minio_mongo"
results_collection_name = "pose"
model_path = "worker_sys/function/yolov8x-pose.pt" # 请确保这个路径指向你的YOLO-Pose模型文件
root_folder = [
"/www/wwwroot/zj.obscura.ac.cn/ipcam/Office/Cam2/CapturePics" ,
"/www/wwwroot/zj.obscura.ac.cn/ipcam/Office/Cam1/CapturePics"
] # 修改为 cam1 文件夹的路径
system = ImageAnalysisSystem(mongo_uri, db_name, model_path, results_collection_name)
system.run(root_folder)
+255
View File
@@ -0,0 +1,255 @@
import torch
from PIL import Image
from transformers import AutoModel, AutoTokenizer
from decord import VideoReader, cpu
import json
import re
from pymongo import MongoClient
import io
from minio import Minio
import time
from bson import ObjectId
import concurrent.futures
import os
# Minio连接模块
class MinioHandler:
def __init__(self, endpoint, access_key, secret_key):
self.client = Minio(
endpoint,
access_key=access_key,
secret_key=secret_key,
secure=True
)
def get_video_data(self, bucket, object_name):
response = self.client.get_object(bucket, object_name)
data = response.read()
# print(f"Read {len(data)} bytes from Minio for {object_name}")
return data
# 数据库连接模块
class DatabaseHandler:
def __init__(self, mongo_uri, database_name, results_collection_name):
self.client = MongoClient(mongo_uri)
self.db = self.client[database_name]
self.minio_files_collection = self.db['minio_files']
self.results_collection = self.db[results_collection_name]
def get_unprocessed_videos(self):
processed_etags = set(self.results_collection.distinct('etag'))
return self.minio_files_collection.find({
'bucket_name': 'raw',
'object_name': {'$regex': r'^douyin/.*/.+\.(mp4|avi|mov|flv)$'},
'etag': {'$nin': list(processed_etags)}
})
def save_result(self, result):
# 检查是否已存在相同 etag 的结果
existing_result = self.results_collection.find_one({'etag': result['etag']})
if existing_result:
print(f"Video with etag {result['etag']} has already been processed. Skipping.")
return
# 将 ObjectId 转换为字符串
if 'video_id' in result and isinstance(result['video_id'], ObjectId):
result['video_id'] = str(result['video_id'])
self.results_collection.insert_one(result)
class JSONEncoder(json.JSONEncoder):
def default(self, o):
if isinstance(o, ObjectId):
return str(o)
return super().default(o)
# 视频处理模块
class VideoProcessor:
def __init__(self, model_dir):
self.model = AutoModel.from_pretrained(model_dir, trust_remote_code=True,
attn_implementation='sdpa', torch_dtype=torch.bfloat16).eval().cuda()
self.tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
self.MAX_NUM_FRAMES = 64
def encode_video(self, video_data):
def uniform_sample(l, n):
gap = len(l) / n
idxs = [int(i * gap + gap / 2) for i in range(n)]
return [l[i] for i in idxs]
video_file = io.BytesIO(video_data)
vr = VideoReader(video_file, ctx=cpu(0))
sample_fps = round(vr.get_avg_fps() / 1)
frame_idx = [i for i in range(0, len(vr), sample_fps)]
if len(frame_idx) > self.MAX_NUM_FRAMES:
frame_idx = uniform_sample(frame_idx, self.MAX_NUM_FRAMES)
frames = vr.get_batch(frame_idx).asnumpy()
frames = [Image.fromarray(v.astype('uint8')) for v in frames]
print('num frames:', len(frames))
return frames
def process_video(self, video_data, object_name):
if not video_data:
raise ValueError(f"Empty video data for {object_name}")
print(f"Processing video: {object_name}, data size: {len(video_data)} bytes")
frames = self.encode_video(video_data)
question = "Describe the video in as much detail as possible in Chinese, including the setting, clear number of people, and changes in behavior."
msgs = [
{'role': 'user', 'content': frames + [question]},
]
params = {
"use_image_id": False,
"max_slice_nums": 1
}
answer = self.model.chat(
image=None,
msgs=msgs,
tokenizer=self.tokenizer,
**params
)
extracted_info = self.extract_info(answer)
return {
"original_answer": answer,
"extracted_info": extracted_info,
"num_frames": len(frames),
}
@staticmethod
def extract_info(answer):
info = {
"environment": None,
"num_people": None,
"actions": [],
"interactions": [],
"objects": [],
"furniture": []
}
# 环境提取
environments = ["办公室", "室内", "室外", "会议室", "办公"]
for env in environments:
if env in answer.lower():
info["environment"] = env
break
# 改进的人数提取
people_patterns = [
r'(\d+)\s*(人|个人|位|名|员工|用户|小朋友|成年人|女性|男性)',
r'(一|二|三|四|五|六|七|八|九|十)\s*(人|个人|位|名|员工|用户|小朋友|成年人|女性|男性)',
r'(一个|几个)\s*(人|个人|员工|用户|小朋友|成年人|女性|男性)',
r'\s*(名|位)\s*(人|员工|用户|小朋友|成年人|女性|男性)?',
r'(男|女)(性|生|士)',
r'(成年|未成年|青少年|老年)\s*(人|群体)',
r'(员工|职工|工人|学生|顾客|观众|游客|乘客)',
r'(群众|民众|大众|公众)',
r'(男女|老少|老幼|大人|小孩)'
]
for pattern in people_patterns:
match = re.search(pattern, answer)
if match:
if match.group(1).isdigit():
info["num_people"] = int(match.group(1))
elif match.group(1) in ['一个', '']:
info["num_people"] = 1
else:
num_word_to_digit = {
'': 2, '': 3, '': 4, '': 5,
'': 6, '': 7, '': 8, '': 9, '': 10
}
info["num_people"] = num_word_to_digit.get(match.group(1), 0)
break
# 动作和互动提取
actions = ["", "", "摔倒", "跳舞", "转身", "", "", "倒下", "躺下", "转身", "跳跃", "", "", "", "说话"]
for action in actions:
if action in answer:
info["actions"].append(action)
interactions = ["互动", "交流", "身体语言", "交谈", "讨论", "开会"]
for interaction in interactions:
if interaction in answer:
info["interactions"].append(interaction)
# 物体和家具提取
objects = ["水瓶", "办公用品", "文件", "电脑"]
furniture = ["椅子", "桌子", "咖啡桌", "文件柜", "", "沙发"]
for obj in objects:
if obj in answer:
info["objects"].append(obj)
for item in furniture:
if item in answer:
info["furniture"].append(item)
return info
# 主处理类
class VideoAnalysisSystem:
def __init__(self, minio_endpoint, minio_access_key, minio_secret_key,
mongo_uri, db_name, model_dir, results_collection_name):
self.minio_handler = MinioHandler(minio_endpoint, minio_access_key, minio_secret_key)
self.db_handler = DatabaseHandler(mongo_uri, db_name, results_collection_name)
self.video_processor = VideoProcessor(model_dir)
def process_video(self, video_doc):
start_time = time.time()
try:
video_data = self.minio_handler.get_video_data(video_doc['bucket_name'], video_doc['object_name'])
result = self.video_processor.process_video(video_data, video_doc['object_name'])
result['etag'] = video_doc['etag']
result['bucket_name'] = video_doc['bucket_name']
result['object_name'] = video_doc['object_name']
self.db_handler.save_result(result)
end_time = time.time()
processing_time = end_time - start_time
print(f"Processed video: {video_doc['object_name']}")
print(f"Processing time: {processing_time:.2f} seconds")
except Exception as e:
end_time = time.time()
processing_time = end_time - start_time
print(f"Error processing video {video_doc['object_name']}: {str(e)}")
print(f"Processing time (including error): {processing_time:.2f} seconds")
import traceback
traceback.print_exc()
def run(self):
while True:
unprocessed_videos = list(self.db_handler.get_unprocessed_videos())
if not unprocessed_videos:
print("No new videos to process. Waiting for 60 seconds before checking again...")
time.sleep(60)
continue
for video_doc in unprocessed_videos:
self.process_video(video_doc)
print("Finished processing current batch of videos. Waiting for new videos...")
time.sleep(30)
# 使用示例
if __name__ == "__main__":
minio_endpoint = "api.obscura.work"
minio_access_key = "MnHTAG2NOLyXXIZrwDLp"
minio_secret_key = "WVlmMgww0aRIU43pCJ1XCjubXQO6YsbHysxX2hBf"
mongo_uri = "mongodb://minio_mongo:BCd4npzKBnwmCRdh@222.186.136.78:27017/minio_mongo"
db_name = "minio_mongo"
results_collection_name = "douyin_results"
model_dir = "MiniCPM-V-2_6"
system = VideoAnalysisSystem(minio_endpoint, minio_access_key, minio_secret_key,
mongo_uri, db_name, model_dir, results_collection_name)
system.run()
+248
View File
@@ -0,0 +1,248 @@
import torch
from PIL import Image
from transformers import AutoModel, AutoTokenizer
from decord import VideoReader, cpu
import json
import re
from pymongo import MongoClient
import io
from minio import Minio
import time
from bson import ObjectId
# Minio连接模块
class MinioHandler:
def __init__(self, endpoint, access_key, secret_key):
self.client = Minio(
endpoint,
access_key=access_key,
secret_key=secret_key,
secure=True
)
def get_video_data(self, bucket, object_name):
response = self.client.get_object(bucket, object_name)
data = response.read()
print(f"Read {len(data)} bytes from Minio for {object_name}")
return data
# 数据库连接模块
class DatabaseHandler:
def __init__(self, mongo_uri, database_name, results_collection_name):
self.client = MongoClient(mongo_uri)
self.db = self.client[database_name]
self.minio_files_collection = self.db['minio_files']
self.results_collection = self.db[results_collection_name]
def get_unprocessed_videos(self):
# 查找 bucket_name 为 'raw-video' 且在结果集合中没有对应 etag 的视频
processed_etags = set(self.results_collection.distinct('etag'))
return self.minio_files_collection.find({
'bucket_name': 'raw',
'object_name': {'$regex': 'videoupload/'},
'etag': {'$nin': list(processed_etags)}
})
def save_result(self, result):
# 检查是否已存在相同 etag 的结果
existing_result = self.results_collection.find_one({'etag': result['etag']})
if existing_result:
print(f"Video with etag {result['etag']} has already been processed. Skipping.")
return
# 将 ObjectId 转换为字符串
if 'video_id' in result and isinstance(result['video_id'], ObjectId):
result['video_id'] = str(result['video_id'])
self.results_collection.insert_one(result)
class JSONEncoder(json.JSONEncoder):
def default(self, o):
if isinstance(o, ObjectId):
return str(o)
return super().default(o)
# 视频处理模块
class VideoProcessor:
def __init__(self, model_dir):
self.model = AutoModel.from_pretrained(model_dir, trust_remote_code=True,
attn_implementation='sdpa', torch_dtype=torch.bfloat16).eval().cuda()
self.tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
self.MAX_NUM_FRAMES = 64
def encode_video(self, video_data):
def uniform_sample(l, n):
gap = len(l) / n
idxs = [int(i * gap + gap / 2) for i in range(n)]
return [l[i] for i in idxs]
video_file = io.BytesIO(video_data)
vr = VideoReader(video_file, ctx=cpu(0))
sample_fps = round(vr.get_avg_fps() / 1)
frame_idx = [i for i in range(0, len(vr), sample_fps)]
if len(frame_idx) > self.MAX_NUM_FRAMES:
frame_idx = uniform_sample(frame_idx, self.MAX_NUM_FRAMES)
frames = vr.get_batch(frame_idx).asnumpy()
frames = [Image.fromarray(v.astype('uint8')) for v in frames]
print('num frames:', len(frames))
return frames
def process_video(self, video_data, object_name):
if not video_data:
raise ValueError(f"Empty video data for {object_name}")
print(f"Processing video: {object_name}, data size: {len(video_data)} bytes")
frames = self.encode_video(video_data)
question = "Describe the video in as much detail as possible in Chinese, including the setting, clear number of people, and changes in behavior."
msgs = [
{'role': 'user', 'content': frames + [question]},
]
params = {
"use_image_id": False,
"max_slice_nums": 1
}
answer = self.model.chat(
image=None,
msgs=msgs,
tokenizer=self.tokenizer,
**params
)
extracted_info = self.extract_info(answer)
return {
"original_answer": answer,
"extracted_info": extracted_info,
"num_frames": len(frames),
}
@staticmethod
def extract_info(answer):
info = {
"environment": None,
"num_people": None,
"actions": [],
"interactions": [],
"objects": [],
"furniture": []
}
# 环境提取
environments = ["办公室", "室内", "室外", "会议室"]
for env in environments:
if env in answer.lower():
info["environment"] = env
break
# 改进的人数提取
people_patterns = [
r'(\d+)\s*(人|个人|位|名|员工|用户|小朋友|成年人|女性|男性)',
r'(一|二|三|四|五|六|七|八|九|十)\s*(人|个人|位|名|员工|用户|小朋友|成年人|女性|男性)',
r'(一个|几个)\s*(人|个人|员工|用户|小朋友|成年人|女性|男性)',
r'\s*(名|位)\s*(人|员工|用户|小朋友|成年人|女性|男性)?',
r'(男|女)(性|生|士)',
r'(成年|未成年|青少年|老年)\s*(人|群体)',
r'(员工|职工|工人|学生|顾客|观众|游客|乘客)',
r'(群众|民众|大众|公众)',
r'(男女|老少|老幼|大人|小孩)'
]
for pattern in people_patterns:
match = re.search(pattern, answer)
if match:
if match.group(1).isdigit():
info["num_people"] = int(match.group(1))
elif match.group(1) in ['一个', '']:
info["num_people"] = 1
else:
num_word_to_digit = {
'': 2, '': 3, '': 4, '': 5,
'': 6, '': 7, '': 8, '': 9, '': 10
}
info["num_people"] = num_word_to_digit.get(match.group(1), 0)
break
# 动作和互动提取
actions = ["", "", "摔倒", "跳舞", "转身", "", "", "倒下", "躺下", "转身", "跳跃", "", "", "", "说话"]
for action in actions:
if action in answer:
info["actions"].append(action)
interactions = ["互动", "交流", "身体语言", "交谈", "讨论", "开会"]
for interaction in interactions:
if interaction in answer:
info["interactions"].append(interaction)
# 物体和家具提取
objects = ["水瓶", "办公用品", "文件", "电脑"]
furniture = ["椅子", "桌子", "咖啡桌", "文件柜", "", "沙发"]
for obj in objects:
if obj in answer:
info["objects"].append(obj)
for item in furniture:
if item in answer:
info["furniture"].append(item)
return info
# 主处理类
class VideoAnalysisSystem:
def __init__(self, minio_endpoint, minio_access_key, minio_secret_key,
mongo_uri, db_name, model_dir, results_collection_name):
self.minio_handler = MinioHandler(minio_endpoint, minio_access_key, minio_secret_key)
self.db_handler = DatabaseHandler(mongo_uri, db_name, results_collection_name)
self.video_processor = VideoProcessor(model_dir)
def run(self):
while True:
unprocessed_videos = list(self.db_handler.get_unprocessed_videos())
if not unprocessed_videos:
print("No new videos to process. Waiting for 60 seconds before checking again...")
time.sleep(1)
continue
for video_doc in unprocessed_videos:
try:
video_data = self.minio_handler.get_video_data(video_doc['bucket_name'], video_doc['object_name'])
result = self.video_processor.process_video(video_data, video_doc['object_name'])
# 添加额外信息到结果中
result['etag'] = video_doc['etag']
# result['video_id'] = str(video_doc['_id']) # 将 ObjectId 转换为字符串
result['bucket_name'] = video_doc['bucket_name']
result['object_name'] = video_doc['object_name']
# 保存结果到 MongoDB
self.db_handler.save_result(result)
print(f"Processed video: {video_doc['object_name']}")
# print(json.dumps(result, ensure_ascii=False, indent=2, cls=JSONEncoder))
except Exception as e:
print(f"Error processing video {video_doc['object_name']}: {str(e)}")
import traceback
traceback.print_exc() # 打印完整的错误堆栈
print("Finished processing current batch of videos. Waiting for new videos...")
time.sleep(30)
# 使用示例
if __name__ == "__main__":
minio_endpoint = "api.obscura.work"
minio_access_key = "MnHTAG2NOLyXXIZrwDLp"
minio_secret_key = "WVlmMgww0aRIU43pCJ1XCjubXQO6YsbHysxX2hBf"
mongo_uri = "mongodb://minio_mongo:BCd4npzKBnwmCRdh@222.186.136.78:27017/minio_mongo"
db_name = "minio_mongo"
results_collection_name = "douyin_results"
model_dir = "OpenBMB/MiniCPM-V-2_6"
system = VideoAnalysisSystem(minio_endpoint, minio_access_key, minio_secret_key,
mongo_uri, db_name, model_dir, results_collection_name)
system.run()
+264
View File
@@ -0,0 +1,264 @@
import torch
from PIL import Image
from transformers import AutoModel, AutoTokenizer
from decord import VideoReader, cpu
import json
import re
from datetime import datetime, timedelta
from pymongo import MongoClient
import io
from minio import Minio
import time
import os
from bson import ObjectId
# Minio连接模块
class MinioHandler:
def __init__(self, endpoint, access_key, secret_key):
self.client = Minio(
endpoint,
access_key=access_key,
secret_key=secret_key,
secure=True
)
def get_video_data(self, bucket, object_name):
response = self.client.get_object(bucket, object_name)
return response.read()
# 数据库连接模块
class DatabaseHandler:
def __init__(self, mongo_uri, database_name, results_collection_name):
self.client = MongoClient(mongo_uri)
self.db = self.client[database_name]
self.minio_files_collection = self.db['minio_files']
self.results_collection = self.db[results_collection_name]
def get_unprocessed_videos(self):
# 查找 bucket_name 为 'raw-video' 且在结果集合中没有对应 etag 的视频
processed_etags = set(self.results_collection.distinct('etag'))
return self.minio_files_collection.find({
'bucket_name': 'raw',
'object_name': {'$regex': '/douyin/'},
'etag': {'$nin': list(processed_etags)}
})
def save_result(self, result):
# 检查是否已存在相同 etag 的结果
existing_result = self.results_collection.find_one({'etag': result['etag']})
if existing_result:
print(f"Video with etag {result['etag']} has already been processed. Skipping.")
return
# 将 ObjectId 转换为字符串
if 'video_id' in result and isinstance(result['video_id'], ObjectId):
result['video_id'] = str(result['video_id'])
self.results_collection.insert_one(result)
class JSONEncoder(json.JSONEncoder):
def default(self, o):
if isinstance(o, ObjectId):
return str(o)
return super().default(o)
# 视频处理模块
class VideoProcessor:
def __init__(self, model_dir):
self.model = AutoModel.from_pretrained(model_dir, trust_remote_code=True,
attn_implementation='sdpa', torch_dtype=torch.bfloat16).eval().cuda()
self.tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
self.MAX_NUM_FRAMES = 64
def encode_video(self, video_data):
def uniform_sample(l, n):
gap = len(l) / n
idxs = [int(i * gap + gap / 2) for i in range(n)]
return [l[i] for i in idxs]
video_file = io.BytesIO(video_data)
vr = VideoReader(video_file, ctx=cpu(0))
sample_fps = round(vr.get_avg_fps() / 1)
frame_idx = [i for i in range(0, len(vr), sample_fps)]
if len(frame_idx) > self.MAX_NUM_FRAMES:
frame_idx = uniform_sample(frame_idx, self.MAX_NUM_FRAMES)
frames = vr.get_batch(frame_idx).asnumpy()
frames = [Image.fromarray(v.astype('uint8')) for v in frames]
print('num frames:', len(frames))
return frames
def process_video(self, video_data, object_name):
frames = self.encode_video(video_data)
question = "Describe the video in as much detail as possible in Chinese, including the setting, clear number of people, and changes in behavior."
msgs = [
{'role': 'user', 'content': frames + [question]},
]
start_time, end_time = self.extract_time_from_filename(object_name)
params = {
"use_image_id": False,
"max_slice_nums": 1
}
answer = self.model.chat(
image=None,
msgs=msgs,
tokenizer=self.tokenizer,
**params
)
extracted_info = self.extract_info(answer)
return {
"original_answer": answer,
"extracted_info": extracted_info,
"num_frames": len(frames),
"start_time": start_time.strftime("%Y-%m-%d %H:%M:%S"),
"end_time": end_time.strftime("%Y-%m-%d %H:%M:%S")
}
@staticmethod
def extract_time_from_filename(object_name):
# 从 object_name 中提取文件名
filename = os.path.basename(object_name)
# 从文件名中提取日期时间部分
time_str = filename.split('_')[0] + '_' + filename.split('_')[1].split('.')[0]
try:
start_time = datetime.strptime(time_str, "%Y%m%d_%H%M%S")
end_time = start_time + timedelta(seconds=10)
return start_time, end_time
except ValueError:
print(f"无法从文件名 '{filename}' 解析时间。使用默认时间。")
return datetime.now(), datetime.now() + timedelta(seconds=10)
@staticmethod
def extract_info(answer):
info = {
"environment": None,
"num_people": None,
"actions": [],
"interactions": [],
"objects": [],
"furniture": []
}
# 环境提取
environments = ["办公室", "室内", "室外", "会议室"]
for env in environments:
if env in answer.lower():
info["environment"] = env
break
# 改进的人数提取
people_patterns = [
r'(\d+)\s*(人|个人|位|名|员工|用户|小朋友|成年人|女性|男性)',
r'(一|二|三|四|五|六|七|八|九|十)\s*(人|个人|位|名|员工|用户|小朋友|成年人|女性|男性)',
r'(一个|几个)\s*(人|个人|员工|用户|小朋友|成年人|女性|男性)',
r'\s*(名|位)\s*(人|员工|用户|小朋友|成年人|女性|男性)?',
r'(男|女)(性|生|士)',
r'(成年|未成年|青少年|老年)\s*(人|群体)',
r'(员工|职工|工人|学生|顾客|观众|游客|乘客)',
r'(群众|民众|大众|公众)',
r'(男女|老少|老幼|大人|小孩)'
]
for pattern in people_patterns:
match = re.search(pattern, answer)
if match:
if match.group(1).isdigit():
info["num_people"] = int(match.group(1))
elif match.group(1) in ['一个', '']:
info["num_people"] = 1
else:
num_word_to_digit = {
'': 2, '': 3, '': 4, '': 5,
'': 6, '': 7, '': 8, '': 9, '': 10
}
info["num_people"] = num_word_to_digit.get(match.group(1), 0)
break
# 动作和互动提取
actions = ["", "", "摔倒", "跳舞", "转身", "", "", "倒下", "躺下", "转身", "跳跃", "", "", "", "说话"]
for action in actions:
if action in answer:
info["actions"].append(action)
interactions = ["互动", "交流", "身体语言", "交谈", "讨论", "开会"]
for interaction in interactions:
if interaction in answer:
info["interactions"].append(interaction)
# 物体和家具提取
objects = ["水瓶", "办公用品", "文件", "电脑"]
furniture = ["椅子", "桌子", "咖啡桌", "文件柜", "", "沙发"]
for obj in objects:
if obj in answer:
info["objects"].append(obj)
for item in furniture:
if item in answer:
info["furniture"].append(item)
return info
# 主处理类
class VideoAnalysisSystem:
def __init__(self, minio_endpoint, minio_access_key, minio_secret_key,
mongo_uri, db_name, model_dir, results_collection_name):
self.minio_handler = MinioHandler(minio_endpoint, minio_access_key, minio_secret_key)
self.db_handler = DatabaseHandler(mongo_uri, db_name, results_collection_name)
self.video_processor = VideoProcessor(model_dir)
def run(self):
while True:
unprocessed_videos = list(self.db_handler.get_unprocessed_videos())
if not unprocessed_videos:
print("No new videos to process. Waiting for 60 seconds before checking again...")
time.sleep(1)
continue
for video_doc in unprocessed_videos:
try:
video_data = self.minio_handler.get_video_data(video_doc['bucket_name'], video_doc['object_name'])
result = self.video_processor.process_video(video_data, video_doc['object_name'])
# 添加额外信息到结果中
result['etag'] = video_doc['etag']
# result['video_id'] = str(video_doc['_id']) # 将 ObjectId 转换为字符串
result['bucket_name'] = video_doc['bucket_name']
result['object_name'] = video_doc['object_name']
# 保存结果到 MongoDB
self.db_handler.save_result(result)
print(f"Processed video: {video_doc['object_name']}")
# print(json.dumps(result, ensure_ascii=False, indent=2, cls=JSONEncoder))
except Exception as e:
print(f"Error processing video {video_doc['object_name']}: {str(e)}")
import traceback
traceback.print_exc() # 打印完整的错误堆栈
print("Finished processing current batch of videos. Waiting for new videos...")
time.sleep(30)
# 使用示例
if __name__ == "__main__":
minio_endpoint = "api.obscura.work"
minio_access_key = "MnHTAG2NOLyXXIZrwDLp"
minio_secret_key = "WVlmMgww0aRIU43pCJ1XCjubXQO6YsbHysxX2hBf"
mongo_uri = "mongodb://minio_mongo:BCd4npzKBnwmCRdh@222.186.136.78:27017/minio_mongo"
db_name = "minio_mongo"
results_collection_name = "douyin_results"
model_dir = "OpenBMB/MiniCPM-V-2_6"
system = VideoAnalysisSystem(minio_endpoint, minio_access_key, minio_secret_key,
mongo_uri, db_name, model_dir, results_collection_name)
system.run()
+121
View File
@@ -0,0 +1,121 @@
import io
import os
import tempfile
import time
from bson import ObjectId
from minio import Minio
from pymongo import MongoClient
import whisper
class MinioHandler:
def __init__(self, endpoint, access_key, secret_key):
self.client = Minio(
endpoint,
access_key=access_key,
secret_key=secret_key,
secure=True
)
def get_video_data(self, bucket, object_name):
response = self.client.get_object(bucket, object_name)
data = response.read()
print(f"Read {len(data)} bytes from Minio for {object_name}")
return data
class DatabaseHandler:
def __init__(self, mongo_uri, database_name, collection_name):
self.client = MongoClient(mongo_uri)
self.db = self.client[database_name]
self.collection = self.db[collection_name]
def get_unprocessed_videos(self):
return self.collection.find({
'bucket_name': 'raw',
'object_name': {'$regex': r'^douyin/.*/.+\.(mp4|avi|mov|flv)$'},
'whisper_transcription': {'$exists': False}
})
def update_transcription(self, video_id, transcription):
self.collection.update_one(
{'_id': video_id},
{'$set': {'whisper_transcription': transcription}}
)
class WhisperProcessor:
def __init__(self, model_name, model_path=None):
if model_path:
self.model = whisper.load_model(model_name, download_root=model_path)
else:
self.model = whisper.load_model(model_name)
def transcribe_audio(self, video_data):
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") as temp_video:
temp_video.write(video_data)
temp_video_path = temp_video.name
try:
result = self.model.transcribe(temp_video_path)
return result["text"]
finally:
os.unlink(temp_video_path)
class WhisperTranscriptionSystem:
def __init__(self, minio_endpoint, minio_access_key, minio_secret_key,
mongo_uri, db_name, collection_name, whisper_model_name, whisper_model_path=None):
self.minio_handler = MinioHandler(minio_endpoint, minio_access_key, minio_secret_key)
self.db_handler = DatabaseHandler(mongo_uri, db_name, collection_name)
self.whisper_processor = WhisperProcessor(whisper_model_name, whisper_model_path)
def process_video(self, video_doc):
start_time = time.time()
try:
video_data = self.minio_handler.get_video_data(video_doc['bucket_name'], video_doc['object_name'])
transcription = self.whisper_processor.transcribe_audio(video_data)
self.db_handler.update_transcription(video_doc['_id'], transcription)
end_time = time.time()
processing_time = end_time - start_time
print(f"Processed video: {video_doc['object_name']}")
print(f"Processing time: {processing_time:.2f} seconds")
except Exception as e:
end_time = time.time()
processing_time = end_time - start_time
print(f"Error processing video {video_doc['object_name']}: {str(e)}")
print(f"Processing time (including error): {processing_time:.2f} seconds")
import traceback
traceback.print_exc()
def run(self):
while True:
unprocessed_videos = list(self.db_handler.get_unprocessed_videos())
if not unprocessed_videos:
print("No new videos to process. Waiting for 60 seconds before checking again...")
time.sleep(60)
continue
for video_doc in unprocessed_videos:
self.process_video(video_doc)
print("Finished processing current batch of videos. Waiting for new videos...")
time.sleep(30)
if __name__ == "__main__":
minio_endpoint = "api.obscura.work"
minio_access_key = "MnHTAG2NOLyXXIZrwDLp"
minio_secret_key = "WVlmMgww0aRIU43pCJ1XCjubXQO6YsbHysxX2hBf"
mongo_uri = "mongodb://minio_mongo:BCd4npzKBnwmCRdh@222.186.136.78:27017/minio_mongo"
db_name = "minio_mongo"
collection_name = "douyin_results"
whisper_model_name = "large-v3" # 指定模型名称
whisper_model_path = "whisper" # 指定模型存放路径
system = WhisperTranscriptionSystem(minio_endpoint, minio_access_key, minio_secret_key,
mongo_uri, db_name, collection_name,
whisper_model_name, whisper_model_path)
system.run()
+114
View File
@@ -0,0 +1,114 @@
import io
import os
import tempfile
import time
from bson import ObjectId
from minio import Minio
from pymongo import MongoClient
import whisper
class MinioHandler:
def __init__(self, endpoint, access_key, secret_key):
self.client = Minio(
endpoint,
access_key=access_key,
secret_key=secret_key,
secure=True
)
def get_video_data(self, bucket, object_name):
response = self.client.get_object(bucket, object_name)
data = response.read()
print(f"Read {len(data)} bytes from Minio for {object_name}")
return data
class DatabaseHandler:
def __init__(self, mongo_uri, database_name, collection_name):
self.client = MongoClient(mongo_uri)
self.db = self.client[database_name]
self.collection = self.db[collection_name]
def get_unprocessed_videos(self):
return self.collection.find({
'bucket_name': 'raw',
'object_name': {'$regex': 'douyin/'},
'whisper_transcription': {'$exists': False}
})
def update_transcription(self, video_id, transcription):
self.collection.update_one(
{'_id': video_id},
{'$set': {'whisper_transcription': transcription}}
)
class WhisperProcessor:
def __init__(self, model_name="large-v3"):
self.model = whisper.load_model(model_name)
def transcribe_audio(self, video_data):
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") as temp_video:
temp_video.write(video_data)
temp_video_path = temp_video.name
try:
result = self.model.transcribe(temp_video_path)
return result["text"]
finally:
os.unlink(temp_video_path)
class WhisperTranscriptionSystem:
def __init__(self, minio_endpoint, minio_access_key, minio_secret_key,
mongo_uri, db_name, collection_name):
self.minio_handler = MinioHandler(minio_endpoint, minio_access_key, minio_secret_key)
self.db_handler = DatabaseHandler(mongo_uri, db_name, collection_name)
self.whisper_processor = WhisperProcessor()
def process_video(self, video_doc):
start_time = time.time()
try:
video_data = self.minio_handler.get_video_data(video_doc['bucket_name'], video_doc['object_name'])
transcription = self.whisper_processor.transcribe_audio(video_data)
self.db_handler.update_transcription(video_doc['_id'], transcription)
end_time = time.time()
processing_time = end_time - start_time
print(f"Processed video: {video_doc['object_name']}")
print(f"Processing time: {processing_time:.2f} seconds")
except Exception as e:
end_time = time.time()
processing_time = end_time - start_time
print(f"Error processing video {video_doc['object_name']}: {str(e)}")
print(f"Processing time (including error): {processing_time:.2f} seconds")
import traceback
traceback.print_exc()
def run(self):
while True:
unprocessed_videos = list(self.db_handler.get_unprocessed_videos())
if not unprocessed_videos:
print("No new videos to process. Waiting for 60 seconds before checking again...")
time.sleep(60)
continue
print(f"Found {len(unprocessed_videos)} videos to process.")
for video_doc in unprocessed_videos:
self.process_video(video_doc)
print("Finished processing current batch of videos. Checking for more...")
if __name__ == "__main__":
minio_endpoint = "api.obscura.work"
minio_access_key = "MnHTAG2NOLyXXIZrwDLp"
minio_secret_key = "WVlmMgww0aRIU43pCJ1XCjubXQO6YsbHysxX2hBf"
mongo_uri = "mongodb://minio_mongo:BCd4npzKBnwmCRdh@222.186.136.78:27017/minio_mongo"
db_name = "minio_mongo"
collection_name = "douyin_results"
system = WhisperTranscriptionSystem(minio_endpoint, minio_access_key, minio_secret_key,
mongo_uri, db_name, collection_name)
system.run()
+114
View File
@@ -0,0 +1,114 @@
import io
import os
import tempfile
import time
from bson import ObjectId
from minio import Minio
from pymongo import MongoClient
import whisper
class MinioHandler:
def __init__(self, endpoint, access_key, secret_key):
self.client = Minio(
endpoint,
access_key=access_key,
secret_key=secret_key,
secure=True
)
def get_video_data(self, bucket, object_name):
response = self.client.get_object(bucket, object_name)
data = response.read()
print(f"Read {len(data)} bytes from Minio for {object_name}")
return data
class DatabaseHandler:
def __init__(self, mongo_uri, database_name, collection_name):
self.client = MongoClient(mongo_uri)
self.db = self.client[database_name]
self.collection = self.db[collection_name]
def get_unprocessed_videos(self):
return self.collection.find({
'bucket_name': 'raw',
'object_name': {'$regex': 'douyin/'},
'whisper_transcription': {'$exists': False}
})
def update_transcription(self, video_id, transcription):
self.collection.update_one(
{'_id': video_id},
{'$set': {'whisper_transcription': transcription}}
)
class WhisperProcessor:
def __init__(self, model_name="large-v3"):
self.model = whisper.load_model(model_name)
def transcribe_audio(self, video_data):
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") as temp_video:
temp_video.write(video_data)
temp_video_path = temp_video.name
try:
result = self.model.transcribe(temp_video_path)
return result["text"]
finally:
os.unlink(temp_video_path)
class WhisperTranscriptionSystem:
def __init__(self, minio_endpoint, minio_access_key, minio_secret_key,
mongo_uri, db_name, collection_name):
self.minio_handler = MinioHandler(minio_endpoint, minio_access_key, minio_secret_key)
self.db_handler = DatabaseHandler(mongo_uri, db_name, collection_name)
self.whisper_processor = WhisperProcessor()
def process_video(self, video_doc):
start_time = time.time()
try:
video_data = self.minio_handler.get_video_data(video_doc['bucket_name'], video_doc['object_name'])
transcription = self.whisper_processor.transcribe_audio(video_data)
self.db_handler.update_transcription(video_doc['_id'], transcription)
end_time = time.time()
processing_time = end_time - start_time
print(f"Processed video: {video_doc['object_name']}")
print(f"Processing time: {processing_time:.2f} seconds")
except Exception as e:
end_time = time.time()
processing_time = end_time - start_time
print(f"Error processing video {video_doc['object_name']}: {str(e)}")
print(f"Processing time (including error): {processing_time:.2f} seconds")
import traceback
traceback.print_exc()
def run(self):
while True:
unprocessed_videos = list(self.db_handler.get_unprocessed_videos())
if not unprocessed_videos:
print("No new videos to process. Waiting for 60 seconds before checking again...")
time.sleep(60)
continue
print(f"Found {len(unprocessed_videos)} videos to process.")
for video_doc in unprocessed_videos:
self.process_video(video_doc)
print("Finished processing current batch of videos. Checking for more...")
if __name__ == "__main__":
minio_endpoint = "api.obscura.work"
minio_access_key = "MnHTAG2NOLyXXIZrwDLp"
minio_secret_key = "WVlmMgww0aRIU43pCJ1XCjubXQO6YsbHysxX2hBf"
mongo_uri = "mongodb://minio_mongo:BCd4npzKBnwmCRdh@222.186.136.78:27017/minio_mongo"
db_name = "minio_mongo"
collection_name = "douyin_results"
system = WhisperTranscriptionSystem(minio_endpoint, minio_access_key, minio_secret_key,
mongo_uri, db_name, collection_name)
system.run()
+256
View File
@@ -0,0 +1,256 @@
import torch
from PIL import Image
from transformers import AutoModel, AutoTokenizer
from decord import VideoReader, cpu
import json
import re
from pymongo import MongoClient
import io
from minio import Minio
import time
from bson import ObjectId
import concurrent.futures
import os
class MinioHandler:
def __init__(self, endpoint, access_key, secret_key, secure=True):
self.client = Minio(
endpoint,
access_key=access_key,
secret_key=secret_key,
secure=secure
)
def list_objects(self, bucket_name, prefix):
objects = self.client.list_objects(bucket_name, prefix=prefix, recursive=True)
return [obj for obj in objects if obj.object_name.lower().endswith(('.mp4', '.avi', '.mov', '.flv'))]
def get_video_data(self, bucket_name, object_name):
try:
response = self.client.get_object(bucket_name, object_name)
return response.read()
except Exception as e:
print(f"Error retrieving video data for {object_name}: {str(e)}")
return None
class DatabaseHandler:
def __init__(self, mongo_uri, database_name, results_collection_name):
self.client = MongoClient(mongo_uri)
self.db = self.client[database_name]
self.results_collection = self.db[results_collection_name]
def get_unprocessed_videos(self, minio_handler, bucket_name='raw', prefix='videoupload/'):
all_objects = minio_handler.list_objects(bucket_name, prefix)
processed_etags = set(self.results_collection.distinct('etag'))
unprocessed_videos = [
{
'bucket_name': bucket_name,
'object_name': obj.object_name,
'etag': obj.etag,
'size': obj.size,
'last_modified': obj.last_modified
}
for obj in all_objects if obj.etag not in processed_etags
]
return unprocessed_videos
def save_result(self, result):
existing_result = self.results_collection.find_one({'etag': result['etag']})
if existing_result:
print(f"Video with etag {result['etag']} has already been processed. Skipping.")
return
if 'video_id' in result and isinstance(result['video_id'], ObjectId):
result['video_id'] = str(result['video_id'])
self.results_collection.insert_one(result)
class JSONEncoder(json.JSONEncoder):
def default(self, o):
if isinstance(o, ObjectId):
return str(o)
return super().default(o)
class VideoProcessor:
def __init__(self, model_dir):
self.model = AutoModel.from_pretrained(model_dir, trust_remote_code=True,
attn_implementation='sdpa', torch_dtype=torch.bfloat16).eval().cuda()
self.tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
self.MAX_NUM_FRAMES = 12
def encode_video(self, video_data):
def uniform_sample(l, n):
gap = len(l) / n
return [l[int(i * gap + gap / 2)] for i in range(n)]
video_file = io.BytesIO(video_data)
vr = VideoReader(video_file, ctx=cpu(0))
sample_fps = round(vr.get_avg_fps() / 1)
frame_idx = list(range(0, len(vr), sample_fps))
if len(frame_idx) > self.MAX_NUM_FRAMES:
frame_idx = uniform_sample(frame_idx, self.MAX_NUM_FRAMES)
frames = vr.get_batch(frame_idx).asnumpy()
frames = [Image.fromarray(v.astype('uint8')) for v in frames]
print('num frames:', len(frames))
return frames
def process_video(self, video_data, object_name):
if not video_data:
raise ValueError(f"Empty video data for {object_name}")
print(f"Processing video: {object_name}, data size: {len(video_data)} bytes")
frames = self.encode_video(video_data)
question = "Describe the video in as much detail as possible in Chinese, including the setting, clear number of people, and changes in behavior."
msgs = [
{'role': 'user', 'content': frames + [question]},
]
params = {
"use_image_id": False,
"max_slice_nums": 1
}
answer = self.model.chat(
image=None,
msgs=msgs,
tokenizer=self.tokenizer,
**params
)
extracted_info = self.extract_info(answer)
return {
"original_answer": answer,
"extracted_info": extracted_info,
"num_frames": len(frames),
}
@staticmethod
def extract_info(answer):
info = {
"environment": None,
"num_people": None,
"actions": [],
"interactions": [],
"objects": [],
"furniture": []
}
environments = ["办公室", "室内", "室外", "会议室", "办公"]
for env in environments:
if env in answer.lower():
info["environment"] = env
break
people_patterns = [
r'(\d+)\s*(人|个人|位|名|员工|用户|小朋友|成年人|女性|男性)',
r'(一|二|三|四|五|六|七|八|九|十)\s*(人|个人|位|名|员工|用户|小朋友|成年人|女性|男性)',
r'(一个|几个)\s*(人|个人|员工|用户|小朋友|成年人|女性|男性)',
r'\s*(名|位)\s*(人|员工|用户|小朋友|成年人|女性|男性)?',
r'(男|女)(性|生|士)',
r'(成年|未成年|青少年|老年)\s*(人|群体)',
r'(员工|职工|工人|学生|顾客|观众|游客|乘客)',
r'(群众|民众|大众|公众)',
r'(男女|老少|老幼|大人|小孩)'
]
for pattern in people_patterns:
match = re.search(pattern, answer)
if match:
if match.group(1).isdigit():
info["num_people"] = int(match.group(1))
elif match.group(1) in ['一个', '']:
info["num_people"] = 1
else:
num_word_to_digit = {
'': 2, '': 3, '': 4, '': 5,
'': 6, '': 7, '': 8, '': 9, '': 10
}
info["num_people"] = num_word_to_digit.get(match.group(1), 0)
break
actions = ["", "", "摔倒", "跳舞", "转身", "", "", "倒下", "躺下", "转身", "跳跃", "", "", "", "说话"]
interactions = ["互动", "交流", "身体语言", "交谈", "讨论", "开会"]
objects = ["水瓶", "办公用品", "文件", "电脑"]
furniture = ["椅子", "桌子", "咖啡桌", "文件柜", "", "沙发"]
for action in actions:
if action in answer:
info["actions"].append(action)
for interaction in interactions:
if interaction in answer:
info["interactions"].append(interaction)
for obj in objects:
if obj in answer:
info["objects"].append(obj)
for item in furniture:
if item in answer:
info["furniture"].append(item)
return info
class VideoAnalysisSystem:
def __init__(self, minio_endpoint, minio_access_key, minio_secret_key,
mongo_uri, db_name, model_dir, results_collection_name):
self.minio_handler = MinioHandler(minio_endpoint, minio_access_key, minio_secret_key)
self.db_handler = DatabaseHandler(mongo_uri, db_name, results_collection_name)
self.video_processor = VideoProcessor(model_dir)
def process_video(self, video_doc):
start_time = time.time()
try:
video_data = self.minio_handler.get_video_data(video_doc['bucket_name'], video_doc['object_name'])
result = self.video_processor.process_video(video_data, video_doc['object_name'])
result['etag'] = video_doc['etag']
result['bucket_name'] = video_doc['bucket_name']
result['object_name'] = video_doc['object_name']
self.db_handler.save_result(result)
end_time = time.time()
processing_time = end_time - start_time
print(f"Processed video: {video_doc['object_name']}")
print(f"Processing time: {processing_time:.2f} seconds")
except Exception as e:
end_time = time.time()
processing_time = end_time - start_time
print(f"Error processing video {video_doc['object_name']}: {str(e)}")
print(f"Processing time (including error): {processing_time:.2f} seconds")
import traceback
traceback.print_exc()
def run(self):
while True:
unprocessed_videos = self.db_handler.get_unprocessed_videos(self.minio_handler)
if not unprocessed_videos:
print("No new videos to process. Waiting for 5 seconds before checking again...")
time.sleep(1)
continue
for video_doc in unprocessed_videos:
self.process_video(video_doc)
print("Finished processing current batch of videos. Waiting for new videos...")
time.sleep(1)
if __name__ == "__main__":
minio_endpoint = "api.obscura.work"
minio_access_key = "MnHTAG2NOLyXXIZrwDLp"
minio_secret_key = "WVlmMgww0aRIU43pCJ1XCjubXQO6YsbHysxX2hBf"
mongo_uri = "mongodb://minio_mongo:BCd4npzKBnwmCRdh@222.186.136.78:27017/minio_mongo"
db_name = "minio_mongo"
results_collection_name = "videoupload_results"
model_dir = "MiniCPM-V-2_6"
system = VideoAnalysisSystem(minio_endpoint, minio_access_key, minio_secret_key,
mongo_uri, db_name, model_dir, results_collection_name)
system.run()
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
+369
View File
@@ -0,0 +1,369 @@
import os
import json
import uuid
from datetime import datetime, timedelta
from fastapi import FastAPI, HTTPException, UploadFile, File
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from kafka import KafkaProducer, KafkaConsumer
from transformers import AutoTokenizer, AutoModel
from decord import VideoReader, cpu
from PIL import Image
from redis import Redis
import io
import re
import torch
import asyncio
from contextlib import asynccontextmanager
import threading
app = FastAPI()
cpm_app = FastAPI()
app.mount("/cpm", cpm_app)
# CORS设置
ALLOWED_ORIGINS = ['https://beta.obscura.work']
app.add_middleware(
CORSMiddleware,
allow_origins=ALLOWED_ORIGINS,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# 配置
MODEL_PATH = "/home/zydi/worker_sys/OpenBMB/MiniCPM-V-2_6"
KAFKA_BROKER = "222.186.10.253:9092"
KAFKA_TOPIC = "cpm"
KAFKA_GROUP_ID = "cpm_group"
REDIS_HOST = "222.186.10.253"
REDIS_PORT = 6379
REDIS_PASSWORD = "Obscura@2024"
REDIS_DB = 5
UPLOAD_DIR = "/www/wwwroot/beta.obscura.work/upload_files/upload"
RESULT_DIR = "/www/wwwroot/beta.obscura.work/upload_files/result"
MAX_FILE_AGE = timedelta(hours=1)
# 确保目录存在
os.makedirs(UPLOAD_DIR, exist_ok=True)
os.makedirs(RESULT_DIR, exist_ok=True)
# 初始化 Kafka
producer = KafkaProducer(bootstrap_servers=[KAFKA_BROKER])
consumer = KafkaConsumer(
KAFKA_TOPIC,
bootstrap_servers=[KAFKA_BROKER],
group_id=KAFKA_GROUP_ID,
auto_offset_reset='earliest',
enable_auto_commit=True,
value_deserializer=lambda x: json.loads(x.decode('utf-8'))
)
# 初始化 Redis
redis_client = Redis(
host=REDIS_HOST,
port=REDIS_PORT,
password=REDIS_PASSWORD,
db=REDIS_DB
)
# 设置 GPU 设备
torch.cuda.set_device(0)
# 初始化模型
model = AutoModel.from_pretrained(MODEL_PATH, trust_remote_code=True)
model = model.half().cuda().eval()
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True)
class MediaAnalysisSystem:
def __init__(self, model, tokenizer):
self.model = model
self.tokenizer = tokenizer
self.device = torch.device("cuda:0")
self.model = self.model.to(self.device)
self.MAX_NUM_FRAMES = 16
def encode_video(self, video_data):
def uniform_sample(l, n):
gap = len(l) / n
return [l[int(i * gap + gap / 2)] for i in range(n)]
video_file = io.BytesIO(video_data)
vr = VideoReader(video_file, ctx=cpu(0))
sample_fps = round(vr.get_avg_fps() / 1)
frame_idx = list(range(0, len(vr), sample_fps))
if len(frame_idx) > self.MAX_NUM_FRAMES:
frame_idx = uniform_sample(frame_idx, self.MAX_NUM_FRAMES)
frames = vr.get_batch(frame_idx).asnumpy()
frames = [Image.fromarray(v.astype('uint8')) for v in frames]
print('num frames:', len(frames))
return frames
def process_video(self, video_data, object_name):
if not video_data:
raise ValueError(f"Empty video data for {object_name}")
print(f"Processing video: {object_name}, data size: {len(video_data)} bytes")
frames = self.encode_video(video_data)
question = "Describe the video in as much detail as possible in Chinese, including the setting, clear number of people, and changes in behavior."
msgs = [
{'role': 'user', 'content': frames + [question]},
]
params = {
"use_image_id": False,
"max_slice_nums": 1
}
answer = self.model.chat(
image=frames, # 直接传递 frames
msgs=msgs,
tokenizer=self.tokenizer,
max_length=512,
temperature=0.7,
top_p=0.9,
**params
)
extracted_info = self.extract_info(answer)
return {
"original_answer": answer,
"extracted_info": extracted_info,
"num_frames": len(frames),
# "start_time": start_time.strftime("%Y-%m-%d %H:%M:%S"),
# "end_time": end_time.strftime("%Y-%m-%d %H:%M:%S")
}
def process_image(self, image_data, object_name):
image = Image.open(io.BytesIO(image_data))
question = "描述这张图片,包括场景、人物数量和行为等细节。"
msgs = [
{'role': 'user', 'content': [image] + [question]},
]
params = {
"use_image_id": False,
"max_slice_nums": 1
}
answer = self.model.chat(
image=None,
msgs=msgs,
tokenizer=self.tokenizer,
max_length=512,
temperature=0.7,
top_p=0.9,
**params
)
extracted_info = self.extract_info(answer)
return {
"original_answer": answer,
"extracted_info": extracted_info,
"timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S")
}
@staticmethod
def extract_time_from_filename(object_name):
filename = os.path.basename(object_name)
time_str = filename.split('_')[0] + '_' + filename.split('_')[1].split('.')[0]
try:
start_time = datetime.strptime(time_str, "%Y%m%d_%H%M%S")
end_time = start_time + timedelta(seconds=10)
return start_time, end_time
except ValueError:
print(f"无法从文件名 '{filename}' 解析时间。使用默认时间。")
return datetime.now(), datetime.now() + timedelta(seconds=10)
@staticmethod
def extract_info(answer):
info = {
"environment": None,
"num_people": None,
"actions": [],
"interactions": [],
"objects": [],
"furniture": []
}
environments = ["办公室", "室内", "室外", "会议室"]
for env in environments:
if env in answer.lower():
info["environment"] = env
break
people_patterns = [
r'(\d+)\s*(人|个人|位|名|员工|用户|小朋友|成年人|女性|男性)',
r'(一|二|三|四|五|六|七|八|九|十)\s*(人|个人|位|名|员工|用户|小朋友|成年人|女性|男性)',
r'(一个|几个)\s*(人|个人|员工|用户|小朋友|成年人|女性|男性)',
r'\s*(名|位)\s*(人|员工|用户|小朋友|成年人|女性|男性)?',
r'(男|女)(性|生|士)',
r'(成年|未成年|青少年|老年)\s*(人|群体)',
r'(员工|职工|工人|学生|顾客|观众|游客|乘客)',
r'(群众|民众|大众|公众)',
r'(男女|老少|老幼|大人|小孩)'
]
for pattern in people_patterns:
match = re.search(pattern, answer)
if match:
if match.group(1).isdigit():
info["num_people"] = int(match.group(1))
elif match.group(1) in ['一个', '']:
info["num_people"] = 1
else:
num_word_to_digit = {
'': 2, '': 3, '': 4, '': 5,
'': 6, '': 7, '': 8, '': 9, '': 10
}
info["num_people"] = num_word_to_digit.get(match.group(1), 0)
break
actions = ["", "", "摔倒", "跳舞", "转身", "", "", "倒下", "躺下", "转身", "跳跃", "", "", "", "说话"]
for action in actions:
if action in answer:
info["actions"].append(action)
interactions = ["互动", "交流", "身体语言", "交谈", "讨论", "开会"]
for interaction in interactions:
if interaction in answer:
info["interactions"].append(interaction)
objects = ["水瓶", "办公用品", "文件", "电脑"]
furniture = ["椅子", "桌子", "咖啡桌", "文件柜", "", "沙发"]
for obj in objects:
if obj in answer:
info["objects"].append(obj)
for item in furniture:
if item in answer:
info["furniture"].append(item)
return info
# 初始化 MediaAnalysisSystem
media_analysis_system = MediaAnalysisSystem(model, tokenizer)
async def process_file(file: UploadFile, file_type: str):
content = await file.read()
# 获取原始文件的后缀
original_extension = os.path.splitext(file.filename)[1]
# 生成新的文件名,包含 UUID 和原始后缀
filename = f"cpm_{uuid.uuid4()}{original_extension}"
file_path = os.path.join(UPLOAD_DIR, filename)
with open(file_path, "wb") as f:
f.write(content)
producer.send(KAFKA_TOPIC, json.dumps({
"filename": filename,
"type": file_type
}).encode('utf-8'))
redis_key = f"{file_type}_result:{filename}"
redis_client.set(redis_key, json.dumps({"status": "queued"}))
return {"message": f"{file_type.capitalize()} uploaded and queued for processing", "filename": filename}
@cpm_app.post("/upload")
async def upload_file(file: UploadFile = File(...)):
try:
file_type = "image" if file.content_type.startswith("image") else "video"
return await process_file(file, file_type)
except Exception as e:
return JSONResponse(content={"error": str(e)}, status_code=500)
@cpm_app.post("/analyze_video")
async def analyze_video(file: UploadFile = File(...)):
try:
return await process_file(file, "video")
except Exception as e:
return JSONResponse(content={"error": str(e)}, status_code=500)
@cpm_app.post("/analyze_image")
async def analyze_image(file: UploadFile = File(...)):
try:
return await process_file(file, "image")
except Exception as e:
return JSONResponse(content={"error": str(e)}, status_code=500)
def process_task():
for message in consumer:
try:
if isinstance(message.value, dict):
task = message.value
else:
task = json.loads(message.value.decode('utf-8'))
filename = task['filename']
file_type = task['type']
file_path = os.path.join(UPLOAD_DIR, filename)
redis_key = f"{file_type}_result:{filename}"
redis_client.set(redis_key, json.dumps({"status": "processing"}))
with open(file_path, 'rb') as f:
file_data = f.read()
if file_type == "video":
result = media_analysis_system.process_video(file_data, filename)
elif file_type == "image":
result = media_analysis_system.process_image(file_data, filename)
# 保存结果到 JSON 文件
result_file_path = os.path.join(RESULT_DIR, f"{filename}.json")
with open(result_file_path, 'w', encoding='utf-8') as f:
json.dump(result, f, ensure_ascii=False, indent=2)
# 将结果存储在 Redis 中
redis_client.set(redis_key, json.dumps({
"status": "completed",
"result": result
}))
except Exception as e:
print(f"Error processing task: {str(e)}")
if 'filename' in locals() and 'file_type' in locals():
redis_key = f"{file_type}_result:{filename}"
redis_client.set(redis_key, json.dumps({"status": "failed", "error": str(e)}))
else:
print("Error occurred before task details were extracted")
@cpm_app.get("/result/{filename}")
async def get_result(filename: str):
for file_type in ["video", "image"]:
redis_key = f"{file_type}_result:{filename}"
result = redis_client.get(redis_key)
if result:
result_json = json.loads(result)
if result_json.get("status") == "queued":
return {"status": "queued", "message": "Your request is in the queue and will be processed soon."}
elif result_json.get("status") == "processing":
return {"status": "processing", "message": "Your request is being processed."}
else:
return result_json
raise HTTPException(status_code=404, detail="Result not found")
async def listen_redis_changes():
pubsub = redis_client.pubsub()
pubsub.psubscribe('__keyspace@5__:*_result:*')
for message in pubsub.listen():
if message['type'] == 'pmessage':
key = message['channel'].decode('utf-8').split(':')[-1]
print(f"Key changed: {key}")
if __name__ == "__main__":
# 在后台线程中启动Kafka消费者
consumer_thread = threading.Thread(target=process_task, daemon=True)
consumer_thread.start()
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7000)
+518
View File
@@ -0,0 +1,518 @@
import os
import json
import uuid
from datetime import datetime, timedelta, timezone
from fastapi import FastAPI, HTTPException, UploadFile, File, Depends, Header
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from fastapi.security import APIKeyHeader
from kafka import KafkaProducer, KafkaConsumer
from transformers import AutoTokenizer, AutoModel
from decord import VideoReader, cpu
from PIL import Image
from redis import Redis
import io
import re
import torch
import asyncio
from contextlib import asynccontextmanager
import threading
app = FastAPI()
cpm_app = FastAPI()
app.mount("/cpm", cpm_app)
# CORS设置
ALLOWED_ORIGINS = ['https://beta.obscura.work']
app.add_middleware(
CORSMiddleware,
allow_origins=ALLOWED_ORIGINS,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# 配置
MODEL_PATH = "/home/zydi/worker_sys/OpenBMB/MiniCPM-V-2_6"
KAFKA_BROKER = "222.186.10.253:9092"
KAFKA_TOPIC = "cpm"
KAFKA_GROUP_ID = "cpm_group"
REDIS_HOST = "222.186.10.253"
REDIS_PORT = 6379
REDIS_PASSWORD = "Obscura@2024"
REDIS_DB = 5
REDIS_API_DB = 12
REDIS_API_USAGE_DB = 13
UPLOAD_DIR = "/www/wwwroot/beta.obscura.work/upload_files/upload"
RESULT_DIR = "/www/wwwroot/beta.obscura.work/upload_files/result"
MAX_FILE_AGE = timedelta(hours=1)
# 确保目录存在
os.makedirs(UPLOAD_DIR, exist_ok=True)
os.makedirs(RESULT_DIR, exist_ok=True)
# 初始化 Kafka
producer = KafkaProducer(bootstrap_servers=[KAFKA_BROKER])
consumer = KafkaConsumer(
KAFKA_TOPIC,
bootstrap_servers=[KAFKA_BROKER],
group_id=KAFKA_GROUP_ID,
auto_offset_reset='earliest',
enable_auto_commit=True,
value_deserializer=lambda x: json.loads(x.decode('utf-8'))
)
# 初始化 Redis
redis_client = Redis(
host=REDIS_HOST,
port=REDIS_PORT,
password=REDIS_PASSWORD,
db=REDIS_DB
)
redis_api_client = Redis(
host=REDIS_HOST,
port=REDIS_PORT,
password=REDIS_PASSWORD,
db=REDIS_API_DB
)
redis_api_usage_client = Redis(
host=REDIS_HOST,
port=REDIS_PORT,
password=REDIS_PASSWORD,
db=REDIS_API_USAGE_DB
)
# 添加API密钥验证
API_KEY_NAME = "X-API-Key"
api_key_header = APIKeyHeader(name=API_KEY_NAME, auto_error=False)
async def get_api_key(api_key: str = Header(None, alias=API_KEY_NAME)):
if api_key is None:
raise HTTPException(status_code=400, detail="API密钥缺失")
return api_key
async def verify_api_key(api_key: str = Depends(get_api_key)):
redis_key = f"api_key:{api_key}"
api_key_info = redis_api_client.hgetall(redis_key)
if not api_key_info:
return None
api_key_info = {k.decode('utf-8'): v.decode('utf-8') for k, v in api_key_info.items()}
if api_key_info.get('is_active') != '1':
return None
expires_at = datetime.fromisoformat(api_key_info.get('expires_at'))
if datetime.now(timezone.utc) > expires_at:
return None
usage_info = redis_api_usage_client.hgetall(redis_key)
usage_info = {k.decode('utf-8'): v.decode('utf-8') for k, v in usage_info.items()}
return {
**api_key_info,
**usage_info
}
async def update_token_usage(api_key: str, new_tokens_used: int, model_name: str):
redis_key = f"api_key:{api_key}"
current_time = datetime.now(timezone.utc).isoformat()
pipe = redis_api_usage_client.pipeline()
# 更新总的token使用量
pipe.hincrby(redis_key, "tokens_used", new_tokens_used)
pipe.hset(redis_key, "last_used_at", current_time)
# 更新特定模型的token使用量
model_tokens_field = f"{model_name}_tokens_used"
model_last_used_field = f"{model_name}_last_used_at"
pipe.hincrby(redis_key, model_tokens_field, new_tokens_used)
pipe.hset(redis_key, model_last_used_field, current_time)
pipe.execute()
def calculate_tokens(file_path: str, file_type: str) -> int:
base_tokens = 0
try:
file_size = os.path.getsize(file_path) # 获取文件大小(字节)
# 基础token:每MB文件大小消耗10个token
base_tokens = int((file_size / (1024 * 1024)) * 10)
if file_type == "image":
img = Image.open(file_path)
width, height = img.size
pixel_count = width * height
# 图片token:每100个像素额外消耗5个token
image_tokens = int((pixel_count / 10000) * 5)
base_tokens += image_tokens
elif file_type == "video":
vr = VideoReader(file_path)
fps = vr.get_avg_fps()
frame_count = len(vr)
width, height = vr[0].shape[1], vr[0].shape[0]
pixel_count = width * height * frame_count
duration = frame_count / fps # 视频时长(秒)
# 视频token:每100万像素每秒额外消耗1个token
video_tokens = int((pixel_count / 10000) * (duration / 60))
base_tokens += video_tokens
return max(1, base_tokens) # 确保至少返回1个token
except Exception as e:
print(f"计算token时出错: {str(e)}")
return 1 # 出错时返回默认值1
# 设置 GPU 设备
torch.cuda.set_device(0)
# 初始化模型
model = AutoModel.from_pretrained(MODEL_PATH, trust_remote_code=True)
model = model.half().cuda().eval()
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True)
class MediaAnalysisSystem:
def __init__(self, model, tokenizer):
self.model = model
self.tokenizer = tokenizer
self.device = torch.device("cuda:0")
self.model = self.model.to(self.device)
self.MAX_NUM_FRAMES = 16
def encode_video(self, video_data):
def uniform_sample(l, n):
gap = len(l) / n
return [l[int(i * gap + gap / 2)] for i in range(n)]
video_file = io.BytesIO(video_data)
vr = VideoReader(video_file, ctx=cpu(0))
sample_fps = round(vr.get_avg_fps() / 1)
frame_idx = list(range(0, len(vr), sample_fps))
if len(frame_idx) > self.MAX_NUM_FRAMES:
frame_idx = uniform_sample(frame_idx, self.MAX_NUM_FRAMES)
frames = vr.get_batch(frame_idx).asnumpy()
frames = [Image.fromarray(v.astype('uint8')) for v in frames]
print('num frames:', len(frames))
return frames
def process_video(self, video_data, object_name):
if not video_data:
raise ValueError(f"Empty video data for {object_name}")
print(f"Processing video: {object_name}, data size: {len(video_data)} bytes")
frames = self.encode_video(video_data)
question = "Describe the video in as much detail as possible in Chinese, including the setting, clear number of people, and changes in behavior."
msgs = [
{'role': 'user', 'content': frames + [question]},
]
params = {
"use_image_id": False,
"max_slice_nums": 1
}
answer = self.model.chat(
image=frames, # 直接传递 frames
msgs=msgs,
tokenizer=self.tokenizer,
max_length=512,
temperature=0.7,
top_p=0.9,
**params
)
extracted_info = self.extract_info(answer)
return {
"original_answer": answer,
"extracted_info": extracted_info,
"num_frames": len(frames),
# "start_time": start_time.strftime("%Y-%m-%d %H:%M:%S"),
# "end_time": end_time.strftime("%Y-%m-%d %H:%M:%S")
}
def process_image(self, image_data, object_name):
image = Image.open(io.BytesIO(image_data))
question = "描述这张图片,包括场景、人物数量和行为等细节。"
msgs = [
{'role': 'user', 'content': [image] + [question]},
]
params = {
"use_image_id": False,
"max_slice_nums": 1
}
answer = self.model.chat(
image=None,
msgs=msgs,
tokenizer=self.tokenizer,
max_length=512,
temperature=0.7,
top_p=0.9,
**params
)
extracted_info = self.extract_info(answer)
return {
"original_answer": answer,
"extracted_info": extracted_info,
"timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S")
}
@staticmethod
def extract_time_from_filename(object_name):
filename = os.path.basename(object_name)
time_str = filename.split('_')[0] + '_' + filename.split('_')[1].split('.')[0]
try:
start_time = datetime.strptime(time_str, "%Y%m%d_%H%M%S")
end_time = start_time + timedelta(seconds=10)
return start_time, end_time
except ValueError:
print(f"无法从文件名 '{filename}' 解析时间。使用默认时间。")
return datetime.now(), datetime.now() + timedelta(seconds=10)
@staticmethod
def extract_info(answer):
info = {
"environment": None,
"num_people": None,
"actions": [],
"interactions": [],
"objects": [],
"furniture": []
}
environments = ["办公室", "室内", "室外", "会议室"]
for env in environments:
if env in answer.lower():
info["environment"] = env
break
people_patterns = [
r'(\d+)\s*(人|个人|位|名|员工|用户|小朋友|成年人|女性|男性)',
r'(一|二|三|四|五|六|七|八|九|十)\s*(人|个人|位|名|员工|用户|小朋友|成年人|女性|男性)',
r'(一个|几个)\s*(人|个人|员工|用户|小朋友|成年人|女性|男性)',
r'\s*(名|位)\s*(人|员工|用户|小朋友|成年人|女性|男性)?',
r'(男|女)(性|生|士)',
r'(成年|未成年|青少年|老年)\s*(人|群体)',
r'(员工|职工|工人|学生|顾客|观众|游客|乘客)',
r'(群众|民众|大众|公众)',
r'(男女|老少|老幼|大人|小孩)'
]
for pattern in people_patterns:
match = re.search(pattern, answer)
if match:
if match.group(1).isdigit():
info["num_people"] = int(match.group(1))
elif match.group(1) in ['一个', '']:
info["num_people"] = 1
else:
num_word_to_digit = {
'': 2, '': 3, '': 4, '': 5,
'': 6, '': 7, '': 8, '': 9, '': 10
}
info["num_people"] = num_word_to_digit.get(match.group(1), 0)
break
actions = ["", "", "摔倒", "跳舞", "转身", "", "", "倒下", "躺下", "转身", "跳跃", "", "", "", "说话"]
for action in actions:
if action in answer:
info["actions"].append(action)
interactions = ["互动", "交流", "身体语言", "交谈", "讨论", "开会"]
for interaction in interactions:
if interaction in answer:
info["interactions"].append(interaction)
objects = ["水瓶", "办公用品", "文件", "电脑"]
furniture = ["椅子", "桌子", "咖啡桌", "文件柜", "", "沙发"]
for obj in objects:
if obj in answer:
info["objects"].append(obj)
for item in furniture:
if item in answer:
info["furniture"].append(item)
return info
# 初始化 MediaAnalysisSystem
media_analysis_system = MediaAnalysisSystem(model, tokenizer)
async def process_file(file: UploadFile, file_type: str, api_key: str):
content = await file.read()
original_extension = os.path.splitext(file.filename)[1]
filename = f"cpm_{uuid.uuid4()}{original_extension}"
file_path = os.path.join(UPLOAD_DIR, filename)
with open(file_path, "wb") as f:
f.write(content)
# 计算token
tokens_required = calculate_tokens(file_path, file_type)
# 检查并更新token使用量
usage_key = f"api_key:{api_key}"
total_tokens = int(redis_api_usage_client.hget(usage_key, "total_tokens") or 0)
tokens_used = int(redis_api_usage_client.hget(usage_key, "tokens_used") or 0)
if tokens_used + tokens_required > total_tokens:
raise HTTPException(status_code=403, detail="Token 余额不足")
# 更新token使用量
model_name = "MiniCPM-V-2_6"
await update_token_usage(api_key, tokens_required, model_name)
producer.send(KAFKA_TOPIC, json.dumps({
"filename": filename,
"type": file_type
}).encode('utf-8'))
redis_key = f"{file_type}_result:{filename}"
redis_client.set(redis_key, json.dumps({"status": "queued"}))
# 获取更新后的 token 使用情况
updated_api_key_info = await verify_api_key(api_key)
new_tokens_used = int(updated_api_key_info.get("tokens_used", 0))
model_tokens_used = int(updated_api_key_info.get(f"{model_name}_tokens_used", 0))
return JSONResponse(content={
"message": "文件已上传并排队等待处理",
"filename": filename,
"tokens_used": tokens_required,
"total_tokens_used": new_tokens_used,
f"{model_name}_tokens_used": model_tokens_used,
"tokens_remaining": total_tokens - new_tokens_used
})
@cpm_app.post("/upload")
async def upload_file(file: UploadFile = File(...), api_key: str = Depends(get_api_key)):
api_key_info = await verify_api_key(api_key)
if not api_key_info:
raise HTTPException(status_code=403, detail="无效的API密钥")
try:
file_type = "image" if file.content_type.startswith("image") else "video"
return await process_file(file, file_type, api_key)
except Exception as e:
return JSONResponse(content={"error": str(e)}, status_code=500)
@cpm_app.post("/analyze_video")
async def analyze_video(file: UploadFile = File(...), api_key: str = Depends(get_api_key)):
api_key_info = await verify_api_key(api_key)
if not api_key_info:
raise HTTPException(status_code=403, detail="无效的API密钥")
try:
return await process_file(file, "video", api_key)
except Exception as e:
return JSONResponse(content={"error": str(e)}, status_code=500)
@cpm_app.post("/analyze_image")
async def analyze_image(file: UploadFile = File(...), api_key: str = Depends(get_api_key)):
api_key_info = await verify_api_key(api_key)
if not api_key_info:
raise HTTPException(status_code=403, detail="无效的API密钥")
try:
return await process_file(file, "image", api_key)
except Exception as e:
return JSONResponse(content={"error": str(e)}, status_code=500)
def process_task():
for message in consumer:
try:
if isinstance(message.value, dict):
task = message.value
else:
task = json.loads(message.value.decode('utf-8'))
filename = task['filename']
file_type = task['type']
file_path = os.path.join(UPLOAD_DIR, filename)
redis_key = f"{file_type}_result:{filename}"
redis_client.set(redis_key, json.dumps({"status": "processing"}))
with open(file_path, 'rb') as f:
file_data = f.read()
if file_type == "video":
result = media_analysis_system.process_video(file_data, filename)
elif file_type == "image":
result = media_analysis_system.process_image(file_data, filename)
# 保存结果到 JSON 文件
result_file_path = os.path.join(RESULT_DIR, f"{filename}.json")
with open(result_file_path, 'w', encoding='utf-8') as f:
json.dump(result, f, ensure_ascii=False, indent=2)
# 将结果存储在 Redis 中
redis_client.set(redis_key, json.dumps({
"status": "completed",
"result": result
}))
except Exception as e:
print(f"Error processing task: {str(e)}")
if 'filename' in locals() and 'file_type' in locals():
redis_key = f"{file_type}_result:{filename}"
redis_client.set(redis_key, json.dumps({"status": "failed", "error": str(e)}))
else:
print("Error occurred before task details were extracted")
@cpm_app.get("/result/{filename}")
async def get_result(filename: str):
for file_type in ["video", "image"]:
redis_key = f"{file_type}_result:{filename}"
result = redis_client.get(redis_key)
if result:
result_json = json.loads(result)
if result_json.get("status") == "queued":
return {"status": "queued", "message": "Your request is in the queue and will be processed soon."}
elif result_json.get("status") == "processing":
return {"status": "processing", "message": "Your request is being processed."}
else:
return result_json
raise HTTPException(status_code=404, detail="Result not found")
async def listen_redis_changes():
pubsub = redis_client.pubsub()
pubsub.psubscribe('__keyspace@5__:*_result:*')
for message in pubsub.listen():
if message['type'] == 'pmessage':
key = message['channel'].decode('utf-8').split(':')[-1]
print(f"Key changed: {key}")
if __name__ == "__main__":
# 在后台线程中启动Kafka消费者
consumer_thread = threading.Thread(target=process_task, daemon=True)
consumer_thread.start()
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7000)

Some files were not shown because too many files have changed in this diff Show More