commit a3dcc7a619452dadd0377b46442d394c940eb7f5 Author: zydi Date: Sun Jan 12 06:15:15 2025 +0000 Initial commit diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..92b7a02 --- /dev/null +++ b/.gitignore @@ -0,0 +1,20 @@ +api_history/OpenBMB/* +!api_history/OpenBMB/.gitkeep + +chat_history/ChatTTS/* +!chat_history/ChatTTS/.gitkeep + +api_chat/OpenBMB/* +!api_chat/OpenBMB/.gitkeep + +api_chat/GPT_SoVITS/* +!api_chat/GPT_SoVITS/.gitkeep + +api_chat/tools/* +!api_chat/tools/.gitkeep + +api_chat/runtime/* +!api_chat/runtime/.gitkeep + +api_chat/docs/* +!api_chat/docs/.gitkeep diff --git a/api/compare.py b/api/compare.py new file mode 100644 index 0000000..eb3b29c --- /dev/null +++ b/api/compare.py @@ -0,0 +1,318 @@ +import os +import cv2 +import torch +import numpy as np +from redis import Redis +from ultralytics import YOLO +import json +from kafka import KafkaConsumer +import threading +import redis +import torch +from config import * +import insightface +from insightface.app import FaceAnalysis +from insightface.utils import face_align + + +torch.cuda.set_device(1) +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') +# 配置 + +KAFKA_BROKER = KAFKA_BROKER +KAFKA_TOPIC = WORKER_CONFIGS["compare"]["kafka_topic"] +KAFKA_GROUP_ID = f"compare_{KAFKA_GROUP_ID_PREFIX}" + +REDIS_HOST = REDIS_HOST +REDIS_PORT = REDIS_PORT +REDIS_PASSWORD = REDIS_PASSWORD +REDIS_DB = WORKER_CONFIGS["compare"]["redis_db"] # Worker使用的Redis DB +MAIN_REDIS_DB = MAIN_REDIS_DB # 主Redis DB + +UPLOAD_DIR = UPLOAD_DIR +RESULT_DIR = RESULT_DIR + +# 初始化 Kafka +consumer = KafkaConsumer( + KAFKA_TOPIC, + bootstrap_servers=[KAFKA_BROKER], + group_id=KAFKA_GROUP_ID, + auto_offset_reset='earliest', + enable_auto_commit=True, + value_deserializer=lambda x: json.loads(x.decode('utf-8')) +) + +# 初始化 Redis +redis_client = Redis( + host=REDIS_HOST, + port=REDIS_PORT, + password=REDIS_PASSWORD, + db=REDIS_DB +) + +main_redis_client = Redis( + host=REDIS_HOST, + port=REDIS_PORT, + password=REDIS_PASSWORD, + db=MAIN_REDIS_DB +) + + +class FaceComparator: + def __init__(self): + print("初始化 InsightFace...") + self.app = FaceAnalysis(name='buffalo_l', allowed_modules=['detection', 'recognition', 'genderage']) + self.app.prepare(ctx_id=0, det_size=(640, 640)) + print(f"InsightFace 初始化完成,使用设备: {device}") + + + def detect(self, frame): + """检测人脸并返回所有特征""" + try: + faces = self.app.get(frame) + return faces + except Exception as e: + return [] + + def format_results(self, faces): + """格式化检测结果""" + print("开始格式化检测结果...") + try: + formatted_results = [] + for i, face in enumerate(faces): + # 设置默认值,避免None导致的错误 + result = { + 'bbox': face.bbox.tolist(), + 'kps': face.kps.tolist(), + 'gender': int(face.gender) if hasattr(face, 'gender') and face.gender is not None else -1, + 'age': float(face.age) if hasattr(face, 'age') and face.age is not None else 0.0, + 'det_score': float(face.det_score), + 'embedding': face.embedding.tolist() + } + formatted_results.append(result) + + return formatted_results + except Exception as e: + print(f"格式化结果时出错: {str(e)}") + print(f"错误类型: {type(e)}") + import traceback + print(f"错误堆栈: {traceback.format_exc()}") + return [] + + def draw_results(self, frame, results): + """在图像上绘制检测结果""" + print("开始绘制检测结果...") + try: + annotated_frame = frame.copy() + + if not results: + print("没有检测结果,返回原始图像") + return annotated_frame + + for i, r in enumerate(results): + bbox = r['bbox'] + kps = r['kps'] + + cv2.rectangle(annotated_frame, + (int(bbox[0]), int(bbox[1])), + (int(bbox[2]), int(bbox[3])), + (0, 255, 0), 2) + + for x, y in kps: + cv2.circle(annotated_frame, + (int(x), int(y)), + 3, (255, 255, 0), -1) + + gender_text = 'Male' if r['gender'] == 1 else 'Female' if r['gender'] == 0 else 'Unknown' + label = f"Age: {int(r['age'])} Gender: {gender_text}" + cv2.putText(annotated_frame, + label, + (int(bbox[0]), int(bbox[1]-10)), + cv2.FONT_HERSHEY_SIMPLEX, + 0.5, + (0, 255, 0), + 2) + return annotated_frame + + except Exception as e: + print(f"绘制结果时出错: {str(e)}") + return frame.copy() + +detector = FaceComparator() + +def process_image(image_path): + try: + original_img = cv2.imread(image_path) + if original_img is None: + raise ValueError(f"无法读取图像文件: {image_path}") + + if original_img.size == 0: + raise ValueError("图像数据为空") + + faces = detector.detect(original_img) + + if not faces: + json_results = [] + else: + json_results = detector.format_results(faces) + if not json_results: # 如果格式化失败,使用空列表 + json_results = [] + + annotated_img = detector.draw_results(original_img, json_results) + + # 即使没有检测到人脸或格式化失败,也返回处理结果 + return json_results, annotated_img + + except Exception as e: + return None, None + +def process_video(video_path): + try: + cap = cv2.VideoCapture(video_path) + frame_count = 0 + json_results = [] + + fps = int(cap.get(cv2.CAP_PROP_FPS)) + width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + original_shape = (height, width) + + out = cv2.VideoWriter(video_path.replace(UPLOAD_DIR, RESULT_DIR), cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height)) + + while cap.isOpened(): + ret, frame = cap.read() + if not ret: + break + + if frame_count % fps == 0: + preprocessed_frame = preprocess_frame(frame) + + results = detector.detect(preprocessed_frame) + frame_json_results = detector.format_results(results, original_shape) + json_results.append({"frame": frame_count, "detections": frame_json_results}) + + annotated_frame = detector.draw_results(frame, frame_json_results) + out.write(annotated_frame) + + frame_count += 1 + + cap.release() + out.release() + + return json_results + except Exception as e: + print(f"处理视频时出错: {str(e)}") + return None + +def preprocess_frame(frame): + frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + frame_resized = cv2.resize(frame_rgb, (640, 640)) + frame_transposed = frame_resized.transpose((2, 0, 1)) + frame_contiguous = np.ascontiguousarray(frame_transposed) + frame_tensor = torch.from_numpy(frame_contiguous).float() + frame_normalized = frame_tensor.to(device) / 255.0 + frame_batched = frame_normalized.unsqueeze(0) + return frame_batched + +def process_task(): + print("开始处理任务,等待Kafka消息...") + for message in consumer: + print(f"收到Kafka消息: topic={message.topic}, partition={message.partition}, offset={message.offset}") + task = message.value + task_id = task['task_id'] + filename = task['filename'] + file_type = task['file_type'] + + print(f"解析任务信息: ID={task_id}, 文件名={filename}, 类型={file_type}") + + file_path = os.path.join(UPLOAD_DIR, filename) + # 检查键类型并更新状态 + task_key = f"task:{task_id}" + try: + key_type = main_redis_client.type(task_key) + if key_type != b'hash': + main_redis_client.delete(task_key) + main_redis_client.hset(task_key, "status", "processing") + print(f"任务 {task_id} 状态更新为 'processing'") + except redis.exceptions.ResponseError as e: + print(f"更新任务 {task_id} 状态时出错: {str(e)}") + continue # 跳过这个任务,继续处理下一个 + + try: + if file_type == "image": + print(f"开始处理图像: {filename}") + json_results, annotated_img = process_image(file_path) + if json_results and annotated_img is not None: + result_filename = f"compare_{filename}" + result_path = os.path.join(RESULT_DIR, result_filename) + cv2.imwrite(result_path, annotated_img) + + redis_client.hset(f"compare_result:{task_id}", + "result", json.dumps(json_results)) + redis_client.hset(f"compare_result:{task_id}", + "result_file", result_filename) + + main_redis_client.hset(f"task:{task_id}", "status", "completed") + main_redis_client.hset(f"task:{task_id}", "result_type", "compare") + main_redis_client.hset(f"task:{task_id}", "result_key", f"compare_result:{task_id}") + + print(f"图像 {filename} 处理完成,结果已保存") + else: + print(f"图像 {filename} 处理失败") + main_redis_client.hset(f"task:{task_id}", "status", "failed") + else: # video + print(f"开始处理视频: {filename}") + json_results = process_video(file_path) + if json_results: + result_filename = f"compare_{filename}" + redis_client.hset(f"compare_result:{task_id}", + "result", json.dumps(json_results)) + redis_client.hset(f"compare_result:{task_id}", + "result_file", result_filename) + + main_redis_client.hset(f"task:{task_id}", "status", "completed") + main_redis_client.hset(f"task:{task_id}", "result_type", "compare") + main_redis_client.hset(f"task:{task_id}", "result_key", f"compare_result:{task_id}") + + print(f"视频 {filename} 处理完成,结果已保存") + else: + print(f"视频 {filename} 处理失败") + main_redis_client.hset(f"task:{task_id}", "status", "failed") + except Exception as e: + print(f"处理任务 {task_id} 时出错: {str(e)}") + main_redis_client.hset(f"task:{task_id}", "status", "failed") + main_redis_client.hset(f"task:{task_id}", "error", str(e)) + + print(f"任务 {task_id} 处理完毕,等待下一个Kafka消息...") +def listen_redis_changes(): + pubsub = redis_client.pubsub() + pubsub.psubscribe('__keyspace@3__:compare_result:*') # 监听所有compare_result键的变化 + + for message in pubsub.listen(): + if message['type'] == 'pmessage': + key = message['channel'].decode('utf-8').split(':')[-1] + operation = message['data'].decode('utf-8') + + if operation == 'hmset': + value = redis_client.hgetall(f"compare_result:{key}") + if value: + result = {k.decode(): v.decode() for k, v in value.items()} + print(f"Result update for task {key}: {result}") + + +if __name__ == "__main__": + print("compare处理程序启动...") + # 启动处理任务的线程 + task_thread = threading.Thread(target=process_task, daemon=True) + task_thread.start() + print("任务处理线程已启动") + + # 启动Redis监听线程 + redis_thread = threading.Thread(target=listen_redis_changes, daemon=True) + redis_thread.start() + print("Redis监听线程已启动") + + print("主程序进入等待状态...") + # 保持主线程运行 + task_thread.join() + redis_thread.join() \ No newline at end of file diff --git a/api/config.py b/api/config.py new file mode 100644 index 0000000..fcf97d6 --- /dev/null +++ b/api/config.py @@ -0,0 +1,81 @@ +# config.py + +import os + +# Kafka配置 +KAFKA_BROKER = "222.186.10.253:9092" +KAFKA_GROUP_ID_PREFIX = "group" + +# Redis配置 +REDIS_HOST = "150.158.144.159" +REDIS_PORT = 13003 +REDIS_PASSWORD = "Obscura@2024" +MAIN_REDIS_DB = 0 +REDIS_API_DB = 2 +REDIS_API_USAGE_DB = 3 +# 目录配置 +UPLOAD_DIR = "/obscura/task/upload" +RESULT_DIR = "/obscura/task/result" + +# 确保目录存在 +os.makedirs(UPLOAD_DIR, exist_ok=True) +os.makedirs(RESULT_DIR, exist_ok=True) + +# 模型配置 +YOLO_MODEL_PATH = "/obscura/models/yolov8x.pt" +POSE_MODEL_PATH = "/obscura/models/yolov8x-pose.pt" +QWEN_MODEL_PATH = "/obscura/models/qwen/Qwen2-VL-2B-Instruct" +FALL_MODEL_PATH = "/obscura/models/yolov8n-fall.pt" +FACE_MODEL_PATH = "/obscura/models/yolov8n-face.pt" +MEDIAPIPE_MODEL_PATH = "/obscura/models/face_landmarker.task" +# COMPARE_MODEL_PATH = "/obscura/models/insightface/insw_r100_glint360k.onnx" +# Ollama配置 +OLLAMA_URL = "http://127.0.0.1:11434/api/generate" + +# 各个worker的配置 +WORKER_CONFIGS = { + "yolo": { + "kafka_topic": "yolo", + "redis_db": 4, + }, + "pose": { + "kafka_topic": "pose", + "redis_db": 5, + }, + "qwenvl": { + "kafka_topic": "qwenvl", + "redis_db": 9, + }, + "qwenvl_analyze": { + "kafka_topic": "qwenvl_analyze", + "redis_db": 32, + }, + "cpm": { + "kafka_topic": "cpm", + "redis_db": 8, + }, + "cpm_analyze": { + "kafka_topic": "cpm_analyze", + "redis_db": 31, + }, + "fall": { + "kafka_topic": "fall", + "redis_db": 6, + }, + "face": { + "kafka_topic": "face", + "redis_db": 7, + }, + "mediapipe": { + "kafka_topic": "mediapipe", + "redis_db": 10, + }, + "compare": { + "kafka_topic": "compare", + "redis_db": 30, + } +} + +# GPU设置 +CUDA_DEVICE_0 = "cuda:0" +CUDA_DEVICE_1 = "cuda:1" diff --git a/api/cpm_analyze.py b/api/cpm_analyze.py new file mode 100644 index 0000000..f2f5ca0 --- /dev/null +++ b/api/cpm_analyze.py @@ -0,0 +1,292 @@ +import os +import json +from datetime import datetime, timedelta +from kafka import KafkaConsumer +from decord import VideoReader, cpu +from PIL import Image +import redis +from redis import Redis +import io +import re +import threading +import requests +import base64 +import traceback +import json +from config import * + +# 配置 +OLLAMA_URL = OLLAMA_URL +KAFKA_BROKER = KAFKA_BROKER +KAFKA_TOPIC = WORKER_CONFIGS["cpm_analyze"]["kafka_topic"] +KAFKA_GROUP_ID = f"cpm_analyze_{KAFKA_GROUP_ID_PREFIX}" + +REDIS_HOST = REDIS_HOST +REDIS_PORT = REDIS_PORT +REDIS_PASSWORD = REDIS_PASSWORD +REDIS_DB = WORKER_CONFIGS["cpm_analyze"]["redis_db"] # Worker使用的Redis DB +MAIN_REDIS_DB = MAIN_REDIS_DB # 主Redis DB + +UPLOAD_DIR = UPLOAD_DIR +RESULT_DIR = RESULT_DIR + +# 确保目录存在 +os.makedirs(UPLOAD_DIR, exist_ok=True) +os.makedirs(RESULT_DIR, exist_ok=True) + +# 初始化 Kafka +consumer = KafkaConsumer( + KAFKA_TOPIC, + bootstrap_servers=[KAFKA_BROKER], + group_id=KAFKA_GROUP_ID, + auto_offset_reset='earliest', + enable_auto_commit=True, + value_deserializer=lambda x: json.loads(x.decode('utf-8')) +) + +# 初始化 Redis +redis_client = Redis( + host=REDIS_HOST, + port=REDIS_PORT, + password=REDIS_PASSWORD, + db=REDIS_DB +) + +main_redis_client = Redis( + host=REDIS_HOST, + port=REDIS_PORT, + password=REDIS_PASSWORD, + db=MAIN_REDIS_DB +) + +# 设置 GPU 设备 +# torch.cuda.set_device(0) + +class MediaAnalysisSystem: + def __init__(self): + self.MAX_NUM_FRAMES = 16 + + def encode_video(self, video_data): + def uniform_sample(l, n): + gap = len(l) / n + return [l[int(i * gap + gap / 2)] for i in range(n)] + + video_file = io.BytesIO(video_data) + vr = VideoReader(video_file, ctx=cpu(0)) + sample_fps = round(vr.get_avg_fps() / 1) + frame_idx = list(range(0, len(vr), sample_fps)) + if len(frame_idx) > self.MAX_NUM_FRAMES: + frame_idx = uniform_sample(frame_idx, self.MAX_NUM_FRAMES) + frames = vr.get_batch(frame_idx).asnumpy() + frames = [Image.fromarray(v.astype('uint8')) for v in frames] + print('num frames:', len(frames)) + return frames + + def process_video(self, video_data, object_name): + if not video_data: + raise ValueError(f"Empty video data for {object_name}") + print(f"Processing video: {object_name}, data size: {len(video_data)} bytes") + frames = self.encode_video(video_data) + question = """您是一个高级OCR和文本分析助手。您的主要任务是: + 1) 从该视频中高精度提取所有文本内容,包括标准文本和数字数据, + 2) 对提取的内容进行全面分析, + 3) 将信息进行逻辑性的组织和结构化, + 4) 提供从文本/数字中发现的详细见解。 + + 对于数字数据,请包含统计分析。以清晰的层次格式呈现您的发现,请提供完整的原始文本。如果遇到任何不清楚或模糊的元素,请突出显示并要求澄清。 + 请支持: + - 多种语言(简体中文、繁体中文、英文) + - 不同的文本方向和布局 + - 表格、图表和结构化数据格式 + - 特殊字符、符号和数学符号 + - 原始格式和文本位置信息""" + + encoded_frames = [self.image_to_base64(frame) for frame in frames] + + payload = { + "model": "minicpm-v", + "prompt": question, + "images": encoded_frames + } + + try: + response = requests.post(OLLAMA_URL, json=payload, stream=True) + print(f"Ollama API 响应状态码: {response.status_code}") + print(f"Ollama API 响应头: {response.headers}") + + if response.status_code == 200: + answer = self.process_stream_response(response) + else: + raise Exception(f"Ollama API 错误: {response.status_code}") + except requests.RequestException as e: + print(f"请求 Ollama API 时出错: {str(e)}") + raise + + return { + "original_answer": answer, + "num_frames": len(frames), + } + + def process_image(self, image_data, object_name): + image = Image.open(io.BytesIO(image_data)) + question = """您是一个高级OCR和文本分析助手。您的主要任务是: + 1) 从该图片中高精度提取所有文本内容,包括标准文本和数字数据, + 2) 对提取的内容进行全面分析, + 3) 将信息进行逻辑性的组织和结构化, + 4) 提供从文本/数字中发现的详细见解。 + + 对于数字数据,请包含统计分析。以清晰的层次格式呈现您的发现,请提供完整的原始文本。如果遇到任何不清楚或模糊的元素,请突出显示并要求澄清。 + 请支持: + - 多种语言(简体中文、繁体中文、英文) + - 不同的文本方向和布局 + - 表格、图表和结构化数据格式 + - 特殊字符、符号和数学符号 + - 原始格式和文本位置信息""" + + encoded_image = self.image_to_base64(image) + + payload = { + "model": "minicpm-v", + "prompt": question, + "images": [encoded_image] + } + + try: + response = requests.post(OLLAMA_URL, json=payload, stream=True) + + if response.status_code == 200: + answer = self.process_stream_response(response) + else: + raise Exception(f"Ollama API 错误: {response.status_code}") + + except requests.RequestException as e: + print(f"请求 Ollama API 时出错: {str(e)}") + raise + + return { + "original_answer": answer, + "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S") + } + + def process_stream_response(self, response): + full_response = [] + for line in response.iter_lines(): + if line: + try: + json_response = json.loads(line) + if 'response' in json_response: + full_response.append(json_response['response']) + if json_response.get('done', False): + break + except json.JSONDecodeError: + print(f"无法解析 JSON 行: {line}") + return ''.join(full_response) + + @staticmethod + def image_to_base64(image): + buffered = io.BytesIO() + image.save(buffered, format="PNG") + return base64.b64encode(buffered.getvalue()).decode() + + @staticmethod + def extract_time_from_filename(object_name): + filename = os.path.basename(object_name) + time_str = filename.split('_')[0] + '_' + filename.split('_')[1].split('.')[0] + + try: + start_time = datetime.strptime(time_str, "%Y%m%d_%H%M%S") + end_time = start_time + timedelta(seconds=10) + return start_time, end_time + except ValueError: + print(f"无法从文件名 '{filename}' 解析时间。使用默认时间。") + return datetime.now(), datetime.now() + timedelta(seconds=10) + + +# 初始化 MediaAnalysisSystem +media_analysis_system = MediaAnalysisSystem() +def process_task(): + for message in consumer: + print(f"收到Kafka消息: topic={message.topic}, partition={message.partition}, offset={message.offset}") + task = message.value + task_id = task['task_id'] + filename = task['filename'] + file_type = task['file_type'] + + print(f"解析任务信息: ID={task_id}, filename={filename}, type={file_type}") + + file_path = os.path.join(UPLOAD_DIR, filename) + task_key = f"task:{task_id}" + + try: + key_type = main_redis_client.type(task_key) + if key_type != b'hash': + main_redis_client.delete(task_key) + main_redis_client.hset(task_key, "status", "processing") + print(f"任务 {task_id} 状态更新为 'processing'") + + if file_type == "image": + print(f"Processing image: {filename}") + with open(file_path, 'rb') as f: + image_data = f.read() + result = media_analysis_system.process_image(image_data, filename) + else: # video + print(f"Processing video: {filename}") + with open(file_path, 'rb') as f: + video_data = f.read() + result = media_analysis_system.process_video(video_data, filename) + + if result: + redis_client.hset(f"cpm_analyze_result:{task_id}", mapping={ + "result": json.dumps(result), + "result_file": filename + }) + main_redis_client.hset(task_key, mapping={ + "status": "completed", + "result_type": "cpm_analyze", + "result_key": f"cpm_analyze_result:{task_id}" + }) + print(f"{file_type.capitalize()} {filename} 已处理,结果已保存") + else: + print(f"{file_type.capitalize()} {filename} 处理失败") + main_redis_client.hset(task_key, "status", "failed") + except Exception as e: + print(f"处理任务 {task_id} 时出错: {str(e)}") + print(f"错误详情: {traceback.format_exc()}") + main_redis_client.hset(task_key, mapping={ + "status": "failed", + "error": str(e) + }) + + print(f"任务 {task_id} 处理完毕,等待下一个Kafka消息...") + +def listen_redis_changes(): + pubsub = redis_client.pubsub() + pubsub.psubscribe('__keyspace@3__:cpm_analyze_result:*') # Listen for changes on all cpm_result keys + + for message in pubsub.listen(): + if message['type'] == 'pmessage': + key = message['channel'].decode('utf-8').split(':')[-1] + operation = message['data'].decode('utf-8') + + if operation == 'hset': + value = redis_client.hgetall(f"cpm_analyze_result:{key}") + if value: + result = {k.decode(): v.decode() for k, v in value.items()} + print(f"Result update for task {key}: {result}") + +if __name__ == "__main__": + print("cpm_analyze处理程序启动...") + # Start the task processing thread + task_thread = threading.Thread(target=process_task, daemon=True) + task_thread.start() + print("任务处理线程已启动") + + # Start the Redis listening thread + redis_thread = threading.Thread(target=listen_redis_changes, daemon=True) + redis_thread.start() + print("Redis监听线程已启动") + + print("主程序进入等待状态...") + # Keep the main thread running + task_thread.join() + redis_thread.join() \ No newline at end of file diff --git a/api/cpm_scene.py b/api/cpm_scene.py new file mode 100644 index 0000000..09ba2d4 --- /dev/null +++ b/api/cpm_scene.py @@ -0,0 +1,367 @@ +import os +import json +from datetime import datetime, timedelta +from kafka import KafkaConsumer +from decord import VideoReader, cpu +from PIL import Image +import redis +from redis import Redis +import io +import re +import threading +import requests +import base64 +import traceback +import json +from config import * + +# 配置 +OLLAMA_URL = OLLAMA_URL +KAFKA_BROKER = KAFKA_BROKER +KAFKA_TOPIC = WORKER_CONFIGS["cpm"]["kafka_topic"] +KAFKA_GROUP_ID = f"cpm_{KAFKA_GROUP_ID_PREFIX}" + +REDIS_HOST = REDIS_HOST +REDIS_PORT = REDIS_PORT +REDIS_PASSWORD = REDIS_PASSWORD +REDIS_DB = WORKER_CONFIGS["cpm"]["redis_db"] # Worker使用的Redis DB +MAIN_REDIS_DB = MAIN_REDIS_DB # 主Redis DB + +UPLOAD_DIR = UPLOAD_DIR +RESULT_DIR = RESULT_DIR + +# 确保目录存在 +os.makedirs(UPLOAD_DIR, exist_ok=True) +os.makedirs(RESULT_DIR, exist_ok=True) + +# 初始化 Kafka +consumer = KafkaConsumer( + KAFKA_TOPIC, + bootstrap_servers=[KAFKA_BROKER], + group_id=KAFKA_GROUP_ID, + auto_offset_reset='earliest', + enable_auto_commit=True, + value_deserializer=lambda x: json.loads(x.decode('utf-8')) +) + +# 初始化 Redis +redis_client = Redis( + host=REDIS_HOST, + port=REDIS_PORT, + password=REDIS_PASSWORD, + db=REDIS_DB +) + +main_redis_client = Redis( + host=REDIS_HOST, + port=REDIS_PORT, + password=REDIS_PASSWORD, + db=MAIN_REDIS_DB +) + +# 设置 GPU 设备 +# torch.cuda.set_device(0) + +class MediaAnalysisSystem: + def __init__(self): + self.MAX_NUM_FRAMES = 16 + + def encode_video(self, video_data): + def uniform_sample(l, n): + gap = len(l) / n + return [l[int(i * gap + gap / 2)] for i in range(n)] + + video_file = io.BytesIO(video_data) + vr = VideoReader(video_file, ctx=cpu(0)) + sample_fps = round(vr.get_avg_fps() / 1) + frame_idx = list(range(0, len(vr), sample_fps)) + if len(frame_idx) > self.MAX_NUM_FRAMES: + frame_idx = uniform_sample(frame_idx, self.MAX_NUM_FRAMES) + frames = vr.get_batch(frame_idx).asnumpy() + frames = [Image.fromarray(v.astype('uint8')) for v in frames] + print('num frames:', len(frames)) + return frames + + def process_video(self, video_data, object_name): + if not video_data: + raise ValueError(f"Empty video data for {object_name}") + print(f"Processing video: {object_name}, data size: {len(video_data)} bytes") + frames = self.encode_video(video_data) + question = """请对这段监控视频进行详细分析,包括以下方面: + 1. 场景中人数的精确统计 + 2. 每个人的个人行为分析 + 3. 面部表情识别和情绪状态评估 + 4. 整体场景和环境的详细描述 + 5. 人与人之间的互动情况 + 6. 时间和环境条件(如果可见) + 7. 任何可疑或异常活动 + 8. 人员的具体特征(估计年龄范围、性别、着装) + 9. 人员的移动模式和方向 + 10. 携带的物品或物体 + 11. 群体动态和聚集情况 + 12. 视频中的时间戳分析(如果有) + + 请用清晰、有条理的格式描述,并突出重要发现。""" + + encoded_frames = [self.image_to_base64(frame) for frame in frames] + + payload = { + "model": "minicpm-v", + "prompt": question, + "images": encoded_frames + } + + try: + response = requests.post(OLLAMA_URL, json=payload, stream=True) + print(f"Ollama API 响应状态码: {response.status_code}") + print(f"Ollama API 响应头: {response.headers}") + + if response.status_code == 200: + answer = self.process_stream_response(response) + else: + raise Exception(f"Ollama API 错误: {response.status_code}") + except requests.RequestException as e: + print(f"请求 Ollama API 时出错: {str(e)}") + raise + + extracted_info = self.extract_info(answer) + + return { + "original_answer": answer, + "extracted_info": extracted_info, + "num_frames": len(frames), + } + + def process_image(self, image_data, object_name): + image = Image.open(io.BytesIO(image_data)) + question = """请对这张监控图像进行详细分析,包括以下方面: + 1. 场景中人数的精确统计 + 2. 每个人的个人行为分析 + 3. 面部表情识别和情绪状态评估 + 4. 整体场景和环境的详细描述 + 5. 人与人之间的互动情况 + 6. 时间和环境条件(如果可见) + 7. 任何可疑或异常活动 + 8. 人员的具体特征(估计年龄范围、性别、着装) + 9. 人员的位置和姿态 + 10. 携带的物品或物体 + 11. 群体动态和聚集情况 + 12. 图像中的时间戳信息(如果有) + + 请用清晰、有条理的格式描述,并突出重要发现。""" + + encoded_image = self.image_to_base64(image) + + payload = { + "model": "minicpm-v", + "prompt": question, + "images": [encoded_image] + } + + try: + response = requests.post(OLLAMA_URL, json=payload, stream=True) + + if response.status_code == 200: + answer = self.process_stream_response(response) + else: + raise Exception(f"Ollama API 错误: {response.status_code}") + + except requests.RequestException as e: + print(f"请求 Ollama API 时出错: {str(e)}") + raise + + extracted_info = self.extract_info(answer) + + return { + "original_answer": answer, + "extracted_info": extracted_info, + "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S") + } + + def process_stream_response(self, response): + full_response = [] + for line in response.iter_lines(): + if line: + try: + json_response = json.loads(line) + if 'response' in json_response: + full_response.append(json_response['response']) + if json_response.get('done', False): + break + except json.JSONDecodeError: + print(f"无法解析 JSON 行: {line}") + return ''.join(full_response) + + @staticmethod + def image_to_base64(image): + buffered = io.BytesIO() + image.save(buffered, format="PNG") + return base64.b64encode(buffered.getvalue()).decode() + + @staticmethod + def extract_time_from_filename(object_name): + filename = os.path.basename(object_name) + time_str = filename.split('_')[0] + '_' + filename.split('_')[1].split('.')[0] + + try: + start_time = datetime.strptime(time_str, "%Y%m%d_%H%M%S") + end_time = start_time + timedelta(seconds=10) + return start_time, end_time + except ValueError: + print(f"无法从文件名 '{filename}' 解析时间。使用默认时间。") + return datetime.now(), datetime.now() + timedelta(seconds=10) + + @staticmethod + def extract_info(answer): + info = { + "environment": None, + "num_people": None, + "actions": [], + "interactions": [], + "objects": [], + "furniture": [] + } + + environments = ["办公室", "室内", "��外", "会议室"] + for env in environments: + if env in answer.lower(): + info["environment"] = env + break + + people_patterns = [ + r'(\d+)\s*(人|个人|位|名|员工|用户|小朋友|成年人|女性|男性)', + r'(一|二|三|四|五|六|七|八|九|十)\s*(人|个人|位|名|员工|用户|小朋友|成年人|女性|男性)', + r'(一个|几个)\s*(人|个人|员工|用户|小朋友|成年人|女性|男性)', + r'几\s*(名|位)\s*(人|员工|用户|小朋友|成年人|女性|男性)?', + r'(男|女)(性|生|士)', + r'(成年|未成年|青少年|老年)\s*(人|群体)', + r'(员工|职工|工人|学生|顾客|观众|游客|乘客)', + r'(群众|民众|大众|公众)', + r'(男女|老少|老幼|大人|小孩)' + ] + for pattern in people_patterns: + match = re.search(pattern, answer) + if match: + if match.group(1).isdigit(): + info["num_people"] = int(match.group(1)) + elif match.group(1) in ['一个', '一']: + info["num_people"] = 1 + else: + num_word_to_digit = { + '二': 2, '三': 3, '四': 4, '五': 5, + '六': 6, '七': 7, '八': 8, '九': 9, '十': 10 + } + info["num_people"] = num_word_to_digit.get(match.group(1), 0) + break + + actions = ["坐", "站", "摔倒", "跳舞", "转身", "摔", "倒", "倒下", "躺下", "转身", "跳跃", "跳", "躺", "睡", "说话"] + for action in actions: + if action in answer: + info["actions"].append(action) + + interactions = ["互动", "交流", "身体语言", "交谈", "讨论", "开会"] + for interaction in interactions: + if interaction in answer: + info["interactions"].append(interaction) + + objects = ["水瓶", "办公用品", "文件", "电脑"] + furniture = ["椅子", "桌子", "咖啡桌", "文件柜", "床", "沙发"] + + for obj in objects: + if obj in answer: + info["objects"].append(obj) + + for item in furniture: + if item in answer: + info["furniture"].append(item) + + return info + +# 初始化 MediaAnalysisSystem +media_analysis_system = MediaAnalysisSystem() +def process_task(): + for message in consumer: + print(f"收到Kafka消息: topic={message.topic}, partition={message.partition}, offset={message.offset}") + task = message.value + task_id = task['task_id'] + filename = task['filename'] + file_type = task['file_type'] + + print(f"解析任务信息: ID={task_id}, filename={filename}, type={file_type}") + + file_path = os.path.join(UPLOAD_DIR, filename) + task_key = f"task:{task_id}" + + try: + key_type = main_redis_client.type(task_key) + if key_type != b'hash': + main_redis_client.delete(task_key) + main_redis_client.hset(task_key, "status", "processing") + print(f"任务 {task_id} 状态更新为 'processing'") + + if file_type == "image": + print(f"Processing image: {filename}") + with open(file_path, 'rb') as f: + image_data = f.read() + result = media_analysis_system.process_image(image_data, filename) + else: # video + print(f"Processing video: {filename}") + with open(file_path, 'rb') as f: + video_data = f.read() + result = media_analysis_system.process_video(video_data, filename) + + if result: + redis_client.hset(f"cpm_result:{task_id}", mapping={ + "result": json.dumps(result), + "result_file": filename + }) + main_redis_client.hset(task_key, mapping={ + "status": "completed", + "result_type": "cpm", + "result_key": f"cpm_result:{task_id}" + }) + print(f"{file_type.capitalize()} {filename} 已处理,结果已保存") + else: + print(f"{file_type.capitalize()} {filename} 处理失败") + main_redis_client.hset(task_key, "status", "failed") + except Exception as e: + print(f"处理任务 {task_id} 时出错: {str(e)}") + print(f"错误详情: {traceback.format_exc()}") + main_redis_client.hset(task_key, mapping={ + "status": "failed", + "error": str(e) + }) + + print(f"任务 {task_id} 处理完毕,等待下一个Kafka消息...") + +def listen_redis_changes(): + pubsub = redis_client.pubsub() + pubsub.psubscribe('__keyspace@3__:cpm_result:*') # Listen for changes on all cpm_result keys + + for message in pubsub.listen(): + if message['type'] == 'pmessage': + key = message['channel'].decode('utf-8').split(':')[-1] + operation = message['data'].decode('utf-8') + + if operation == 'hset': + value = redis_client.hgetall(f"cpm_result:{key}") + if value: + result = {k.decode(): v.decode() for k, v in value.items()} + print(f"Result update for task {key}: {result}") + +if __name__ == "__main__": + print("cpm处理程序启动...") + # Start the task processing thread + task_thread = threading.Thread(target=process_task, daemon=True) + task_thread.start() + print("任务处理线程已启动") + + # Start the Redis listening thread + redis_thread = threading.Thread(target=listen_redis_changes, daemon=True) + redis_thread.start() + print("Redis监听线程已启动") + + print("主程序进入等待状态...") + # Keep the main thread running + task_thread.join() + redis_thread.join() \ No newline at end of file diff --git a/api/face.py b/api/face.py new file mode 100644 index 0000000..463001a --- /dev/null +++ b/api/face.py @@ -0,0 +1,290 @@ +import os +import cv2 +import torch +import numpy as np +from redis import Redis +from ultralytics import YOLO +import json +from kafka import KafkaConsumer +import threading +import redis +import torch +from config import * +torch.cuda.set_device(1) + +# 配置 +MODEL_PATH = FACE_MODEL_PATH +KAFKA_BROKER = KAFKA_BROKER +KAFKA_TOPIC = WORKER_CONFIGS["face"]["kafka_topic"] +KAFKA_GROUP_ID = f"face_{KAFKA_GROUP_ID_PREFIX}" + +REDIS_HOST = REDIS_HOST +REDIS_PORT = REDIS_PORT +REDIS_PASSWORD = REDIS_PASSWORD +REDIS_DB = WORKER_CONFIGS["face"]["redis_db"] # Worker使用的Redis DB +MAIN_REDIS_DB = MAIN_REDIS_DB # 主Redis DB + +UPLOAD_DIR = UPLOAD_DIR +RESULT_DIR = RESULT_DIR + +# 初始化 Kafka +consumer = KafkaConsumer( + KAFKA_TOPIC, + bootstrap_servers=[KAFKA_BROKER], + group_id=KAFKA_GROUP_ID, + auto_offset_reset='earliest', + enable_auto_commit=True, + value_deserializer=lambda x: json.loads(x.decode('utf-8')) +) + +# 初始化 Redis +redis_client = Redis( + host=REDIS_HOST, + port=REDIS_PORT, + password=REDIS_PASSWORD, + db=REDIS_DB +) + +main_redis_client = Redis( + host=REDIS_HOST, + port=REDIS_PORT, + password=REDIS_PASSWORD, + db=MAIN_REDIS_DB +) + +class faceDetector: + def __init__(self, model_path): + self.model = YOLO(model_path).to('cuda:1') + + def detect(self, frame): + results = self.model(frame, device='cuda:1') + return results + + def format_results(self, results, original_shape): + formatted_results = [] + for r in results: + boxes = r.boxes + keypoints = r.keypoints + for i in range(len(boxes)): + box = boxes[i] + kpts = keypoints[i] + + # 调整边界框坐标以适应原始图像大小 + orig_h, orig_w = original_shape[:2] + model_h, model_w = r.orig_shape + scale_x, scale_y = orig_w / model_w, orig_h / model_h + + bbox = box.xyxy[0].cpu().numpy() + bbox_scaled = [ + bbox[0] * scale_x, bbox[1] * scale_y, + bbox[2] * scale_x, bbox[3] * scale_y + ] + + # 调整关键点坐标以适应原始图像大小 + kpts_scaled = kpts.xy[0].cpu().numpy() * np.array([scale_x, scale_y]) + + formatted_results.append({ + "bbox": bbox_scaled, + "confidence": box.conf.item(), + "keypoints": kpts_scaled.tolist() + }) + return formatted_results + + def draw_results(self, frame, results): + annotated_frame = frame.copy() + for r in results: + bbox = r["bbox"] + keypoints = r["keypoints"] + + # 绘制边界框 + cv2.rectangle(annotated_frame, + (int(bbox[0]), int(bbox[1])), + (int(bbox[2]), int(bbox[3])), + (0, 255, 0), 2) + + # 绘制关键点 + for kp in keypoints: + cv2.circle(annotated_frame, + (int(kp[0]), int(kp[1])), + 5, (255, 0, 0), -1) + return annotated_frame + + +detector = faceDetector(MODEL_PATH) + +def process_image(image_path): + try: + original_img = cv2.imread(image_path) + original_shape = original_img.shape + + img = cv2.cvtColor(original_img, cv2.COLOR_BGR2RGB) + img = cv2.resize(img, (640, 640)) + img = img.transpose((2, 0, 1)) + img = np.ascontiguousarray(img) + img = torch.from_numpy(img).float() + img /= 255.0 + img = img.unsqueeze(0) + + results = detector.detect(img) + + json_results = detector.format_results(results, original_shape) + + annotated_img = detector.draw_results(original_img, json_results) + + return json_results, annotated_img + except Exception as e: + print(f"处理图像时出错: {str(e)}") + return None, None + +def process_video(video_path): + try: + cap = cv2.VideoCapture(video_path) + frame_count = 0 + json_results = [] + + fps = int(cap.get(cv2.CAP_PROP_FPS)) + width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + original_shape = (height, width) + + out = cv2.VideoWriter(video_path.replace(UPLOAD_DIR, RESULT_DIR), cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height)) + + while cap.isOpened(): + ret, frame = cap.read() + if not ret: + break + + if frame_count % fps == 0: + preprocessed_frame = preprocess_frame(frame) + + results = detector.detect(preprocessed_frame) + frame_json_results = detector.format_results(results, original_shape) + json_results.append({"frame": frame_count, "detections": frame_json_results}) + + annotated_frame = detector.draw_results(frame, frame_json_results) + out.write(annotated_frame) + + frame_count += 1 + + cap.release() + out.release() + + return json_results + except Exception as e: + print(f"处理视频时出错: {str(e)}") + return None + +def preprocess_frame(frame): + frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + frame_resized = cv2.resize(frame_rgb, (640, 640)) + frame_transposed = frame_resized.transpose((2, 0, 1)) + frame_contiguous = np.ascontiguousarray(frame_transposed) + frame_tensor = torch.from_numpy(frame_contiguous).float() + frame_normalized = frame_tensor / 255.0 + frame_batched = frame_normalized.unsqueeze(0) + return frame_batched + +def process_task(): + print("开始处理任务,等待Kafka消息...") + for message in consumer: + print(f"收到Kafka消息: topic={message.topic}, partition={message.partition}, offset={message.offset}") + task = message.value + task_id = task['task_id'] + filename = task['filename'] + file_type = task['file_type'] + + print(f"解析任务信息: ID={task_id}, 文件名={filename}, 类型={file_type}") + + file_path = os.path.join(UPLOAD_DIR, filename) + # 检查键类型并更新状态 + task_key = f"task:{task_id}" + try: + key_type = main_redis_client.type(task_key) + if key_type != b'hash': + main_redis_client.delete(task_key) + main_redis_client.hset(task_key, "status", "processing") + print(f"任务 {task_id} 状态更新为 'processing'") + except redis.exceptions.ResponseError as e: + print(f"更新任务 {task_id} 状态时出错: {str(e)}") + continue # 跳过这个任务,继续处理下一个 + + try: + if file_type == "image": + print(f"开始处理图像: {filename}") + json_results, annotated_img = process_image(file_path) + if json_results and annotated_img is not None: + result_filename = f"face_{filename}" + result_path = os.path.join(RESULT_DIR, result_filename) + cv2.imwrite(result_path, annotated_img) + + redis_client.hmset(f"face_result:{task_id}", { + "result": json.dumps(json_results), + "result_file": result_filename + }) + main_redis_client.hmset(f"task:{task_id}", { + "status": "completed", + "result_type": "face", + "result_key": f"face_result:{task_id}" + }) + print(f"图像 {filename} 处理完成,结果已保存") + else: + print(f"图像 {filename} 处理失败") + main_redis_client.hset(f"task:{task_id}", "status", "failed") + else: # video + print(f"开始处理视频: {filename}") + json_results = process_video(file_path) + if json_results: + result_filename = f"face_{filename}" + redis_client.hmset(f"face_result:{task_id}", { + "result": json.dumps(json_results), + "result_file": result_filename + }) + main_redis_client.hmset(f"task:{task_id}", { + "status": "completed", + "result_type": "face", + "result_key": f"face_result:{task_id}" + }) + print(f"视频 {filename} 处理完成,结果已保存") + else: + print(f"视频 {filename} 处理失败") + main_redis_client.hset(f"task:{task_id}", "status", "failed") + except Exception as e: + print(f"处理任务 {task_id} 时出错: {str(e)}") + main_redis_client.hmset(f"task:{task_id}", { + "status": "failed", + "error": str(e) + }) + + print(f"任务 {task_id} 处理完毕,等待下一个Kafka消息...") +def listen_redis_changes(): + pubsub = redis_client.pubsub() + pubsub.psubscribe('__keyspace@3__:face_result:*') # 监听所有face_result键的变化 + + for message in pubsub.listen(): + if message['type'] == 'pmessage': + key = message['channel'].decode('utf-8').split(':')[-1] + operation = message['data'].decode('utf-8') + + if operation == 'hmset': + value = redis_client.hgetall(f"face_result:{key}") + if value: + result = {k.decode(): v.decode() for k, v in value.items()} + print(f"Result update for task {key}: {result}") + + +if __name__ == "__main__": + print("face处理程序启动...") + # 启动处理任务的线程 + task_thread = threading.Thread(target=process_task, daemon=True) + task_thread.start() + print("任务处理线程已启动") + + # 启动Redis监听线程 + redis_thread = threading.Thread(target=listen_redis_changes, daemon=True) + redis_thread.start() + print("Redis监听线程已启动") + + print("主程序进入等待状态...") + # 保持主线程运行 + task_thread.join() + redis_thread.join() \ No newline at end of file diff --git a/api/fall.py b/api/fall.py new file mode 100644 index 0000000..691f62d --- /dev/null +++ b/api/fall.py @@ -0,0 +1,274 @@ +import os +import cv2 +import torch +import numpy as np +from redis import Redis +from ultralytics import YOLO +import json +from kafka import KafkaConsumer +import threading +import redis +from config import * +# 配置 +MODEL_PATH = FALL_MODEL_PATH +KAFKA_BROKER = KAFKA_BROKER +KAFKA_TOPIC = WORKER_CONFIGS["fall"]["kafka_topic"] +KAFKA_GROUP_ID = f"fall_{KAFKA_GROUP_ID_PREFIX}" + +REDIS_HOST = REDIS_HOST +REDIS_PORT = REDIS_PORT +REDIS_PASSWORD = REDIS_PASSWORD +REDIS_DB = WORKER_CONFIGS["fall"]["redis_db"] # Worker使用的Redis DB +MAIN_REDIS_DB = MAIN_REDIS_DB # 主Redis DB + +UPLOAD_DIR = UPLOAD_DIR +RESULT_DIR = RESULT_DIR + +# 确保目录存在 +os.makedirs(UPLOAD_DIR, exist_ok=True) +os.makedirs(RESULT_DIR, exist_ok=True) + +# 初始化 Kafka + +consumer = KafkaConsumer( + KAFKA_TOPIC, + bootstrap_servers=[KAFKA_BROKER], + group_id=KAFKA_GROUP_ID, + auto_offset_reset='latest', + enable_auto_commit=True, + value_deserializer=lambda x: json.loads(x.decode('utf-8')) +) + +# 初始化 Redis +redis_client = Redis( + host=REDIS_HOST, + port=REDIS_PORT, + password=REDIS_PASSWORD, + db=REDIS_DB +) + +main_redis_client = Redis( + host=REDIS_HOST, + port=REDIS_PORT, + password=REDIS_PASSWORD, + db=MAIN_REDIS_DB +) + +class fallDetector: + def __init__(self, model_path): + self.model = YOLO(model_path) + def detect(self, frame): + results = self.model(frame) + return results + def format_results(self, results): + formatted_results = [] + for r in results: + if not hasattr(r, 'boxes') or len(r.boxes) == 0: + print("没有检测到任何对象") + return [{"message": "No objects detected"}] + + boxes = r.boxes + names = getattr(r, 'names', {}) + + for i in range(len(boxes)): + box = boxes[i] + if not hasattr(box, 'cls') or not hasattr(box, 'conf') or not hasattr(box, 'xyxy'): + print(f"警告: 第 {i} 个框缺少必要的属性") + continue + + try: + class_id = int(box.cls.item()) + formatted_result = { + "bbox": box.xyxy.tolist()[0], + "confidence": box.conf.item(), + "class_id": class_id, + "class": names.get(class_id, f"Unknown-{class_id}") + } + formatted_results.append(formatted_result) + except Exception as e: + print(f"处理第 {i} 个框时出错: {str(e)}") + + # print("格式化后的结果:", formatted_results) + return formatted_results + + def draw_results(self, frame, results): + for r in results: + annotated_frame = r.plot() + return annotated_frame + +detector = fallDetector(MODEL_PATH) + +def process_image(image_data, filename): + try: + img = cv2.imdecode(np.frombuffer(image_data, np.uint8), cv2.IMREAD_COLOR) + + results = detector.detect(img) + + # Format results for JSON + json_results = detector.format_results(results) + + # Draw results on image + annotated_img = detector.draw_results(img, results) + + # Save annotated image + annotated_filename = f"fall_{filename}" + annotated_path = os.path.join(RESULT_DIR, annotated_filename) + cv2.imwrite(annotated_path, annotated_img) + + return json_results, annotated_filename + except Exception as e: + print(f"Error processing image: {str(e)}") + return None, None + + + +def process_video(video_data, filename): + try: + temp_video_path = os.path.join(UPLOAD_DIR, f"fall_{filename}") + with open(temp_video_path, 'wb') as temp_video: + temp_video.write(video_data) + + cap = cv2.VideoCapture(temp_video_path) + frame_count = 0 + json_results = [] + + # Get video properties + fps = int(cap.get(cv2.CAP_PROP_FPS)) + width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + + # Create output video file + annotated_filename = f"fall_{filename}" + output_path = os.path.join(RESULT_DIR, annotated_filename) + out = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height)) + + while cap.isOpened(): + ret, frame = cap.read() + if not ret: + break + + results = detector.detect(frame) + frame_json_results = detector.format_results(results) + json_results.append({"frame": frame_count, "detections": frame_json_results}) + + annotated_frame = detector.draw_results(frame, results) + out.write(annotated_frame) + + frame_count += 1 + + cap.release() + out.release() + + # Clean up temporary input video file + os.remove(temp_video_path) + + return json_results, annotated_filename + except Exception as e: + print(f"Error processing video: {str(e)}") + return None, None + +def process_task(): + print("开始处理任务,等待Kafka消息...") + for message in consumer: + print(f"收到Kafka消息: topic={message.topic}, partition={message.partition}, offset={message.offset}") + task = message.value + task_id = task['task_id'] + filename = task['filename'] + file_type = task['file_type'] + + print(f"解析任务信息: ID={task_id}, 文件名={filename}, 类型={file_type}") + + file_path = os.path.join(UPLOAD_DIR, filename) + # 检查键类型并更新状态 + task_key = f"task:{task_id}" + try: + key_type = main_redis_client.type(task_key) + if key_type != b'hash': + main_redis_client.delete(task_key) + main_redis_client.hset(f"task:{task_id}", "status", "processing") + print(f"任务 {task_id} 状态更新为 'processing'") + except redis.exceptions.ResponseError as e: + print(f"更新任务 {task_id} 状态时出错: {str(e)}") + continue # 跳过这个任务,继续处理下一个 + + try: + if file_type == "image": + print(f"开始处理图像: {filename}") + with open(file_path, 'rb') as f: + image_data = f.read() + json_results, annotated_filename = process_image(image_data, filename) + if json_results and annotated_filename is not None: + result_path = os.path.join(RESULT_DIR, annotated_filename) + + redis_client.hset(f"fall_result:{task_id}", mapping={ + "result": json.dumps(json_results), + "result_file": annotated_filename + }) + main_redis_client.hset(f"task:{task_id}", "status", "completed") + main_redis_client.hset(f"task:{task_id}", "result_type", "fall") + main_redis_client.hset(f"task:{task_id}", "result_key", f"fall_result:{task_id}") + print(f"图像 {filename} 处理完成,结果已保存") + else: + print(f"图像 {filename} 处理失败") + main_redis_client.hset(f"task:{task_id}", "status", "failed") + else: # video + print(f"开始处理视频: {filename}") + with open(file_path, 'rb') as f: + video_data = f.read() + json_results, annotated_filename = process_video(video_data, filename) + if json_results and annotated_filename: + redis_client.hset(f"fall_result:{task_id}", mapping={ + "result": json.dumps(json_results), + "result_file": annotated_filename + }) + main_redis_client.hset(f"task:{task_id}", "status", "completed") + main_redis_client.hset(f"task:{task_id}", "result_type", "fall") + main_redis_client.hset(f"task:{task_id}", "result_key", f"fall_result:{task_id}") + + + + print(f"视频 {filename} 处理完成,结果已保存") + else: + print(f"视频 {filename} 处理失败") + main_redis_client.hset(f"task:{task_id}", "status", "failed") + except Exception as e: + print(f"处理任务 {task_id} 时出错: {str(e)}") + main_redis_client.hset(f"task:{task_id}", { + "status": "failed", + "error": str(e) + }) + + print(f"任务 {task_id} 处理完毕,等待下一个Kafka消息...") + +def listen_redis_changes(): + pubsub = redis_client.pubsub() + pubsub.psubscribe('__keyspace@3__:fall_result:*') # 监听所有fall_result键的变化 + + for message in pubsub.listen(): + if message['type'] == 'pmessage': + key = message['channel'].decode('utf-8').split(':')[-1] + operation = message['data'].decode('utf-8') + + if operation == 'hset': + value = redis_client.hgetall(f"fall_result:{key}") + if value: + result = {k.decode(): v.decode() for k, v in value.items()} + print(f"Result update for task {key}: {result}") + + +if __name__ == "__main__": + print("fall处理程序启动...") + # 启动处理任务的线程 + task_thread = threading.Thread(target=process_task, daemon=True) + task_thread.start() + print("任务处理线程已启动") + + # 启动Redis监听线程 + redis_thread = threading.Thread(target=listen_redis_changes, daemon=True) + redis_thread.start() + print("Redis监听线程已启动") + + print("主程序进入等待状态...") + # 保持主线程运行 + task_thread.join() + redis_thread.join() \ No newline at end of file diff --git a/api/main.py b/api/main.py new file mode 100644 index 0000000..078fd33 --- /dev/null +++ b/api/main.py @@ -0,0 +1,426 @@ +# main.py +from fastapi import FastAPI, File, UploadFile, Depends, HTTPException, Security +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import JSONResponse +from fastapi.security import APIKeyHeader +from kafka import KafkaProducer +from redis import Redis +import os +import json +import uuid +from datetime import datetime, timedelta, timezone +import string +from decord import VideoReader +from PIL import Image +from fastapi.responses import FileResponse +import logging +from config import * + +app = FastAPI() +v1_app = FastAPI() +app.mount("/v1", v1_app) + + +# CORS设置 +# ALLOWED_ORIGINS = ['https://beta.obscura.work'] + +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + +# 配置 +KAFKA_BROKER = KAFKA_BROKER +REDIS_HOST = REDIS_HOST +REDIS_PORT = REDIS_PORT +REDIS_PASSWORD = REDIS_PASSWORD +REDIS_DB = MAIN_REDIS_DB +REDIS_API_DB = REDIS_API_DB +REDIS_API_USAGE_DB = REDIS_API_USAGE_DB +UPLOAD_DIR = UPLOAD_DIR +RESULT_DIR = RESULT_DIR +MAX_FILE_AGE = timedelta(hours=1) + +# 确保目录存在 +os.makedirs(UPLOAD_DIR, exist_ok=True) +os.makedirs(RESULT_DIR, exist_ok=True) + +# 定义支持的任务类型 +KAFKA_TOPICS = { + 'pose': 'pose', + 'mediapipe': 'mediapipe', + 'qwenvl': 'qwenvl', + 'yolo': 'yolo', + 'fall': 'fall', + 'face': 'face', + 'cpm': 'cpm', + 'compare': 'compare' +} + +TASK_TYPES = list(KAFKA_TOPICS.keys()) + + +# 初始化 Kafka Producer +producer = KafkaProducer( + bootstrap_servers=[KAFKA_BROKER], + value_serializer=lambda v: json.dumps(v).encode('utf-8') +) + +# 初始化 Redis +redis_client = Redis( + host=REDIS_HOST, + port=REDIS_PORT, + password=REDIS_PASSWORD, + db=REDIS_DB +) + +redis_api_client = Redis( + host=REDIS_HOST, + port=REDIS_PORT, + password=REDIS_PASSWORD, + db=REDIS_API_DB +) + +redis_api_usage_client = Redis( + host=REDIS_HOST, + port=REDIS_PORT, + password=REDIS_PASSWORD, + db=REDIS_API_USAGE_DB +) +redis_pose_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=WORKER_CONFIGS['pose']['redis_db']) +redis_cpm_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=WORKER_CONFIGS['cpm']['redis_db']) +redis_yolo_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=WORKER_CONFIGS['yolo']['redis_db']) +redis_face_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=WORKER_CONFIGS['face']['redis_db']) +redis_fall_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=WORKER_CONFIGS['fall']['redis_db']) +redis_mediapipe_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=WORKER_CONFIGS['mediapipe']['redis_db']) +redis_qwenvl_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=WORKER_CONFIGS['qwenvl']['redis_db']) +redis_compare_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=WORKER_CONFIGS['compare']['redis_db']) +@v1_app.get('/favicon.ico', include_in_schema=False) +async def favicon(): + file_name = "favicon.ico" + file_path = os.path.join(app.root_path, "static", file_name) + if os.path.isfile(file_path): + return FileResponse(file_path) + else: + return {"message": "Favicon not found"}, 404 + +# 定义API密钥头部 +api_key_header = APIKeyHeader(name="Authorization", auto_error=False) + +# 定义 base62 字符集 +BASE62 = string.digits + string.ascii_letters + +# 验证API密钥的函数 +async def get_api_key(api_key: str = Security(api_key_header)): + if api_key and api_key.startswith("Bearer "): + key = api_key.split(" ")[1] + if key.startswith("obs-"): + return key + raise HTTPException( + status_code=401, + detail="无效的API密钥", + headers={"WWW-Authenticate": "Bearer"}, + ) + +async def verify_api_key(api_key: str = Depends(get_api_key)): + logging.info(f"验证API密钥: {api_key}") + redis_key = f"api_key:{api_key}" + + api_key_info = redis_api_client.hgetall(redis_key) + + if not api_key_info: + logging.warning(f"API密钥不存在: {api_key}") + raise HTTPException(status_code=401, detail="无效的API密钥") + + api_key_info = {k.decode('utf-8'): v.decode('utf-8') for k, v in api_key_info.items()} + + if api_key_info.get('is_active') != '1': + logging.warning(f"API密钥已停用: {api_key}") + raise HTTPException(status_code=401, detail="API密钥已停用") + + expires_at = datetime.fromisoformat(api_key_info.get('expires_at')) + if datetime.now(timezone.utc) > expires_at: + logging.warning(f"API密钥已过期: {api_key}") + raise HTTPException(status_code=401, detail="API密钥已过期") + + usage_info = redis_api_usage_client.hgetall(redis_key) + usage_info = {k.decode('utf-8'): v.decode('utf-8') for k, v in usage_info.items()} + + logging.info(f"API密钥验证成功: {api_key}") + return { + "api_key": api_key, + **api_key_info, + **usage_info + } + +async def update_token_usage(api_key: str, new_tokens_used: int, model_name: str): + redis_key = f"api_key:{api_key}" + current_time = datetime.now(timezone.utc).isoformat() + + pipe = redis_api_usage_client.pipeline() + + # 更新总的token使用量 + pipe.hincrby(redis_key, "tokens_used", new_tokens_used) + pipe.hset(redis_key, "last_used_at", current_time) + + # 更新特定模型的token使用量 + model_tokens_field = f"{model_name}_tokens_used" + model_last_used_field = f"{model_name}_last_used_at" + + pipe.hincrby(redis_key, model_tokens_field, new_tokens_used) + pipe.hset(redis_key, model_last_used_field, current_time) + + pipe.execute() + +def calculate_tokens(file_path: str, file_type: str) -> int: + base_tokens = 0 + + try: + file_size = os.path.getsize(file_path) # 获取文件大小(字节) + + # 基础token:每MB文件大小消耗10个token + base_tokens = int((file_size / (1024 * 1024)) * 0.1) + + if file_type == "image": + img = Image.open(file_path) + width, height = img.size + pixel_count = width * height + + # 图片token:每100个像素额外消耗5个token + image_tokens = int((pixel_count / 100000000) * 0.1) + + base_tokens += image_tokens + + elif file_type == "video": + vr = VideoReader(file_path) + fps = vr.get_avg_fps() + frame_count = len(vr) + width, height = vr[0].shape[1], vr[0].shape[0] + + pixel_count = width * height * frame_count + duration = frame_count / fps # 视频时长(秒) + + # 视频token:每100万像素每秒额外消耗1个token + video_tokens = int((pixel_count / 100000000) * (duration / 60)) + + base_tokens += video_tokens + + return max(1, base_tokens) # 确保至少返回1个token + except Exception as e: + print(f"计算token时出错: {str(e)}") + return 1 # 出错时返回默认值1 + + + +async def upload_file(file: UploadFile, task_type: str, api_key_info: dict): + if task_type not in KAFKA_TOPICS: + raise HTTPException(status_code=400, detail="不支持的任务类型") + + content = await file.read() + file_extension = os.path.splitext(file.filename)[1].lower() + new_filename = f"{uuid.uuid4()}{file_extension}" + + file_path = os.path.join(UPLOAD_DIR, new_filename) + with open(file_path, "wb") as f: + f.write(content) + + # 计算 token + file_type = "image" if file_extension in ['.jpg', '.jpeg', '.png'] else "video" + tokens_required = calculate_tokens(file_path, file_type) + + # 检查并更新 token 使用量 + api_key = api_key_info['api_key'] + usage_key = f"api_key:{api_key}" + total_tokens = int(redis_api_usage_client.hget(usage_key, "total_tokens") or 0) + tokens_used = int(redis_api_usage_client.hget(usage_key, "tokens_used") or 0) + + if tokens_used + tokens_required > total_tokens: + raise HTTPException(status_code=403, detail="Token 余额不足") + + # 更新 token 使用量 + await update_token_usage(api_key, tokens_required, task_type) + + # 创建任务记录 + task_id = str(uuid.uuid4()) + task_data = { + "task_id": task_id, + "filename": new_filename, + "file_type": file_type, + "task_type": task_type, + "status": "queued", + "created_at": datetime.now(timezone.utc).isoformat(), + } + + # 存储任务信息到Redis + redis_client.set(f"task:{task_id}", json.dumps(task_data)) + logging.info(f"任务信息已存储到Redis: {task_id}") + + # 发送任务到对应的Kafka主题 + kafka_topic = KAFKA_TOPICS[task_type] + producer.send(kafka_topic, task_data) + logging.info(f"任务已发送到Kafka主题: {kafka_topic}") + + # 获取更新后的 token 使用情况 + updated_api_key_info = await verify_api_key(api_key) + new_tokens_used = int(updated_api_key_info.get("tokens_used", 0)) + model_tokens_used = int(updated_api_key_info.get(f"{task_type}_tokens_used", 0)) + + response_data = { + "message": "文件已上传并排队等待处理", + "task_id": task_id, + "tokens_used": tokens_required, + "total_tokens_used": new_tokens_used, + f"{task_type}_tokens_used": model_tokens_used, + "tokens_remaining": total_tokens - new_tokens_used + } + logging.info(f"上传文件完成: {task_id}") + return JSONResponse(content=response_data) + +# 为每个任务类型创建单独的端点 +@v1_app.post("/pose") +async def upload_pose(file: UploadFile = File(...), api_key_info: dict = Depends(verify_api_key)): + logging.info(f"收到 /pose端点的请求") + return await upload_file(file, task_type="pose", api_key_info=api_key_info) + +@v1_app.post("/cpm") +async def upload_cpm(file: UploadFile = File(...), api_key_info: dict = Depends(verify_api_key)): + return await upload_file(file, task_type="cpm", api_key_info=api_key_info) + +@v1_app.post("/qwenvl") +async def upload_qwenvl(file: UploadFile = File(...), api_key_info: dict = Depends(verify_api_key)): + return await upload_file(file, task_type="qwenvl", api_key_info=api_key_info) + +@v1_app.post("/yolo") +async def upload_yolo(file: UploadFile = File(...), api_key_info: dict = Depends(verify_api_key)): + return await upload_file(file, task_type="yolo", api_key_info=api_key_info) + +@v1_app.post("/fall") +async def upload_fall(file: UploadFile = File(...), api_key_info: dict = Depends(verify_api_key)): + return await upload_file(file, task_type="fall", api_key_info=api_key_info) + +@v1_app.post("/face") +async def upload_face(file: UploadFile = File(...), api_key_info: dict = Depends(verify_api_key)): + logging.info(f"收到 /face 端点的请求") + return await upload_file(file, task_type="face", api_key_info=api_key_info) + +@v1_app.post("/mediapipe") +async def upload_mediapipe(file: UploadFile = File(...), api_key_info: dict = Depends(verify_api_key)): + return await upload_file(file, task_type="mediapipe", api_key_info=api_key_info) + +@v1_app.post("/compare") +async def upload_compare(file: UploadFile = File(...), api_key_info: dict = Depends(verify_api_key)): + return await upload_file(file, task_type="compare", api_key_info=api_key_info) + + +@v1_app.get("/result/{task_id}") +async def get_result(task_id: str, api_key: str = Depends(get_api_key)): + api_key_info = await verify_api_key(api_key) + if not api_key_info: + raise HTTPException(status_code=403, detail="无效的API密钥") + + # 从 REDIS_DB (15) 获取任务状态 + task_info = redis_client.hgetall(f"task:{task_id}") + if not task_info: + raise HTTPException(status_code=404, detail="Task not found") + + task_info = {k.decode('utf-8'): v.decode('utf-8') for k, v in task_info.items()} + + if task_info['status'] != 'completed': + return {"status": task_info['status'], "message": "Task is not completed yet"} + + result_type = task_info['result_type'] + result_key = task_info['result_key'] + + # 根据任务类型选择相应的 Redis 客户端 + redis_client_map = { + 'pose': redis_pose_client, + 'cpm': redis_cpm_client, + 'yolo': redis_yolo_client, + 'face': redis_face_client, + 'fall': redis_fall_client, + 'mediapipe': redis_mediapipe_client, + 'qwenvl': redis_qwenvl_client, + 'compare': redis_compare_client + } + + result_redis = redis_client_map.get(result_type) + if not result_redis: + raise HTTPException(status_code=400, detail="Unsupported result type") + + result = result_redis.hgetall(result_key) + if not result: + raise HTTPException(status_code=404, detail=f"{result_type.upper()} result not found") + + result = {k.decode('utf-8'): v.decode('utf-8') for k, v in result.items()} + + # 将 result 字段解析为 JSON(如果存在) + if 'result' in result: + result['result'] = json.loads(result['result']) + + return { + "status": "completed", + "result_type": result_type, + "result": result + } + +@v1_app.get("/annotated/{task_id}") +async def get_annotated_image(task_id: str, api_key: str = Depends(get_api_key)): + api_key_info = await verify_api_key(api_key) + if not api_key_info: + raise HTTPException(status_code=403, detail="无效的API密钥") + + # 从 REDIS_DB (15) 获取任务信息 + task_info = redis_client.hgetall(f"task:{task_id}") + if not task_info: + raise HTTPException(status_code=404, detail="Task not found") + + task_info = {k.decode('utf-8'): v.decode('utf-8') for k, v in task_info.items()} + + if task_info['status'] != 'completed': + raise HTTPException(status_code=400, detail="Task is not completed yet") + + result_type = task_info.get('result_type') + result_key = task_info.get('result_key') + + if not result_key: + raise HTTPException(status_code=404, detail="Result key not found") + + if result_type in ['cpm', 'qwenvl']: + raise HTTPException(status_code=400, detail="Annotated image not available for this task type") + + # 根据任务类型选择相应的 Redis 客户端 + redis_client_map = { + 'pose': redis_pose_client, + 'yolo': redis_yolo_client, + 'face': redis_face_client, + 'fall': redis_fall_client, + 'mediapipe': redis_mediapipe_client, + 'compare': redis_compare_client + } + + result_redis = redis_client_map.get(result_type) + if not result_redis: + raise HTTPException(status_code=400, detail="Unsupported result type") + + result = result_redis.hgetall(result_key) + if not result: + raise HTTPException(status_code=404, detail=f"{result_type.upper()} result not found") + + result = {k.decode('utf-8'): v.decode('utf-8') for k, v in result.items()} + + result_file = result.get('result_file') + if not result_file: + raise HTTPException(status_code=404, detail="Result file not found") + + file_path = os.path.join(RESULT_DIR, result_file) + if not os.path.exists(file_path): + raise HTTPException(status_code=404, detail="Result image file not found") + + return FileResponse(file_path, media_type="image/png") + +if __name__ == "__main__": + import uvicorn + uvicorn.run(app, host="0.0.0.0", port=8005) \ No newline at end of file diff --git a/api/media.py b/api/media.py new file mode 100644 index 0000000..2385c8f --- /dev/null +++ b/api/media.py @@ -0,0 +1,299 @@ +import os +import cv2 +import torch +import numpy as np +from redis import Redis +import json +from kafka import KafkaConsumer +import threading +import redis +import torch +import mediapipe as mp +from mediapipe.tasks import python +from mediapipe.tasks.python import vision +from config import * +# 配置 +MODEL_PATH = MEDIAPIPE_MODEL_PATH +KAFKA_BROKER = KAFKA_BROKER +KAFKA_TOPIC = WORKER_CONFIGS["mediapipe"]["kafka_topic"] +KAFKA_GROUP_ID = f"mediapipe_{KAFKA_GROUP_ID_PREFIX}" + +REDIS_HOST = REDIS_HOST +REDIS_PORT = REDIS_PORT +REDIS_PASSWORD = REDIS_PASSWORD +REDIS_DB = WORKER_CONFIGS["mediapipe"]["redis_db"] # Worker使用的Redis DB +MAIN_REDIS_DB = MAIN_REDIS_DB # 主Redis DB + +UPLOAD_DIR = UPLOAD_DIR +RESULT_DIR = RESULT_DIR + + +# Ensure directories exist +os.makedirs(UPLOAD_DIR, exist_ok=True) +os.makedirs(RESULT_DIR, exist_ok=True) + +# Initialize Kafka +consumer = KafkaConsumer( + KAFKA_TOPIC, + bootstrap_servers=[KAFKA_BROKER], + group_id=KAFKA_GROUP_ID, + auto_offset_reset='earliest', + enable_auto_commit=True, + value_deserializer=lambda x: json.loads(x.decode('utf-8')) +) + +# 初始化 Redis +redis_client = Redis( + host=REDIS_HOST, + port=REDIS_PORT, + password=REDIS_PASSWORD, + db=REDIS_DB +) + +main_redis_client = Redis( + host=REDIS_HOST, + port=REDIS_PORT, + password=REDIS_PASSWORD, + db=MAIN_REDIS_DB +) + + +class mediapipeEmbedder: + def __init__(self, model_path): + base_options = python.BaseOptions(model_asset_path=model_path) + options = vision.FaceLandmarkerOptions( + base_options=base_options, + output_face_blendshapes=True, + output_facial_transformation_matrixes=True, + num_faces=1 + ) + self.detector = vision.FaceLandmarker.create_from_options(options) + + def get_mediapipe_landmarks(self, image): + mp_image = mp.Image(image_format=mp.ImageFormat.SRGB, data=image) + detection_result = self.detector.detect(mp_image) + if detection_result.face_landmarks: + return np.array([(lm.x, lm.y, lm.z) for lm in detection_result.face_landmarks[0]]) + return None + + def process_image(self, image_data): + img = cv2.imdecode(np.frombuffer(image_data, np.uint8), cv2.IMREAD_COLOR) + landmarks = self.get_mediapipe_landmarks(img) + + if landmarks is not None: + # Calculate a more detailed mediapipe embedding + embedding = self.calculate_detailed_embedding(landmarks) + + # Draw landmarks on the image + for lm in landmarks: + cv2.circle(img, (int(lm[0]*img.shape[1]), int(lm[1]*img.shape[0])), 2, (0,255,0), -1) + + return { + "embedding": embedding, + "landmarks": landmarks.tolist() + }, img + else: + return None, img + + def calculate_detailed_embedding(self, landmarks): + # Calculate various statistical features + mean = np.mean(landmarks, axis=0) + std = np.std(landmarks, axis=0) + median = np.median(landmarks, axis=0) + min_vals = np.min(landmarks, axis=0) + max_vals = np.max(landmarks, axis=0) + + # Calculate pairwise distances between key facial landmarks + nose_tip = landmarks[4] + left_eye = landmarks[159] + right_eye = landmarks[386] + left_mouth = landmarks[61] + right_mouth = landmarks[291] + + eye_distance = np.linalg.norm(left_eye - right_eye) + mouth_width = np.linalg.norm(left_mouth - right_mouth) + nose_to_mouth = np.linalg.norm(nose_tip - (left_mouth + right_mouth) / 2) + + # Calculate face shape features + face_width = np.max(landmarks[:, 0]) - np.min(landmarks[:, 0]) + face_height = np.max(landmarks[:, 1]) - np.min(landmarks[:, 1]) + face_depth = np.max(landmarks[:, 2]) - np.min(landmarks[:, 2]) + + # Combine all features into a single embedding + embedding = np.concatenate([ + mean, std, median, min_vals, max_vals, + [eye_distance, mouth_width, nose_to_mouth, face_width, face_height, face_depth] + ]) + + return embedding.tolist() + +embedder = mediapipeEmbedder(MODEL_PATH) + +def process_image(image_data, filename): + try: + results, annotated_img = embedder.process_image(image_data) + + if results: + # Save annotated image + annotated_filename = f"mediapipe_{filename}" + annotated_path = os.path.join(RESULT_DIR, annotated_filename) + cv2.imwrite(annotated_path, annotated_img) + + return results, annotated_filename + else: + print(f"No face landmarks detected in image: {filename}") + return None, None + except Exception as e: + print(f"Error processing image {filename}: {str(e)}") + import traceback + traceback.print_exc() + return None, None + +def process_video(video_data, filename): + try: + temp_video_path = os.path.join(UPLOAD_DIR, f"mediapipe_{filename}") + with open(temp_video_path, 'wb') as temp_video: + temp_video.write(video_data) + + cap = cv2.VideoCapture(temp_video_path) + frame_count = 0 + results = [] + + fps = int(cap.get(cv2.CAP_PROP_FPS)) + width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + + annotated_filename = f"mediapipe_{filename}" + output_path = os.path.join(RESULT_DIR, annotated_filename) + out = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height)) + + while cap.isOpened(): + ret, frame = cap.read() + if not ret: + break + + if frame_count % fps == 0: + frame_results, annotated_frame = embedder.process_image(cv2.imencode('.jpg', frame)[1].tobytes()) + if frame_results: + results.append({"frame": frame_count, "results": frame_results}) + else: + annotated_frame = frame + + out.write(annotated_frame) + frame_count += 1 + + cap.release() + out.release() + + os.remove(temp_video_path) + + return results, annotated_filename + except Exception as e: + print(f"Error processing video: {str(e)}") + return None, None + +def process_task(): + print("开始处理任务,等待Kafka消息...") + for message in consumer: + print(f"收到Kafka消息: topic={message.topic}, partition={message.partition}, offset={message.offset}") + task = message.value + task_id = task['task_id'] + filename = task['filename'] + file_type = task['file_type'] + + print(f"解析任务信息: ID={task_id}, 文件名={filename}, 类型={file_type}") + + file_path = os.path.join(UPLOAD_DIR, filename) + # 检查键类型并更新状态 + task_key = f"task:{task_id}" + try: + key_type = main_redis_client.type(task_key) + if key_type != b'hash': + main_redis_client.delete(task_key) + main_redis_client.hset(f"task:{task_id}", "status", "processing") + print(f"任务 {task_id} 状态更新为 'processing'") + except redis.exceptions.ResponseError as e: + print(f"更新任务 {task_id} 状态时出错: {str(e)}") + continue # 跳过这个任务,继续处理下一个 + + try: + if file_type == "image": + print(f"开始处理图像: {filename}") + with open(file_path, 'rb') as f: + image_data = f.read() + json_results, annotated_filename = process_image(image_data, filename) + if json_results and annotated_filename is not None: + result_path = os.path.join(RESULT_DIR, annotated_filename) + + redis_client.hset(f"mediapipe_result:{task_id}", mapping={ + "result": json.dumps(json_results), + "result_file": annotated_filename + }) + main_redis_client.hset(f"task:{task_id}", "status", "completed") + main_redis_client.hset(f"task:{task_id}", "result_type", "mediapipe") + main_redis_client.hset(f"task:{task_id}", "result_key", f"mediapipe_result:{task_id}") + print(f"图像 {filename} 处理完成,结果已保存") + else: + print(f"图像 {filename} 处理失败") + main_redis_client.hset(f"task:{task_id}", "status", "failed") + else: # video + print(f"开始处理视频: {filename}") + with open(file_path, 'rb') as f: + video_data = f.read() + json_results, annotated_filename = process_video(video_data, filename) + if json_results and annotated_filename: + redis_client.hset(f"mediapipe_result:{task_id}", mapping={ + "result": json.dumps(json_results), + "result_file": annotated_filename + }) + main_redis_client.hset(f"task:{task_id}", "status", "completed") + main_redis_client.hset(f"task:{task_id}", "result_type", "mediapipe") + main_redis_client.hset(f"task:{task_id}", "result_key", f"mediapipe_result:{task_id}") + + + + print(f"视频 {filename} 处理完成,结果已保存") + else: + print(f"视频 {filename} 处理失败") + main_redis_client.hset(f"task:{task_id}", "status", "failed") + except Exception as e: + print(f"处理任务 {task_id} 时出错: {str(e)}") + main_redis_client.hset(f"task:{task_id}", { + "status": "failed", + "error": str(e) + }) + + print(f"任务 {task_id} 处理完毕,等待下一个Kafka消息...") + +def listen_redis_changes(): + pubsub = redis_client.pubsub() + pubsub.psubscribe('__keyspace@3__:mediapipe_result:*') # 监听所有fall_result键的变化 + + for message in pubsub.listen(): + if message['type'] == 'pmessage': + key = message['channel'].decode('utf-8').split(':')[-1] + operation = message['data'].decode('utf-8') + + if operation == 'hset': + value = redis_client.hgetall(f"mediapipe_result:{key}") + if value: + result = {k.decode(): v.decode() for k, v in value.items()} + print(f"Result update for task {key}: {result}") + + +if __name__ == "__main__": + print("mediapipe处理程序启动...") + # 启动处理任务的线程 + task_thread = threading.Thread(target=process_task, daemon=True) + task_thread.start() + print("任务处理线程已启动") + + # 启动Redis监听线程 + redis_thread = threading.Thread(target=listen_redis_changes, daemon=True) + redis_thread.start() + print("Redis监听线程已启动") + + print("主程序进入等待状态...") + # 保持主线程运行 + task_thread.join() + redis_thread.join() \ No newline at end of file diff --git a/api/ollama_proxy.py b/api/ollama_proxy.py new file mode 100644 index 0000000..7b412a4 --- /dev/null +++ b/api/ollama_proxy.py @@ -0,0 +1,97 @@ +# 将本地Ollama API完全反向代理 +from fastapi import FastAPI, Request +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import StreamingResponse +import httpx +OLLAMA_URL = "http://127.0.0.1:11434" + +app = FastAPI() + +# 添加CORS中间件 +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], # 允许所有来源 + allow_credentials=True, + allow_methods=["*"], # 允许所有HTTP方法 + allow_headers=["*"], # 允许所有HTTP头 +) + +# 创建异步HTTP客户端 +async_client = httpx.AsyncClient() + +@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"]) +async def proxy_to_ollama(request: Request, path: str): + if request.method == "OPTIONS": + # 处理预检请求 + headers = { + "Access-Control-Allow-Origin": "*", + "Access-Control-Allow-Methods": "GET, POST, PUT, DELETE, OPTIONS", + "Access-Control-Allow-Headers": "*", + } + return StreamingResponse(content=iter([]), headers=headers) + + target_url = f"{OLLAMA_URL}/{path}" + + # 获取请求体 + body = await request.body() + + # 获取请求头 + headers = dict(request.headers) + headers.pop("host", None) + + try: + # 将请求转换为Python请求 + python_request = { + "method": request.method, + "url": target_url, + "headers": headers, + "data": body + } + + # 使用Python请求发送到Ollama API + async with async_client.stream(**python_request) as response: + # 返回Ollama API的流式响应,并添加CORS头 + response_headers = dict(response.headers) + response_headers["Access-Control-Allow-Origin"] = "*" + return StreamingResponse( + response.aiter_raw(), + status_code=response.status_code, + headers=response_headers + ) + except httpx.RequestError as e: + return {"error": f"请求Ollama API时发生错误: {str(e)}"}, 500 + except httpx.StreamClosed: + # 处理流关闭异常 + print("流已关闭,客户端可能已断开连接") + return {"error": "流已关闭,客户端可能已断开连接"}, 499 + +@app.on_event("shutdown") +async def shutdown_event(): + await async_client.aclose() + +if __name__ == "__main__": + import uvicorn + import requests + import json + + # 测试Ollama API + test_url = "http://localhost:11434/api/generate" + test_data = { + "model": "llama3.1", + "prompt": "Why is the sky blue?", + "stream": False + } + + try: + response = requests.post(test_url, json=test_data) + if response.status_code == 200: + print("Ollama API 测试成功:") + print(json.dumps(response.json(), indent=2)) + else: + print(f"Ollama API 测试失败,状态码: {response.status_code}") + print(response.text) + except requests.RequestException as e: + print(f"Ollama API 测试出错: {str(e)}") + + # 启动FastAPI应用 + uvicorn.run(app, host="0.0.0.0", port=8001) diff --git a/api/pose.py b/api/pose.py new file mode 100644 index 0000000..42f71d9 --- /dev/null +++ b/api/pose.py @@ -0,0 +1,292 @@ +import os +import cv2 +import torch +import numpy as np +from redis import Redis +from ultralytics import YOLO +import json +from kafka import KafkaConsumer +import threading +import redis +import torch +from config import * +torch.cuda.set_device(CUDA_DEVICE_1) +# 配置 +MODEL_PATH = POSE_MODEL_PATH +KAFKA_BROKER = KAFKA_BROKER +KAFKA_TOPIC = WORKER_CONFIGS["pose"]["kafka_topic"] +KAFKA_GROUP_ID = f"pose_{KAFKA_GROUP_ID_PREFIX}" + +REDIS_HOST = REDIS_HOST +REDIS_PORT = REDIS_PORT +REDIS_PASSWORD = REDIS_PASSWORD +REDIS_DB = WORKER_CONFIGS["pose"]["redis_db"] # Worker使用的Redis DB +MAIN_REDIS_DB = MAIN_REDIS_DB # 主Redis DB + +UPLOAD_DIR = UPLOAD_DIR +RESULT_DIR = RESULT_DIR + +# 确保目录存在 +os.makedirs(UPLOAD_DIR, exist_ok=True) +os.makedirs(RESULT_DIR, exist_ok=True) + +# 初始化 Kafka +consumer = KafkaConsumer( + KAFKA_TOPIC, + bootstrap_servers=[KAFKA_BROKER], + group_id=KAFKA_GROUP_ID, + auto_offset_reset='earliest', + enable_auto_commit=True, + value_deserializer=lambda x: json.loads(x.decode('utf-8')) +) + +# 初始化 Redis +redis_client = Redis( + host=REDIS_HOST, + port=REDIS_PORT, + password=REDIS_PASSWORD, + db=REDIS_DB +) + +main_redis_client = Redis( + host=REDIS_HOST, + port=REDIS_PORT, + password=REDIS_PASSWORD, + db=MAIN_REDIS_DB +) + +class PoseDetector: + def __init__(self, model_path): + self.model = YOLO(model_path).to(CUDA_DEVICE_1) + + def detect(self, frame): + results = self.model(frame, device=CUDA_DEVICE_1) + return results + + def format_results(self, results, original_shape): + formatted_results = [] + for r in results: + boxes = r.boxes + keypoints = r.keypoints + for i in range(len(boxes)): + box = boxes[i] + kpts = keypoints[i] + + # 调整边界框坐标以适应原始图像大小 + orig_h, orig_w = original_shape[:2] + model_h, model_w = r.orig_shape + scale_x, scale_y = orig_w / model_w, orig_h / model_h + + bbox = box.xyxy[0].cpu().numpy() + bbox_scaled = [ + bbox[0] * scale_x, bbox[1] * scale_y, + bbox[2] * scale_x, bbox[3] * scale_y + ] + + # 调整关键点坐标以适应原始图像大小 + kpts_scaled = kpts.xy[0].cpu().numpy() * np.array([scale_x, scale_y]) + + formatted_results.append({ + "bbox": bbox_scaled, + "confidence": box.conf.item(), + "keypoints": kpts_scaled.tolist() + }) + return formatted_results + + def draw_results(self, frame, results): + annotated_frame = frame.copy() + for r in results: + bbox = r["bbox"] + keypoints = r["keypoints"] + + # 绘制边界框 + cv2.rectangle(annotated_frame, + (int(bbox[0]), int(bbox[1])), + (int(bbox[2]), int(bbox[3])), + (0, 255, 0), 2) + + # 绘制关键点 + for kp in keypoints: + cv2.circle(annotated_frame, + (int(kp[0]), int(kp[1])), + 5, (255, 0, 0), -1) + return annotated_frame + +detector = PoseDetector(MODEL_PATH) + +def process_image(image_path): + try: + original_img = cv2.imread(image_path) + original_shape = original_img.shape + + img = cv2.cvtColor(original_img, cv2.COLOR_BGR2RGB) + img = cv2.resize(img, (640, 640)) + img = img.transpose((2, 0, 1)) + img = np.ascontiguousarray(img) + img = torch.from_numpy(img).float() + img /= 255.0 + img = img.unsqueeze(0) + + results = detector.detect(img) + + json_results = detector.format_results(results, original_shape) + + annotated_img = detector.draw_results(original_img, json_results) + + return json_results, annotated_img + except Exception as e: + print(f"处理图像时出错: {str(e)}") + return None, None + +def process_video(video_path): + try: + cap = cv2.VideoCapture(video_path) + frame_count = 0 + json_results = [] + + fps = int(cap.get(cv2.CAP_PROP_FPS)) + width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + original_shape = (height, width) + + out = cv2.VideoWriter(video_path.replace(UPLOAD_DIR, RESULT_DIR), cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height)) + + while cap.isOpened(): + ret, frame = cap.read() + if not ret: + break + + if frame_count % fps == 0: + preprocessed_frame = preprocess_frame(frame) + + results = detector.detect(preprocessed_frame) + frame_json_results = detector.format_results(results, original_shape) + json_results.append({"frame": frame_count, "detections": frame_json_results}) + + annotated_frame = detector.draw_results(frame, frame_json_results) + out.write(annotated_frame) + + frame_count += 1 + + cap.release() + out.release() + + return json_results + except Exception as e: + print(f"处理视频时出错: {str(e)}") + return None + +def preprocess_frame(frame): + frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + frame_resized = cv2.resize(frame_rgb, (640, 640)) + frame_transposed = frame_resized.transpose((2, 0, 1)) + frame_contiguous = np.ascontiguousarray(frame_transposed) + frame_tensor = torch.from_numpy(frame_contiguous).float() + frame_normalized = frame_tensor / 255.0 + frame_batched = frame_normalized.unsqueeze(0) + return frame_batched + +def process_task(): + print("开始处理任务,等待Kafka消息...") + for message in consumer: + print(f"收到Kafka消息: topic={message.topic}, partition={message.partition}, offset={message.offset}") + task = message.value + task_id = task['task_id'] + filename = task['filename'] + file_type = task['file_type'] + + print(f"解析任务信息: ID={task_id}, 文件名={filename}, 类型={file_type}") + + file_path = os.path.join(UPLOAD_DIR, filename) + # 检查键类型并更新状态 + task_key = f"task:{task_id}" + try: + key_type = main_redis_client.type(task_key) + if key_type != b'hash': + main_redis_client.delete(task_key) + main_redis_client.hset(task_key, "status", "processing") + print(f"任务 {task_id} 状态更新为 'processing'") + except redis.exceptions.ResponseError as e: + print(f"更新任务 {task_id} 状态时出错: {str(e)}") + continue # 跳过这个任务,继续处理下一个 + + try: + if file_type == "image": + print(f"开始处理图像: {filename}") + json_results, annotated_img = process_image(file_path) + if json_results and annotated_img is not None: + result_filename = f"pose_{filename}" + result_path = os.path.join(RESULT_DIR, result_filename) + cv2.imwrite(result_path, annotated_img) + + redis_client.hmset(f"pose_result:{task_id}", { + "result": json.dumps(json_results), + "result_file": result_filename + }) + main_redis_client.hmset(f"task:{task_id}", { + "status": "completed", + "result_type": "pose", + "result_key": f"pose_result:{task_id}" + }) + print(f"图像 {filename} 处理完成,结果已保存") + else: + print(f"图像 {filename} 处理失败") + main_redis_client.hset(f"task:{task_id}", "status", "failed") + else: # video + print(f"开始处理视频: {filename}") + json_results = process_video(file_path) + if json_results: + result_filename = f"pose_{filename}" + redis_client.hmset(f"pose_result:{task_id}", { + "result": json.dumps(json_results), + "result_file": result_filename + }) + main_redis_client.hmset(f"task:{task_id}", { + "status": "completed", + "result_type": "pose", + "result_key": f"pose_result:{task_id}" + }) + print(f"视频 {filename} 处理完成,结果已保存") + else: + print(f"视频 {filename} 处理失败") + main_redis_client.hset(f"task:{task_id}", "status", "failed") + except Exception as e: + print(f"处理任务 {task_id} 时出错: {str(e)}") + main_redis_client.hmset(f"task:{task_id}", { + "status": "failed", + "error": str(e) + }) + + print(f"任务 {task_id} 处理完毕,等待下一个Kafka消息...") +def listen_redis_changes(): + pubsub = redis_client.pubsub() + pubsub.psubscribe('__keyspace@3__:pose_result:*') # 监听所有pose_result键的变化 + + for message in pubsub.listen(): + if message['type'] == 'pmessage': + key = message['channel'].decode('utf-8').split(':')[-1] + operation = message['data'].decode('utf-8') + + if operation == 'hmset': + value = redis_client.hgetall(f"pose_result:{key}") + if value: + result = {k.decode(): v.decode() for k, v in value.items()} + print(f"Result update for task {key}: {result}") + + +if __name__ == "__main__": + print("pose处理程序启动...") + # 启动处理任务的线程 + task_thread = threading.Thread(target=process_task, daemon=True) + task_thread.start() + print("任务处理线程已启动") + + # 启动Redis监听线程 + redis_thread = threading.Thread(target=listen_redis_changes, daemon=True) + redis_thread.start() + print("Redis监听线程已启动") + + print("主程序进入等待状态...") + # 保持主线程运行 + task_thread.join() + redis_thread.join() \ No newline at end of file diff --git a/api/producer.py b/api/producer.py new file mode 100644 index 0000000..32a7b46 --- /dev/null +++ b/api/producer.py @@ -0,0 +1,419 @@ +# main.py +from fastapi import FastAPI, File, UploadFile, Depends, HTTPException, Security +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import JSONResponse +from fastapi.security import APIKeyHeader +from kafka import KafkaProducer +from redis import Redis +import os +import json +import uuid +from datetime import datetime, timedelta, timezone +import string +from decord import VideoReader +from PIL import Image +from fastapi.responses import FileResponse +import logging +from config import * + +app = FastAPI() +v1_app = FastAPI() +app.mount("/v1", v1_app) + + +# CORS设置 +# ALLOWED_ORIGINS = ['https://beta.obscura.work'] + +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + +# 配置 +KAFKA_BROKER = KAFKA_BROKER +REDIS_HOST = REDIS_HOST +REDIS_PORT = REDIS_PORT +REDIS_PASSWORD = REDIS_PASSWORD +REDIS_DB = MAIN_REDIS_DB +REDIS_API_DB = REDIS_API_DB +REDIS_API_USAGE_DB = REDIS_API_USAGE_DB +UPLOAD_DIR = UPLOAD_DIR +RESULT_DIR = RESULT_DIR +MAX_FILE_AGE = timedelta(hours=1) + +# 确保目录存在 +os.makedirs(UPLOAD_DIR, exist_ok=True) +os.makedirs(RESULT_DIR, exist_ok=True) + +# 定义支持的任务类型 +KAFKA_TOPICS = { + 'pose': 'pose', + 'mediapipe': 'mediapipe', + 'qwenvl': 'qwenvl', + 'yolo': 'yolo', + 'fall': 'fall', + 'face': 'face', + 'cpm': 'cpm' +} + +TASK_TYPES = list(KAFKA_TOPICS.keys()) + + +# 初始化 Kafka Producer +producer = KafkaProducer( + bootstrap_servers=[KAFKA_BROKER], + value_serializer=lambda v: json.dumps(v).encode('utf-8') +) + +# 初始化 Redis +redis_client = Redis( + host=REDIS_HOST, + port=REDIS_PORT, + password=REDIS_PASSWORD, + db=REDIS_DB +) + +redis_api_client = Redis( + host=REDIS_HOST, + port=REDIS_PORT, + password=REDIS_PASSWORD, + db=REDIS_API_DB +) + +redis_api_usage_client = Redis( + host=REDIS_HOST, + port=REDIS_PORT, + password=REDIS_PASSWORD, + db=REDIS_API_USAGE_DB +) +redis_pose_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=WORKER_CONFIGS['pose']['redis_db']) +redis_cpm_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=WORKER_CONFIGS['cpm']['redis_db']) +redis_yolo_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=WORKER_CONFIGS['yolo']['redis_db']) +redis_face_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=WORKER_CONFIGS['face']['redis_db']) +redis_fall_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=WORKER_CONFIGS['fall']['redis_db']) +redis_mediapipe_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=WORKER_CONFIGS['mediapipe']['redis_db']) +redis_qwenvl_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=WORKER_CONFIGS['qwenvl']['redis_db']) + +@v1_app.get('/favicon.ico', include_in_schema=False) +async def favicon(): + file_name = "favicon.ico" + file_path = os.path.join(app.root_path, "static", file_name) + if os.path.isfile(file_path): + return FileResponse(file_path) + else: + return {"message": "Favicon not found"}, 404 + +# 定义API密钥头部 +api_key_header = APIKeyHeader(name="Authorization", auto_error=False) + +# 定义 base62 字符集 +BASE62 = string.digits + string.ascii_letters + +# 验证API密钥的函数 +async def get_api_key(api_key: str = Security(api_key_header)): + if api_key and api_key.startswith("Bearer "): + key = api_key.split(" ")[1] + if key.startswith("obs-"): + return key + raise HTTPException( + status_code=401, + detail="无效的API密钥", + headers={"WWW-Authenticate": "Bearer"}, + ) + +async def verify_api_key(api_key: str = Depends(get_api_key)): + logging.info(f"验证API密钥: {api_key}") + redis_key = f"api_key:{api_key}" + + api_key_info = redis_api_client.hgetall(redis_key) + + if not api_key_info: + logging.warning(f"API密钥不存在: {api_key}") + raise HTTPException(status_code=401, detail="无效的API密钥") + + api_key_info = {k.decode('utf-8'): v.decode('utf-8') for k, v in api_key_info.items()} + + if api_key_info.get('is_active') != '1': + logging.warning(f"API密钥已停用: {api_key}") + raise HTTPException(status_code=401, detail="API密钥已停用") + + expires_at = datetime.fromisoformat(api_key_info.get('expires_at')) + if datetime.now(timezone.utc) > expires_at: + logging.warning(f"API密钥已过期: {api_key}") + raise HTTPException(status_code=401, detail="API密钥已过期") + + usage_info = redis_api_usage_client.hgetall(redis_key) + usage_info = {k.decode('utf-8'): v.decode('utf-8') for k, v in usage_info.items()} + + logging.info(f"API密钥验证成功: {api_key}") + return { + "api_key": api_key, + **api_key_info, + **usage_info + } + +async def update_token_usage(api_key: str, new_tokens_used: int, model_name: str): + redis_key = f"api_key:{api_key}" + current_time = datetime.now(timezone.utc).isoformat() + + pipe = redis_api_usage_client.pipeline() + + # 更新总的token使用量 + pipe.hincrby(redis_key, "tokens_used", new_tokens_used) + pipe.hset(redis_key, "last_used_at", current_time) + + # 更新特定模型的token使用量 + model_tokens_field = f"{model_name}_tokens_used" + model_last_used_field = f"{model_name}_last_used_at" + + pipe.hincrby(redis_key, model_tokens_field, new_tokens_used) + pipe.hset(redis_key, model_last_used_field, current_time) + + pipe.execute() + +def calculate_tokens(file_path: str, file_type: str) -> int: + base_tokens = 0 + + try: + file_size = os.path.getsize(file_path) # 获取文件大小(字节) + + # 基础token:每MB文件大小消耗10个token + base_tokens = int((file_size / (1024 * 1024)) * 10) + + if file_type == "image": + img = Image.open(file_path) + width, height = img.size + pixel_count = width * height + + # 图片token:每100个像素额外消耗5个token + image_tokens = int((pixel_count / 10000) * 5) + + base_tokens += image_tokens + + elif file_type == "video": + vr = VideoReader(file_path) + fps = vr.get_avg_fps() + frame_count = len(vr) + width, height = vr[0].shape[1], vr[0].shape[0] + + pixel_count = width * height * frame_count + duration = frame_count / fps # 视频时长(秒) + + # 视频token:每100万像素每秒额外消耗1个token + video_tokens = int((pixel_count / 10000) * (duration / 60)) + + base_tokens += video_tokens + + return max(1, base_tokens) # 确保至少返回1个token + except Exception as e: + print(f"计算token时出错: {str(e)}") + return 1 # 出错时返回默认值1 + + + +async def upload_file(file: UploadFile, task_type: str, api_key_info: dict): + if task_type not in KAFKA_TOPICS: + raise HTTPException(status_code=400, detail="不支持的任务类型") + + content = await file.read() + file_extension = os.path.splitext(file.filename)[1].lower() + new_filename = f"{uuid.uuid4()}{file_extension}" + + file_path = os.path.join(UPLOAD_DIR, new_filename) + with open(file_path, "wb") as f: + f.write(content) + + # 计算 token + file_type = "image" if file_extension in ['.jpg', '.jpeg', '.png'] else "video" + tokens_required = calculate_tokens(file_path, file_type) + + # 检查并更新 token 使用量 + api_key = api_key_info['api_key'] + usage_key = f"api_key:{api_key}" + total_tokens = int(redis_api_usage_client.hget(usage_key, "total_tokens") or 0) + tokens_used = int(redis_api_usage_client.hget(usage_key, "tokens_used") or 0) + + if tokens_used + tokens_required > total_tokens: + raise HTTPException(status_code=403, detail="Token 余额不足") + + # 更新 token 使用量 + await update_token_usage(api_key, tokens_required, task_type) + + # 创建任务记录 + task_id = str(uuid.uuid4()) + task_data = { + "task_id": task_id, + "filename": new_filename, + "file_type": file_type, + "task_type": task_type, + "status": "queued", + "created_at": datetime.now(timezone.utc).isoformat(), + } + + # 存储任务信息到Redis + redis_client.set(f"task:{task_id}", json.dumps(task_data)) + logging.info(f"任务信息已存储到Redis: {task_id}") + + # 发送任务到对应的Kafka主题 + kafka_topic = KAFKA_TOPICS[task_type] + producer.send(kafka_topic, task_data) + logging.info(f"任务已发送到Kafka主题: {kafka_topic}") + + # 获取更新后的 token 使用情况 + updated_api_key_info = await verify_api_key(api_key) + new_tokens_used = int(updated_api_key_info.get("tokens_used", 0)) + model_tokens_used = int(updated_api_key_info.get(f"{task_type}_tokens_used", 0)) + + response_data = { + "message": "文件已上传并排队等待处理", + "task_id": task_id, + "tokens_used": tokens_required, + "total_tokens_used": new_tokens_used, + f"{task_type}_tokens_used": model_tokens_used, + "tokens_remaining": total_tokens - new_tokens_used + } + logging.info(f"上传文件完成: {task_id}") + return JSONResponse(content=response_data) + +# 为每个任务类型创建单独的端点 +@v1_app.post("/pose") +async def upload_pose(file: UploadFile = File(...), api_key_info: dict = Depends(verify_api_key)): + logging.info(f"收到 /pose端点的请求") + return await upload_file(file, task_type="pose", api_key_info=api_key_info) + +@v1_app.post("/cpm") +async def upload_cpm(file: UploadFile = File(...), api_key_info: dict = Depends(verify_api_key)): + return await upload_file(file, task_type="cpm", api_key_info=api_key_info) + +@v1_app.post("/qwenvl") +async def upload_qwenvl(file: UploadFile = File(...), api_key_info: dict = Depends(verify_api_key)): + return await upload_file(file, task_type="qwenvl", api_key_info=api_key_info) + +@v1_app.post("/yolo") +async def upload_yolo(file: UploadFile = File(...), api_key_info: dict = Depends(verify_api_key)): + return await upload_file(file, task_type="yolo", api_key_info=api_key_info) + +@v1_app.post("/fall") +async def upload_fall(file: UploadFile = File(...), api_key_info: dict = Depends(verify_api_key)): + return await upload_file(file, task_type="fall", api_key_info=api_key_info) + +@v1_app.post("/face") +async def upload_face(file: UploadFile = File(...), api_key_info: dict = Depends(verify_api_key)): + logging.info(f"收到 /face 端点的请求") + return await upload_file(file, task_type="face", api_key_info=api_key_info) + +@v1_app.post("/mediapipe") +async def upload_mediapipe(file: UploadFile = File(...), api_key_info: dict = Depends(verify_api_key)): + return await upload_file(file, task_type="mediapipe", api_key_info=api_key_info) + + +@v1_app.get("/result/{task_id}") +async def get_result(task_id: str, api_key: str = Depends(get_api_key)): + api_key_info = await verify_api_key(api_key) + if not api_key_info: + raise HTTPException(status_code=403, detail="无效的API密钥") + + # 从 REDIS_DB (15) 获取任务状态 + task_info = redis_client.hgetall(f"task:{task_id}") + if not task_info: + raise HTTPException(status_code=404, detail="Task not found") + + task_info = {k.decode('utf-8'): v.decode('utf-8') for k, v in task_info.items()} + + if task_info['status'] != 'completed': + return {"status": task_info['status'], "message": "Task is not completed yet"} + + result_type = task_info['result_type'] + result_key = task_info['result_key'] + + # 根据任务类型选择相应的 Redis 客户端 + redis_client_map = { + 'pose': redis_pose_client, + 'cpm': redis_cpm_client, + 'yolo': redis_yolo_client, + 'face': redis_face_client, + 'fall': redis_fall_client, + 'mediapipe': redis_mediapipe_client, + 'qwenvl': redis_qwenvl_client + } + + result_redis = redis_client_map.get(result_type) + if not result_redis: + raise HTTPException(status_code=400, detail="Unsupported result type") + + result = result_redis.hgetall(result_key) + if not result: + raise HTTPException(status_code=404, detail=f"{result_type.upper()} result not found") + + result = {k.decode('utf-8'): v.decode('utf-8') for k, v in result.items()} + + # 将 result 字段解析为 JSON(如果存在) + if 'result' in result: + result['result'] = json.loads(result['result']) + + return { + "status": "completed", + "result_type": result_type, + "result": result + } + +@v1_app.get("/annotated/{task_id}") +async def get_annotated_image(task_id: str, api_key: str = Depends(get_api_key)): + api_key_info = await verify_api_key(api_key) + if not api_key_info: + raise HTTPException(status_code=403, detail="无效的API密钥") + + # 从 REDIS_DB (15) 获取任务信息 + task_info = redis_client.hgetall(f"task:{task_id}") + if not task_info: + raise HTTPException(status_code=404, detail="Task not found") + + task_info = {k.decode('utf-8'): v.decode('utf-8') for k, v in task_info.items()} + + if task_info['status'] != 'completed': + raise HTTPException(status_code=400, detail="Task is not completed yet") + + result_type = task_info.get('result_type') + result_key = task_info.get('result_key') + + if not result_key: + raise HTTPException(status_code=404, detail="Result key not found") + + if result_type in ['cpm', 'qwenvl']: + raise HTTPException(status_code=400, detail="Annotated image not available for this task type") + + # 根据任务类型选择相应的 Redis 客户端 + redis_client_map = { + 'pose': redis_pose_client, + 'yolo': redis_yolo_client, + 'face': redis_face_client, + 'fall': redis_fall_client, + 'mediapipe': redis_mediapipe_client + } + + result_redis = redis_client_map.get(result_type) + if not result_redis: + raise HTTPException(status_code=400, detail="Unsupported result type") + + result = result_redis.hgetall(result_key) + if not result: + raise HTTPException(status_code=404, detail=f"{result_type.upper()} result not found") + + result = {k.decode('utf-8'): v.decode('utf-8') for k, v in result.items()} + + result_file = result.get('result_file') + if not result_file: + raise HTTPException(status_code=404, detail="Result file not found") + + file_path = os.path.join(RESULT_DIR, result_file) + if not os.path.exists(file_path): + raise HTTPException(status_code=404, detail="Result image file not found") + + return FileResponse(file_path, media_type="image/png") + +if __name__ == "__main__": + import uvicorn + uvicorn.run(app, host="0.0.0.0", port=8005) \ No newline at end of file diff --git a/api/qwenvl_analyze.py b/api/qwenvl_analyze.py new file mode 100644 index 0000000..95823fe --- /dev/null +++ b/api/qwenvl_analyze.py @@ -0,0 +1,271 @@ +import os +import json +import uuid +from datetime import datetime, timedelta +from kafka import KafkaConsumer +from transformers import Qwen2VLForConditionalGeneration, AutoProcessor +from qwen_vl_utils import process_vision_info +from decord import VideoReader, cpu +from PIL import Image +import redis +from redis import Redis +import io +import re +import threading +from config import * + +# 配置 +MODEL_PATH = QWEN_MODEL_PATH +KAFKA_BROKER = KAFKA_BROKER +KAFKA_TOPIC = WORKER_CONFIGS["qwenvl_analyze"]["kafka_topic"] +KAFKA_GROUP_ID = f"qwenvl_analyze_{KAFKA_GROUP_ID_PREFIX}" + +REDIS_HOST = REDIS_HOST +REDIS_PORT = REDIS_PORT +REDIS_PASSWORD = REDIS_PASSWORD +REDIS_DB = WORKER_CONFIGS["qwenvl_analyze"]["redis_db"] # Worker使用的Redis DB +MAIN_REDIS_DB = MAIN_REDIS_DB # 主Redis DB + +UPLOAD_DIR = UPLOAD_DIR +RESULT_DIR = RESULT_DIR +# 确保目录存在 +os.makedirs(UPLOAD_DIR, exist_ok=True) +os.makedirs(RESULT_DIR, exist_ok=True) + +# 初始化 Kafka +consumer = KafkaConsumer( + KAFKA_TOPIC, + bootstrap_servers=[KAFKA_BROKER], + group_id=KAFKA_GROUP_ID, + auto_offset_reset='earliest', + enable_auto_commit=True, + value_deserializer=lambda x: json.loads(x.decode('utf-8')) +) + +# 初始化 Redis +redis_client = Redis( + host=REDIS_HOST, + port=REDIS_PORT, + password=REDIS_PASSWORD, + db=REDIS_DB +) + +main_redis_client = Redis( + host=REDIS_HOST, + port=REDIS_PORT, + password=REDIS_PASSWORD, + db=MAIN_REDIS_DB +) + + +# 初始化模型 +model = Qwen2VLForConditionalGeneration.from_pretrained( + MODEL_PATH, torch_dtype="auto", device_map="cuda:0" +) + +min_pixels = 128*28*28 +max_pixels = 512*28*28 +processor = AutoProcessor.from_pretrained(MODEL_PATH, min_pixels=min_pixels, max_pixels=max_pixels) + +class MediaAnalysisSystem: + def __init__(self, model, processor): + self.model = model + self.processor = processor + self.MAX_NUM_FRAMES = 10 + + def encode_video(self, video_data): + def uniform_sample(l, n): + gap = len(l) / n + return [l[int(i * gap + gap / 2)] for i in range(n)] + + video_file = io.BytesIO(video_data) + vr = VideoReader(video_file) + sample_fps = round(vr.get_avg_fps() / 1) + frame_idx = list(range(0, len(vr), sample_fps)) + if len(frame_idx) > self.MAX_NUM_FRAMES: + frame_idx = uniform_sample(frame_idx, self.MAX_NUM_FRAMES) + frames = vr.get_batch(frame_idx).asnumpy() + frames = [Image.fromarray(v.astype('uint8')) for v in frames] + print('num frames:', len(frames)) + return frames + + def process_media(self, media_data, object_name, media_type='image'): + if not media_data: + raise ValueError(f"Empty {media_type} data for {object_name}") + + print(f"Processing {media_type}: {object_name}, data size: {len(media_data)} bytes") + + if media_type == 'video': + frames = self.encode_video(media_data) + media_content = {"type": "video", "video": frames, "fps": 1.0} + else: # image + image = Image.open(io.BytesIO(media_data)) + media_content = {"type": "image", "image": image} + + messages = [ + { + "role": "user", + "content": [ + media_content, + {"type": "text", "text": """您是一个高级OCR和文本分析助手。您的主要任务是: + 1) 从这个""" + ("视频" if media_type == "video" else "图片") + """中高精度提取所有文本内容,包括标准文本和数字数据, + 2) 对提取的内容进行全面分析, + 3) 将信息进行逻辑性的组织和结构化, + 4) 提供从文本/数字中发现的详细见解。 + + 对于数字数据,请包含统计分析。以清晰的层次格式呈现您的发现,请提供完整的原始文本。如果遇到任何不清楚或模糊的元素,请突出显示并要求澄清。 + + 请支持: + - 多种语言(简体中文、繁体中文、英文) + - 不同的文本方向和布局 + - 表格、图表和结构化数据格式 + - 特殊字符、符号和数学符号 + - 原始格式和文本位置信息"""}, + ], + } + ] + + text = self.processor.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + image_inputs, video_inputs = process_vision_info(messages) + inputs = self.processor( + text=[text], + images=image_inputs, + videos=video_inputs, + padding=True, + return_tensors="pt", + ) + inputs = inputs.to('cuda:0') + generated_ids = self.model.generate(**inputs, max_new_tokens=2048) + generated_ids_trimmed = [ + out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) + ] + answer = self.processor.batch_decode( + generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False + )[0] + + result = { + "original_answer": answer, + "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S") + } + + if media_type == 'video': + result["num_frames"] = len(frames) + + return result + + def process_video(self, video_data, object_name): + return self.process_media(video_data, object_name, media_type='video') + + def process_image(self, image_data, object_name): + return self.process_media(image_data, object_name, media_type='image') + + @staticmethod + def extract_time_from_filename(object_name): + filename = os.path.basename(object_name) + time_str = filename.split('_')[0] + '_' + filename.split('_')[1].split('.')[0] + + try: + start_time = datetime.strptime(time_str, "%Y%m%d_%H%M%S") + end_time = start_time + timedelta(seconds=10) + return start_time, end_time + except ValueError: + print(f"无法从文件名 '{filename}' 解析时间。使用默认时间。") + return datetime.now(), datetime.now() + timedelta(seconds=10) + + +# 初始化 MediaAnalysisSystem +media_analysis_system = MediaAnalysisSystem(model, processor) + +def process_task(): + print("开始处理任务,等待Kafka消息...") + for message in consumer: + print(f"收到Kafka消息: topic={message.topic}, partition={message.partition}, offset={message.offset}") + task = message.value + task_id = task['task_id'] + filename = task['filename'] + file_type = task['file_type'] + + print(f"解析任务信息: ID={task_id}, filename={filename}, type={file_type}") + + file_path = os.path.join(UPLOAD_DIR, filename) + # Check key type and update status + task_key = f"task:{task_id}" + try: + key_type = main_redis_client.type(task_key) + if key_type != b'hash': + main_redis_client.delete(task_key) + main_redis_client.hset(f"task:{task_id}", "status", "processing") + print(f"任务 {task_id} 状态更新为 'processing'") + except redis.exceptions.ResponseError as e: + print(f"更新任务 {task_id} 状态时出错: {str(e)}") + continue # Skip this task and continue with the next one + + try: + if file_type == "image": + print(f"Processing image: {filename}") + with open(file_path, 'rb') as f: + image_data = f.read() + result = media_analysis_system.process_image(image_data, filename) + else: # video + print(f"Processing video: {filename}") + with open(file_path, 'rb') as f: + video_data = f.read() + result = media_analysis_system.process_video(video_data, filename) + + if result: + redis_client.hset(f"qwenvl_analyze_result:{task_id}", mapping={ + "result": json.dumps(result), + "result_file": filename + }) + main_redis_client.hset(f"task:{task_id}", "status", "completed") + main_redis_client.hset(f"task:{task_id}", "result_type", "qwenvl_analyze") + main_redis_client.hset(f"task:{task_id}", "result_key", f"qwenvl_analyze_result:{task_id}") + print(f"{file_type.capitalize()} {filename} processed, result saved") + else: + print(f"{file_type.capitalize()} {filename} processing failed") + main_redis_client.hset(f"task:{task_id}", "status", "failed") + except Exception as e: + error_msg = str(e) + print(f"处理任务 {task_id} 时出错: {error_msg}") + try: + # 分开设置每个字段,避免使用字典 + main_redis_client.hset(f"task:{task_id}", "status", "failed") + main_redis_client.hset(f"task:{task_id}", "error", error_msg) + except Exception as redis_error: + print(f"更新任务状态时出错: {str(redis_error)}") + + print(f"任务 {task_id} 处理完毕,等待下一个Kafka消息...") + +def listen_redis_changes(): + pubsub = redis_client.pubsub() + pubsub.psubscribe('__keyspace@3__:qwenvl_analyze_result:*') # Listen for changes on all qwenvl_result keys + + for message in pubsub.listen(): + if message['type'] == 'pmessage': + key = message['channel'].decode('utf-8').split(':')[-1] + operation = message['data'].decode('utf-8') + + if operation == 'hset': + value = redis_client.hgetall(f"qwenvl_analyze_result:{key}") + if value: + result = {k.decode(): v.decode() for k, v in value.items()} + print(f"Result update for task {key}: {result}") + +if __name__ == "__main__": + print("qwenvl_analyze处理程序启动...") + # Start the task processing thread + task_thread = threading.Thread(target=process_task, daemon=True) + task_thread.start() + print("任务处理线程已启动") + + # Start the Redis listening thread + redis_thread = threading.Thread(target=listen_redis_changes, daemon=True) + redis_thread.start() + print("Redis监听线程已启动") + + print("主程序进入等待状态...") + # Keep the main thread running + task_thread.join() + redis_thread.join() \ No newline at end of file diff --git a/api/qwenvl_scene.py b/api/qwenvl_scene.py new file mode 100644 index 0000000..c5f594d --- /dev/null +++ b/api/qwenvl_scene.py @@ -0,0 +1,329 @@ +import os +import json +import uuid +from datetime import datetime, timedelta +from kafka import KafkaConsumer +from transformers import Qwen2VLForConditionalGeneration, AutoProcessor +from qwen_vl_utils import process_vision_info +from decord import VideoReader, cpu +from PIL import Image +import redis +from redis import Redis +import io +import re +import threading +from config import * + +# 配置 +MODEL_PATH = QWEN_MODEL_PATH +KAFKA_BROKER = KAFKA_BROKER +KAFKA_TOPIC = WORKER_CONFIGS["qwenvl"]["kafka_topic"] +KAFKA_GROUP_ID = f"qwenvl_{KAFKA_GROUP_ID_PREFIX}" + +REDIS_HOST = REDIS_HOST +REDIS_PORT = REDIS_PORT +REDIS_PASSWORD = REDIS_PASSWORD +REDIS_DB = WORKER_CONFIGS["qwenvl"]["redis_db"] # Worker使用的Redis DB +MAIN_REDIS_DB = MAIN_REDIS_DB # 主Redis DB + +UPLOAD_DIR = UPLOAD_DIR +RESULT_DIR = RESULT_DIR +# 确保目录存在 +os.makedirs(UPLOAD_DIR, exist_ok=True) +os.makedirs(RESULT_DIR, exist_ok=True) + +# 初始化 Kafka +consumer = KafkaConsumer( + KAFKA_TOPIC, + bootstrap_servers=[KAFKA_BROKER], + group_id=KAFKA_GROUP_ID, + auto_offset_reset='earliest', + enable_auto_commit=True, + value_deserializer=lambda x: json.loads(x.decode('utf-8')) +) + +# 初始化 Redis +redis_client = Redis( + host=REDIS_HOST, + port=REDIS_PORT, + password=REDIS_PASSWORD, + db=REDIS_DB +) + +main_redis_client = Redis( + host=REDIS_HOST, + port=REDIS_PORT, + password=REDIS_PASSWORD, + db=MAIN_REDIS_DB +) + + +# 初始化模型 +model = Qwen2VLForConditionalGeneration.from_pretrained( + MODEL_PATH, torch_dtype="auto", device_map="cuda:1" +) + +min_pixels = 128*28*28 +max_pixels = 512*28*28 +processor = AutoProcessor.from_pretrained(MODEL_PATH, min_pixels=min_pixels, max_pixels=max_pixels) + +class MediaAnalysisSystem: + def __init__(self, model, processor): + self.model = model + self.processor = processor + self.MAX_NUM_FRAMES = 10 + + def encode_video(self, video_data): + def uniform_sample(l, n): + gap = len(l) / n + return [l[int(i * gap + gap / 2)] for i in range(n)] + + video_file = io.BytesIO(video_data) + vr = VideoReader(video_file) + sample_fps = round(vr.get_avg_fps() / 1) + frame_idx = list(range(0, len(vr), sample_fps)) + if len(frame_idx) > self.MAX_NUM_FRAMES: + frame_idx = uniform_sample(frame_idx, self.MAX_NUM_FRAMES) + frames = vr.get_batch(frame_idx).asnumpy() + frames = [Image.fromarray(v.astype('uint8')) for v in frames] + print('num frames:', len(frames)) + return frames + + def process_media(self, media_data, object_name, media_type='image'): + if not media_data: + raise ValueError(f"Empty {media_type} data for {object_name}") + + print(f"Processing {media_type}: {object_name}, data size: {len(media_data)} bytes") + + if media_type == 'video': + frames = self.encode_video(media_data) + media_content = {"type": "video", "video": frames, "fps": 1.0} + else: # image + image = Image.open(io.BytesIO(media_data)) + media_content = {"type": "image", "image": image} + + messages = [ + { + "role": "user", + "content": [ + media_content, + {"type": "text", "text": f"""请对这{'段监控视频' if media_type == 'video' else '张监控图像'}进行详细分析,包括以下方面: + 1. 场景中人数的精确统计 + 2. 每个人的个人行为分析 + 3. 面部表情识别和情绪状态评估 + 4. 整体场景和环境的详细描述 + 5. 人与人之间的互动情况 + 6. 时间和环境条件(如果可见) + 7. 任何可疑或异常活动 + 8. 人员的具体特征(估计年龄范围、性别、着装) + 9. 人员的{'移动模式和方向' if media_type == 'video' else '位置和姿态'} + 10. 携带的物品或物体 + 11. 群体动态和聚集情况 + 12. {'视频' if media_type == 'video' else '图像'}中的时间戳信息(如果有) + + 请用清晰、有条理的格式描述,并突出重要发现。"""}, + ], + } + ] + + text = self.processor.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + image_inputs, video_inputs = process_vision_info(messages) + inputs = self.processor( + text=[text], + images=image_inputs, + videos=video_inputs, + padding=True, + return_tensors="pt", + ) + inputs = inputs.to('cuda:1') + generated_ids = self.model.generate(**inputs, max_new_tokens=2048) + generated_ids_trimmed = [ + out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) + ] + answer = self.processor.batch_decode( + generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False + )[0] + + extracted_info = self.extract_info(answer) + + result = { + "original_answer": answer, + "extracted_info": extracted_info, + "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S") + } + + if media_type == 'video': + result["num_frames"] = len(frames) + + return result + + def process_video(self, video_data, object_name): + return self.process_media(video_data, object_name, media_type='video') + + def process_image(self, image_data, object_name): + return self.process_media(image_data, object_name, media_type='image') + + @staticmethod + def extract_time_from_filename(object_name): + filename = os.path.basename(object_name) + time_str = filename.split('_')[0] + '_' + filename.split('_')[1].split('.')[0] + + try: + start_time = datetime.strptime(time_str, "%Y%m%d_%H%M%S") + end_time = start_time + timedelta(seconds=10) + return start_time, end_time + except ValueError: + print(f"无法从文件名 '{filename}' 解析时间。使用默认时间。") + return datetime.now(), datetime.now() + timedelta(seconds=10) + + @staticmethod + def extract_info(answer): + info = { + "environment": None, + "num_people": None, + "actions": [], + "interactions": [], + "objects": [], + "furniture": [] + } + + environments = ["办公室", "室内", "室外", "会议室"] + for env in environments: + if env in answer.lower(): + info["environment"] = env + break + + people_patterns = [ + r'(\d+)\s*(人|个人|位|名|员工|用户|小朋友|成年人|女性|男性)', + r'(一|二|三|四|五|六|七|八|九|十)\s*(人|个人|位|名|员工|用户|小朋友|成年人|女性|男性)', + r'(一个|几个)\s*(人|个人|员工|用户|小朋友|成年人|女性|男性)', + r'几\s*(名|位)\s*(人|员工|用户|小朋友|成年人|女性|男性)?', + r'(男|女)(性|生|士)', + r'(成年|未成年|青少年|老年)\s*(人|群体)', + r'(员工|职工|工人|学生|顾客|观众|游客|乘客)', + r'(群众|民众|大众|公众)', + r'(男女|老少|老幼|大人|小孩)' + ] + for pattern in people_patterns: + match = re.search(pattern, answer) + if match: + if match.group(1).isdigit(): + info["num_people"] = int(match.group(1)) + elif match.group(1) in ['一个', '一']: + info["num_people"] = 1 + else: + num_word_to_digit = { + '二': 2, '三': 3, '四': 4, '五': 5, + '六': 6, '七': 7, '八': 8, '九': 9, '十': 10 + } + info["num_people"] = num_word_to_digit.get(match.group(1), 0) + break + + actions = ["坐", "站", "摔倒", "跳舞", "转身", "摔", "倒", "倒下", "躺下", "转身", "跳跃", "跳", "躺", "睡", "说话"] + interactions = ["互动", "交流", "身体语言", "交谈", "讨论", "开会"] + objects = ["水瓶", "办公用品", "文件", "电脑"] + furniture = ["椅子", "桌子", "咖啡桌", "文件柜", "床", "沙发"] + + for item_list, key in [(actions, "actions"), (interactions, "interactions"), (objects, "objects"), (furniture, "furniture")]: + for item in item_list: + if item in answer: + info[key].append(item) + + return info + +# 初始化 MediaAnalysisSystem +media_analysis_system = MediaAnalysisSystem(model, processor) + +def process_task(): + print("开始处理任务,等待Kafka消息...") + for message in consumer: + print(f"收到Kafka消息: topic={message.topic}, partition={message.partition}, offset={message.offset}") + task = message.value + task_id = task['task_id'] + filename = task['filename'] + file_type = task['file_type'] + + print(f"解析任务信息: ID={task_id}, filename={filename}, type={file_type}") + + file_path = os.path.join(UPLOAD_DIR, filename) + # Check key type and update status + task_key = f"task:{task_id}" + try: + key_type = main_redis_client.type(task_key) + if key_type != b'hash': + main_redis_client.delete(task_key) + main_redis_client.hset(f"task:{task_id}", "status", "processing") + print(f"任务 {task_id} 状态更新为 'processing'") + except redis.exceptions.ResponseError as e: + print(f"更新任务 {task_id} 状态时出错: {str(e)}") + continue # Skip this task and continue with the next one + + try: + if file_type == "image": + print(f"Processing image: {filename}") + with open(file_path, 'rb') as f: + image_data = f.read() + result = media_analysis_system.process_image(image_data, filename) + else: # video + print(f"Processing video: {filename}") + with open(file_path, 'rb') as f: + video_data = f.read() + result = media_analysis_system.process_video(video_data, filename) + + if result: + redis_client.hset(f"qwenvl_result:{task_id}", mapping={ + "result": json.dumps(result), + "result_file": filename + }) + main_redis_client.hset(f"task:{task_id}", "status", "completed") + main_redis_client.hset(f"task:{task_id}", "result_type", "qwenvl") + main_redis_client.hset(f"task:{task_id}", "result_key", f"qwenvl_result:{task_id}") + print(f"{file_type.capitalize()} {filename} processed, result saved") + else: + print(f"{file_type.capitalize()} {filename} processing failed") + main_redis_client.hset(f"task:{task_id}", "status", "failed") + except Exception as e: + error_msg = str(e) + print(f"处理任务 {task_id} 时出错: {error_msg}") + try: + # 分开设置每个字段,避免使用字典 + main_redis_client.hset(f"task:{task_id}", "status", "failed") + main_redis_client.hset(f"task:{task_id}", "error", error_msg) + except Exception as redis_error: + print(f"更新任务状态时出错: {str(redis_error)}") + + print(f"任务 {task_id} 处理完毕,等待下一个Kafka消息...") + +def listen_redis_changes(): + pubsub = redis_client.pubsub() + pubsub.psubscribe('__keyspace@3__:qwenvl_result:*') # Listen for changes on all qwenvl_result keys + + for message in pubsub.listen(): + if message['type'] == 'pmessage': + key = message['channel'].decode('utf-8').split(':')[-1] + operation = message['data'].decode('utf-8') + + if operation == 'hset': + value = redis_client.hgetall(f"qwenvl_result:{key}") + if value: + result = {k.decode(): v.decode() for k, v in value.items()} + print(f"Result update for task {key}: {result}") + +if __name__ == "__main__": + print("qwenvl处理程序启动...") + # Start the task processing thread + task_thread = threading.Thread(target=process_task, daemon=True) + task_thread.start() + print("任务处理线程已启动") + + # Start the Redis listening thread + redis_thread = threading.Thread(target=listen_redis_changes, daemon=True) + redis_thread.start() + print("Redis监听线程已启动") + + print("主程序进入等待状态...") + # Keep the main thread running + task_thread.join() + redis_thread.join() \ No newline at end of file diff --git a/api/yolo.py b/api/yolo.py new file mode 100644 index 0000000..947523c --- /dev/null +++ b/api/yolo.py @@ -0,0 +1,282 @@ +import os +import cv2 +import torch +import numpy as np +from redis import Redis +from ultralytics import YOLO +import json +from kafka import KafkaConsumer +import threading +import redis +from config import * + +# 配置 +MODEL_PATH = YOLO_MODEL_PATH +KAFKA_BROKER = KAFKA_BROKER +KAFKA_TOPIC = WORKER_CONFIGS["yolo"]["kafka_topic"] +KAFKA_GROUP_ID = f"yolo_{KAFKA_GROUP_ID_PREFIX}" + +REDIS_HOST = REDIS_HOST +REDIS_PORT = REDIS_PORT +REDIS_PASSWORD = REDIS_PASSWORD +REDIS_DB = WORKER_CONFIGS["yolo"]["redis_db"] # Worker使用的Redis DB +MAIN_REDIS_DB = MAIN_REDIS_DB # 主Redis DB + +UPLOAD_DIR = UPLOAD_DIR +RESULT_DIR = RESULT_DIR + + +# 确保目录存在 +os.makedirs(UPLOAD_DIR, exist_ok=True) +os.makedirs(RESULT_DIR, exist_ok=True) + +KAFKA_TOPIC = 'yolo' +# 初始化 Kafka Consumer +consumer = KafkaConsumer( + KAFKA_TOPIC, + bootstrap_servers=[KAFKA_BROKER], + group_id=KAFKA_GROUP_ID, + auto_offset_reset='earliest', + enable_auto_commit=True, + value_deserializer=lambda x: json.loads(x.decode('utf-8')) +) + + +# 初始化 Redis +redis_client = Redis( + host=REDIS_HOST, + port=REDIS_PORT, + password=REDIS_PASSWORD, + db=REDIS_DB +) + +main_redis_client = Redis( + host=REDIS_HOST, + port=REDIS_PORT, + password=REDIS_PASSWORD, + db=MAIN_REDIS_DB +) + +# YOLO detector class +class yoloDetector: + def __init__(self, model_path): + self.model = YOLO(model_path) + + def detect(self, frame): + results = self.model(frame) + return results + + def format_results(self, results, original_shape): + formatted_results = [] + for r in results: + boxes = r.boxes + for box in boxes: + x1, y1, x2, y2 = box.xyxy[0].tolist() + + x1, x2 = [x * original_shape[1] / 640 for x in [x1, x2]] + y1, y2 = [y * original_shape[0] / 640 for y in [y1, y2]] + + conf = box.conf.item() + cls = int(box.cls.item()) + name = self.model.names[cls] + + formatted_results.append({ + "class": name, + "confidence": conf, + "bbox": [x1, y1, x2, y2] + }) + return formatted_results + + def draw_results(self, frame, formatted_results): + for result in formatted_results: + x1, y1, x2, y2 = map(int, result['bbox']) + name = result['class'] + conf = result['confidence'] + + cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 255, 0), 2) + label = f"{name} {conf:.2f}" + (text_width, text_height), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 1, 2) + cv2.rectangle(frame, (x1, y1 - text_height - 5), (x1 + text_width, y1), (0, 255, 0), -1) + cv2.putText(frame, label, (x1, y1 - 5), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 0), 2) + + return frame + +detector = yoloDetector(MODEL_PATH) + +def process_image(image_path): + try: + original_img = cv2.imread(image_path) + original_shape = original_img.shape + + img = cv2.cvtColor(original_img, cv2.COLOR_BGR2RGB) + img = cv2.resize(img, (640, 640)) + img = img.transpose((2, 0, 1)) + img = np.ascontiguousarray(img) + img = torch.from_numpy(img).float() + img /= 255.0 + img = img.unsqueeze(0) + + results = detector.detect(img) + + json_results = detector.format_results(results, original_shape) + + annotated_img = detector.draw_results(original_img, json_results) + + return json_results, annotated_img + except Exception as e: + print(f"处理图像时出错: {str(e)}") + return None, None + +def process_video(video_path): + try: + cap = cv2.VideoCapture(video_path) + frame_count = 0 + json_results = [] + + fps = int(cap.get(cv2.CAP_PROP_FPS)) + width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + original_shape = (height, width) + + out = cv2.VideoWriter(video_path.replace(UPLOAD_DIR, RESULT_DIR), cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height)) + + while cap.isOpened(): + ret, frame = cap.read() + if not ret: + break + + if frame_count % fps == 0: + preprocessed_frame = preprocess_frame(frame) + + results = detector.detect(preprocessed_frame) + frame_json_results = detector.format_results(results, original_shape) + json_results.append({"frame": frame_count, "detections": frame_json_results}) + + annotated_frame = detector.draw_results(frame, frame_json_results) + out.write(annotated_frame) + + frame_count += 1 + + cap.release() + out.release() + + return json_results + except Exception as e: + print(f"处理视频时出错: {str(e)}") + return None + +def preprocess_frame(frame): + frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + frame_resized = cv2.resize(frame_rgb, (640, 640)) + frame_transposed = frame_resized.transpose((2, 0, 1)) + frame_contiguous = np.ascontiguousarray(frame_transposed) + frame_tensor = torch.from_numpy(frame_contiguous).float() + frame_normalized = frame_tensor / 255.0 + frame_batched = frame_normalized.unsqueeze(0) + return frame_batched + +def process_task(): + print("开始处理任务,等待Kafka消息...") + for message in consumer: + print(f"收到Kafka消息: topic={message.topic}, partition={message.partition}, offset={message.offset}") + task = message.value + task_id = task['task_id'] + filename = task['filename'] + file_type = task['file_type'] + + print(f"解析任务信息: ID={task_id}, 文件名={filename}, 类型={file_type}") + + file_path = os.path.join(UPLOAD_DIR, filename) + # 检查键类型并更新状态 + task_key = f"task:{task_id}" + try: + key_type = main_redis_client.type(task_key) + if key_type != b'hash': + main_redis_client.delete(task_key) + main_redis_client.hset(task_key, "status", "processing") + print(f"任务 {task_id} 状态更新为 'processing'") + except redis.exceptions.ResponseError as e: + print(f"更新任务 {task_id} 状态时出错: {str(e)}") + continue # 跳过这个任务,继续处理下一个 + + try: + if file_type == "image": + print(f"开始处理图像: {filename}") + json_results, annotated_img = process_image(file_path) + if json_results and annotated_img is not None: + result_filename = f"yolo_{filename}" + result_path = os.path.join(RESULT_DIR, result_filename) + cv2.imwrite(result_path, annotated_img) + + redis_client.hmset(f"yolo_result:{task_id}", { + "result": json.dumps(json_results), + "result_file": result_filename + }) + main_redis_client.hmset(f"task:{task_id}", { + "status": "completed", + "result_type": "yolo", + "result_key": f"yolo_result:{task_id}" + }) + print(f"图像 {filename} 处理完成,结果已保存") + else: + print(f"图像 {filename} 处理失败") + main_redis_client.hset(f"task:{task_id}", "status", "failed") + else: # video + print(f"开始处理视频: {filename}") + json_results = process_video(file_path) + if json_results: + result_filename = f"yolo_{filename}" + redis_client.hmset(f"yolo_result:{task_id}", { + "result": json.dumps(json_results), + "result_file": result_filename + }) + main_redis_client.hmset(f"task:{task_id}", { + "status": "completed", + "result_type": "yolo", + "result_key": f"yolo_result:{task_id}" + }) + print(f"视频 {filename} 处理完成,结果已保存") + else: + print(f"视频 {filename} 处理失败") + main_redis_client.hset(f"task:{task_id}", "status", "failed") + except Exception as e: + print(f"处理任务 {task_id} 时出错: {str(e)}") + main_redis_client.hmset(f"task:{task_id}", { + "status": "failed", + "error": str(e) + }) + + print(f"任务 {task_id} 处理完毕,等待下一个Kafka消息...") + +def listen_redis_changes(): + pubsub = redis_client.pubsub() + pubsub.psubscribe('__keyspace@3__:yolo_result:*') # 监听所有yolo_result键的变化 + + for message in pubsub.listen(): + if message['type'] == 'pmessage': + key = message['channel'].decode('utf-8').split(':')[-1] + operation = message['data'].decode('utf-8') + + if operation == 'hmset': + value = redis_client.hgetall(f"yolo_result:{key}") + if value: + result = {k.decode(): v.decode() for k, v in value.items()} + print(f"Result update for task {key}: {result}") + + +if __name__ == "__main__": + print("YOLO处理程序启动...") + # 启动处理任务的线程 + task_thread = threading.Thread(target=process_task, daemon=True) + task_thread.start() + print("任务处理线程已启动") + + # 启动Redis监听线程 + redis_thread = threading.Thread(target=listen_redis_changes, daemon=True) + redis_thread.start() + print("Redis监听线程已启动") + + print("主程序进入等待状态...") + # 保持主线程运行 + task_thread.join() + redis_thread.join() \ No newline at end of file diff --git a/api_chat/.env b/api_chat/.env new file mode 100644 index 0000000..d4f73ec --- /dev/null +++ b/api_chat/.env @@ -0,0 +1,94 @@ +# Kafka 配置 +KAFKA_BROKER=222.186.10.253:9092 +KAFKA_ASR_TOPIC=asr +KAFKA_CHAT_TOPIC=chat +KAFKA_TTS_TOPIC=tts + + +# Redis 配置 +REDIS_HOST=150.158.144.159 +REDIS_PORT=13003 +REDIS_ASR_DB=12 +REDIS_CHAT_DB=13 +REDIS_TTS_DB=14 +REDIS_PASSWORD=Obscura@2024 +REDIS_API_DB=2 +REDIS_API_USAGE_DB=3 +REDIS_TASK_DB=11 +REDIS_SESSION_DB=63 + +REDIS_SESSION_DB_ZH=63 +REDIS_SESSION_DB_EN=62 +REDIS_SESSION_DB_KO=61 + +# CORS 配置 +# ALLOWED_ORIGINS=https://beta.obscura.work + + +# GPT-SoVITS 配置 +GPT_MODEL_PATH=GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt +SOVITS_MODEL_PATH=GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth +REF_AUDIO_PATH=sample/woman.wav +REF_TEXT_PATH=sample/woman.txt +REF_LANGUAGE=中文 +TARGET_LANGUAGE=多语种混合 +OUTPUT_PATH=/obscura/task/audio_files + +# VOICE_CONFIGS +GIRL_REF_AUDIO=sample/gril.wav +GIRL_REF_TEXT=sample/gril.txt + +WOMAN_REF_AUDIO=sample/woman.wav +WOMAN_REF_TEXT=sample/woman.txt + + +MAN_REF_AUDIO=sample/man.wav +MAN_REF_TEXT=sample/man.txt + +LEIJUN_REF_AUDIO=sample/leijun.wav +LEIJUN_REF_TEXT=sample/leijun.txt + +DUFU_REF_AUDIO=sample/dufu.wav +DUFU_REF_TEXT=sample/dufu.txt + +HEJIONG_REF_AUDIO=sample/hejiong.wav +HEJIONG_REF_TEXT=sample/hejiong.txt + +MAHUATENG_REF_AUDIO=sample/mahuateng.wav +MAHUATENG_REF_TEXT=sample/mahuateng.txt + +LIDAN_REF_AUDIO=sample/lidan.wav +LIDAN_REF_TEXT=sample/lidan.txt + +YUHUA_REF_AUDIO=sample/yuhua.wav +YUHUA_REF_TEXT=sample/yuhua.txt + +LIUZHENYUN_REF_AUDIO=sample/liuzhenyun.wav +LIUZHENYUN_REF_TEXT=sample/liuzhenyun.txt + +DABING_REF_AUDIO=sample/dabing.wav +DABING_REF_TEXT=sample/dabing.txt + +LUOXIANG_REF_AUDIO=sample/luoxiang.wav +LUOXIANG_REF_TEXT=sample/luoxiang.txt + +XUZHIYUAN_REF_AUDIO=sample/xuzhiyuan.wav +XUZHIYUAN_REF_TEXT=sample/xuzhiyuan.txt + + +REDIS_GIRL_DB = 15 +REDIS_WOMAN_DB = 16 +REDIS_MAN_DB = 17 +REDIS_LEIJUN_DB = 18 +REDIS_DUFU_DB = 19 +REDIS_HEJIONG_DB = 20 +REDIS_MAHUATENG_DB = 21 +REDIS_LIDAN_DB = 22 +REDIS_DABING_DB = 23 +REDIS_LUOXIANG_DB = 24 +REDIS_XUZHIYUAN_DB = 25 +REDIS_YUHUA_DB = 26 +REDIS_LIUZHENYUN_DB = 27 + + + diff --git a/api_chat/GPT_SoVITS/.gitkeep b/api_chat/GPT_SoVITS/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/api_chat/OpenBMB/.gitkeep b/api_chat/OpenBMB/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/api_chat/TEMP/jieba.cache b/api_chat/TEMP/jieba.cache new file mode 100644 index 0000000..429e68b Binary files /dev/null and b/api_chat/TEMP/jieba.cache differ diff --git a/api_chat/__pycache__/ollama.cpython-311.pyc b/api_chat/__pycache__/ollama.cpython-311.pyc new file mode 100644 index 0000000..e90f8f8 Binary files /dev/null and b/api_chat/__pycache__/ollama.cpython-311.pyc differ diff --git a/api_chat/__pycache__/ollama_api.cpython-311.pyc b/api_chat/__pycache__/ollama_api.cpython-311.pyc new file mode 100644 index 0000000..ee93b61 Binary files /dev/null and b/api_chat/__pycache__/ollama_api.cpython-311.pyc differ diff --git a/api_chat/__pycache__/ollamas.cpython-311.pyc b/api_chat/__pycache__/ollamas.cpython-311.pyc new file mode 100644 index 0000000..cf15090 Binary files /dev/null and b/api_chat/__pycache__/ollamas.cpython-311.pyc differ diff --git a/api_chat/__pycache__/sovits_api.cpython-311.pyc b/api_chat/__pycache__/sovits_api.cpython-311.pyc new file mode 100644 index 0000000..8a1c93b Binary files /dev/null and b/api_chat/__pycache__/sovits_api.cpython-311.pyc differ diff --git a/api_chat/__pycache__/whisper.cpython-311.pyc b/api_chat/__pycache__/whisper.cpython-311.pyc new file mode 100644 index 0000000..7e33804 Binary files /dev/null and b/api_chat/__pycache__/whisper.cpython-311.pyc differ diff --git a/api_chat/__pycache__/whisper_api.cpython-311.pyc b/api_chat/__pycache__/whisper_api.cpython-311.pyc new file mode 100644 index 0000000..f0e4bfe Binary files /dev/null and b/api_chat/__pycache__/whisper_api.cpython-311.pyc differ diff --git a/api_chat/asr.py b/api_chat/asr.py new file mode 100644 index 0000000..c07d8ec --- /dev/null +++ b/api_chat/asr.py @@ -0,0 +1,110 @@ +import whisper +import os +import json +import redis +from dotenv import load_dotenv +from kafka import KafkaConsumer +import asyncio + +# 设置要使用的GPU ID +GPU_ID = 1 # 修改这个值来选择要使用的GPU + +# 设置CUDA_VISIBLE_DEVICES环境变量 +os.environ["CUDA_VISIBLE_DEVICES"] = str(GPU_ID) + +# 加载环境变量 +load_dotenv() + +print("正在加载Whisper模型...") +model = whisper.load_model("large-v3") +print("Whisper模型加载完成。") + +# Kafka配置 +KAFKA_BROKER = os.getenv('KAFKA_BROKER') +KAFKA_TOPIC = os.getenv('KAFKA_ASR_TOPIC') + +# Redis配置 +REDIS_HOST = os.getenv('REDIS_HOST') +REDIS_PORT = int(os.getenv('REDIS_PORT')) +REDIS_ASR_DB = int(os.getenv('REDIS_ASR_DB')) +REDIS_PASSWORD = os.getenv('REDIS_PASSWORD') +REDIS_TASK_DB = int(os.getenv('REDIS_TASK_DB')) + +# 创建Redis客户端 +redis_asr_client = redis.Redis( + host=REDIS_HOST, + port=REDIS_PORT, + db=REDIS_ASR_DB, + password=REDIS_PASSWORD +) + +redis_task_client = redis.Redis( + host=REDIS_HOST, + port=REDIS_PORT, + db=REDIS_TASK_DB, + password=REDIS_PASSWORD +) + +async def process_audio(file_path: str, cache_key: str): + try: + # 更新任务状态 + redis_task_client.set(f"task_status:{cache_key}", "processing") + + result = model.transcribe(file_path) + transcription = result['text'] + + print(f"处理了文件: {file_path}") + print(f"转录结果: {transcription}") + + redis_asr_client.setex(cache_key, 3600, transcription) + + result_data = { + 'transcription': transcription + } + redis_asr_client.publish('asr_results', json.dumps(result_data)) + + # 更新任务状态 + redis_task_client.set(f"task_status:{cache_key}", "completed") + + os.remove(file_path) + + except Exception as e: + print(f"处理音频文件时发生错误: {str(e)}") + # 更新任务状态 + redis_task_client.set(f"task_status:{cache_key}", "error") + +async def kafka_consumer(): + consumer = KafkaConsumer( + KAFKA_TOPIC, + bootstrap_servers=[KAFKA_BROKER], + value_deserializer=lambda x: json.loads(x.decode('utf-8')), + group_id='asr_group', + auto_offset_reset='earliest', + enable_auto_commit=True + ) + + print(f"ASR消费者已启动") + + for message in consumer: + try: + task = message.value + file_path = task.get('file_path') + task_id = task.get('task_id') + status = task.get('status') + + if not file_path or not task_id or status != 'queued': + print(f"收到无效任务: {task}") + continue + + cache_key = f"asr:{task_id}" + + print(f"开始处理任务: {cache_key}") + await process_audio(file_path, cache_key) + print(f"完成处理任务: {cache_key}") + + except Exception as e: + print(f"处理消息时发生错误: {str(e)}") + +if __name__ == "__main__": + print("启动Kafka消费者处理ASR请求...") + asyncio.run(kafka_consumer()) \ No newline at end of file diff --git a/api_chat/before/asr.py b/api_chat/before/asr.py new file mode 100644 index 0000000..db0f524 --- /dev/null +++ b/api_chat/before/asr.py @@ -0,0 +1,115 @@ +import whisper +import os +import json +import redis +from dotenv import load_dotenv +from kafka import KafkaConsumer +import threading + +# 设置要使用的GPU ID +GPU_ID = 1 # 修改这个值来选择要使用的GPU + +# 设置CUDA_VISIBLE_DEVICES环境变量 +os.environ["CUDA_VISIBLE_DEVICES"] = str(GPU_ID) + +# 加载环境变量 +load_dotenv() + +print("正在加载Whisper模型...") +model = whisper.load_model("large-v3") +print("Whisper模型加载完成。") + +# Kafka配置 +KAFKA_BROKER = os.getenv('KAFKA_BROKER') +KAFKA_TOPIC = os.getenv('KAFKA_ASR_TOPIC') + +# Redis配置 +REDIS_HOST = os.getenv('REDIS_HOST') +REDIS_PORT = int(os.getenv('REDIS_PORT')) +REDIS_ASR_DB = int(os.getenv('REDIS_ASR_DB')) +REDIS_PASSWORD = os.getenv('REDIS_PASSWORD') +REDIS_TASK_DB = int(os.getenv('REDIS_TASK_DB')) + +# 创建Redis客户端 +redis_asr_client = redis.Redis( + host=REDIS_HOST, + port=REDIS_PORT, + db=REDIS_ASR_DB, + password=REDIS_PASSWORD +) + +redis_task_client = redis.Redis( + host=REDIS_HOST, + port=REDIS_PORT, + db=REDIS_TASK_DB, + password=REDIS_PASSWORD +) + +def process_audio(file_path: str, client_id: str, cache_key: str): + try: + # 设置任务状态为 "processing" + redis_task_client.set(f"task_status:{cache_key}", "processing") + + result = model.transcribe(file_path) + transcription = result['text'] + + print(f"处理了文件: {file_path}") + print(f"转录结果: {transcription}") + + # 将结果存入Redis缓存 + redis_asr_client.setex(cache_key, 3600, transcription) # 缓存1小时 + + # 发布结果到Redis频道 + result_data = { + 'client_id': client_id, + 'transcription': transcription + } + redis_asr_client.publish('asr_results', json.dumps(result_data)) + + # 设置任务状态为 "completed" + redis_task_client.set(f"task_status:{cache_key}", "completed") + + # 清理临时文件 + os.remove(file_path) + + except Exception as e: + print(f"处理音频文件时发生错误: {str(e)}") + # 设置任务状态为 "error" + redis_task_client.set(f"task_status:{cache_key}", "error") + +def kafka_consumer(): + consumer = KafkaConsumer( + KAFKA_TOPIC, + bootstrap_servers=[KAFKA_BROKER], + value_deserializer=lambda x: json.loads(x.decode('utf-8')), + group_id='asr_group', + auto_offset_reset='earliest', + enable_auto_commit=True + ) + + print(f"ASR消费者已启动") + + for message in consumer: + try: + task = message.value + file_path = task.get('file_path') + task_id = task.get('task_id') + status = task.get('status') + + if not file_path or not task_id or status != 'queued': + print(f"收到无效任务: {task}") + continue + + cache_key = f"asr:{task_id}" + client_id = task_id # 使用task_id作为client_id + + print(f"开始处理任务: {cache_key}") + process_audio(file_path, client_id, cache_key) + print(f"完成处理任务: {cache_key}") + + except Exception as e: + print(f"处理消息时发生错误: {str(e)}") + +if __name__ == "__main__": + print("启动Kafka消费者处理ASR请求...") + kafka_consumer() \ No newline at end of file diff --git a/api_chat/before/chat.py b/api_chat/before/chat.py new file mode 100644 index 0000000..6b9f264 --- /dev/null +++ b/api_chat/before/chat.py @@ -0,0 +1,117 @@ +from kafka import KafkaConsumer +import json +import asyncio +import redis +import os +from dotenv import load_dotenv +import requests +from concurrent.futures import ThreadPoolExecutor + +# 加载 .env 文件 +load_dotenv() + +# Kafka 设置 +KAFKA_BROKER = os.getenv('KAFKA_BROKER') +KAFKA_CHAT_TOPIC = os.getenv('KAFKA_CHAT_TOPIC') +KAFKA_CONSUMER_GROUP = 'chat_group' +KAFKA_CONSUMER_NUM = 3 # 消费者数量 + +# Redis 设置 +REDIS_HOST = os.getenv('REDIS_HOST') +REDIS_PORT = int(os.getenv('REDIS_PORT')) +REDIS_CHAT_DB = int(os.getenv('REDIS_CHAT_DB')) +REDIS_PASSWORD = os.getenv('REDIS_PASSWORD') +REDIS_TASK_DB = int(os.getenv('REDIS_TASK_DB')) + +# 创建Redis客户端 +redis_client = redis.Redis( + host=REDIS_HOST, + port=REDIS_PORT, + db=REDIS_CHAT_DB, + password=REDIS_PASSWORD +) + +redis_task_client = redis.Redis( + host=REDIS_HOST, + port=REDIS_PORT, + db=REDIS_TASK_DB, + password=REDIS_PASSWORD +) + +DEFAULT_SYSTEM_PROMPT = "你是一个智能助手,请用尽可能简短、简洁、友好的方式回答问题。输入的所有内容都来自于语音识别输入,因此可能会出现各种错误,请尽可能猜测用户的意思" + +# 创建Kafka消费者 +def create_kafka_consumer(): + return KafkaConsumer( + KAFKA_CHAT_TOPIC, + bootstrap_servers=KAFKA_BROKER, + auto_offset_reset='latest', + enable_auto_commit=True, + group_id=KAFKA_CONSUMER_GROUP, + value_deserializer=lambda x: json.loads(x.decode('utf-8')) + ) + +async def process_chat_request(chat_request): + try: + session_id = chat_request['session_id'] + query = chat_request['query'] + model = chat_request.get('model', 'qwen2.5:3b') + + # 设置任务状态为 "processing" + redis_task_client.set(f"task_status:{session_id}", "processing") + + # 从Redis获取历史记录 + history = json.loads(redis_client.get(session_id) or '[]') + + # 构建包含历史对话的完整提示 + full_prompt = DEFAULT_SYSTEM_PROMPT + "\n" + for past_query, past_response in history: + full_prompt += f"用户: {past_query}\n助手: {past_response}\n" + full_prompt += f"用户: {query}\n助手:" + + data = { + "model": model, + "prompt": full_prompt, + "stream": True, + "temperature": 0 + } + + response = requests.post("http://127.0.0.1:11434/api/generate", json=data, stream=True) + response.raise_for_status() + + text_output = "" + for line in response.iter_lines(): + if line: + json_data = json.loads(line) + if 'response' in json_data: + text_output += json_data['response'] + + # 更新历史记录 + history.append((query, text_output)) + redis_client.set(session_id, json.dumps(history)) + + # 设置任务状态为 "completed" + redis_task_client.set(f"task_status:{session_id}", "completed") + + print(f"处理完成 session {session_id}: {text_output}") + + except Exception as e: + print(f"处理 session {chat_request['session_id']} 时出错: {str(e)}") + # 设置任务状态为 "error" + redis_task_client.set(f"task_status:{chat_request['session_id']}", "error") + +def kafka_consumer_thread(consumer_id): + consumer = create_kafka_consumer() + print(f"消费者 {consumer_id} 已启动") + for message in consumer: + chat_request = message.value + asyncio.run(process_chat_request(chat_request)) + +def main(): + print("启动Kafka消费者处理聊天请求...") + with ThreadPoolExecutor(max_workers=KAFKA_CONSUMER_NUM) as executor: + for i in range(KAFKA_CONSUMER_NUM): + executor.submit(kafka_consumer_thread, i) + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/api_chat/before/mini3_api.py b/api_chat/before/mini3_api.py new file mode 100644 index 0000000..c21dc7f --- /dev/null +++ b/api_chat/before/mini3_api.py @@ -0,0 +1,182 @@ +from fastapi import FastAPI, HTTPException +from fastapi.middleware.cors import CORSMiddleware +from pydantic import BaseModel +import requests +import json +from typing import List, Tuple +from kafka import KafkaConsumer, TopicPartition +from concurrent.futures import ThreadPoolExecutor +import threading +import asyncio +import redis +import uuid +import logging +import uvicorn +from dotenv import load_dotenv +import os +import torch +from modelscope import AutoModelForCausalLM, AutoTokenizer + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) +app = FastAPI() + +device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") +torch.cuda.set_device(device) +print(f"Using device: {device}") + +# Load MiniCPM3-4B model +path = "/home/zydi/worker_chat/api/OpenBMB/MiniCPM3-4B" +tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True) +tokenizer.pad_token = tokenizer.eos_token +model = AutoModelForCausalLM.from_pretrained(path, torch_dtype=torch.bfloat16, device_map=device, trust_remote_code=True) + +# 加载 .env 文件 +load_dotenv() +# CORS 配置 +ALLOWED_ORIGINS = os.getenv('ALLOWED_ORIGINS').split(',') + +# 添加CORS中间件 +app.add_middleware( + CORSMiddleware, + allow_origins=ALLOWED_ORIGINS, + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + +# Kafka 设置 +KAFKA_BROKER = os.getenv('KAFKA_BROKER') +KAFKA_TOPIC = os.getenv('KAFKA_MINI3_TOPIC') +KAFKA_CONSUMER_GROUP = 'mini3_group' +KAFKA_CONSUMER_NUM = 1 + +# Redis 设置 +REDIS_HOST = os.getenv('REDIS_HOST') +REDIS_PORT = int(os.getenv('REDIS_PORT')) +REDIS_DB = int(os.getenv('REDIS_MINI3_DB')) +REDIS_PASSWORD = os.getenv('REDIS_PASSWORD') + +# 创建Redis客户端 +redis_client = redis.Redis( + host=REDIS_HOST, + port=REDIS_PORT, + db=REDIS_DB, + password=REDIS_PASSWORD # 使用密码进行认证 +) +# 创建Kafka消费者 +def create_kafka_consumer(): + return KafkaConsumer( + bootstrap_servers=KAFKA_BROKER, + auto_offset_reset='earliest', + enable_auto_commit=True, + group_id=KAFKA_CONSUMER_GROUP, + value_deserializer=lambda x: json.loads(x.decode('utf-8')) + ) + +# Kafka消费者函数 +def kafka_consumer(consumer, consumer_id): + # 获取消费者分配的分区 + consumer.subscribe([KAFKA_TOPIC]) + partitions = consumer.assignment() + + logger.info(f"消费者 {consumer_id} 被分配了以下分区: {[p.partition for p in partitions]}") + + for message in consumer: + partition = message.partition + offset = message.offset + chat_request = message.value # 直接使用 message.value,它已经是一个字典 + session_id = chat_request['session_id'] + query = chat_request['query'] + + logger.info(f"消费者 {consumer_id} 正在处理来自分区 {partition} 的消息:") + + asyncio.run(process_chat_request(chat_request)) + +# 启动Kafka消费者线程 +def start_kafka_consumers(num_consumers=KAFKA_CONSUMER_NUM): + consumers = [] + for i in range(num_consumers): + consumer = create_kafka_consumer() + consumer_thread = threading.Thread(target=kafka_consumer, args=(consumer, i), daemon=True) + consumer_thread.start() + consumers.append((consumer, consumer_thread)) + return consumers + +DEFAULT_SYSTEM_PROMPT = "你是一个智能助手,请用尽可能简短、简洁、友好的方式回答问题。输入的所有内容都来自于语音识别输入,因此可能会出现各种错误,请尽可能猜测用户的意思" + +class ChatRequest(BaseModel): + session_id: str + query: str + model: str = "minicpm3-4b" + +class ChatResponse(BaseModel): + response: str + history: List[Tuple[str, str]] + +# 处理聊天请求的异步函数 +async def process_chat_request(chat_request): + try: + response = await chat(ChatRequest(**chat_request)) + print(f"Processed message for session {chat_request['session_id']}: {response}") + except Exception as e: + print(f"Error processing message for session {chat_request['session_id']}: {str(e)}") + +@app.post("/mini3", response_model=ChatResponse) +async def chat(request: ChatRequest): + session_id = request.session_id + query = request.query + + # 从Redis获取历史记录 + history = json.loads(redis_client.get(session_id) or '[]') + + # 构建包含历史对话的完整提示 + messages = [{"role": "system", "content": DEFAULT_SYSTEM_PROMPT}] + for past_query, past_response in history: + messages.append({"role": "user", "content": past_query}) + messages.append({"role": "assistant", "content": past_response}) + messages.append({"role": "user", "content": query}) + + try: + model_inputs = tokenizer.apply_chat_template(messages, return_tensors="pt", add_generation_prompt=True) + + # 创建注意力掩码 + attention_mask = model_inputs.ne(tokenizer.pad_token_id).long() + + # 将输入移动到正确的设备(CPU或GPU) + model_inputs = model_inputs.to(device) + attention_mask = attention_mask.to(device) + + model_outputs = model.generate( + model_inputs, + attention_mask=attention_mask, + max_new_tokens=1024, + top_p=0.7, + temperature=0.7, + pad_token_id=tokenizer.eos_token_id, # 将pad_token_id设置为eos_token_id + do_sample=True + ) + + output_token_ids = model_outputs[0][len(model_inputs[0]):] + text_output = tokenizer.decode(output_token_ids, skip_special_tokens=True) + + # 更新历史记录 + history.append((query, text_output)) + redis_client.set(session_id, json.dumps(history)) + + return ChatResponse(response=text_output, history=history) + + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + +@app.post("/start_chat") +async def start_chat(): + session_id = str(uuid.uuid4()) + return {"session_id": session_id} + +if __name__ == '__main__': + # 启动Kafka消费者线程 + start_kafka_consumers() + + # 启动FastAPI服务器 + uvicorn.run(app, host="0.0.0.0", port=6003) \ No newline at end of file diff --git a/api_chat/before/ollama_api.py b/api_chat/before/ollama_api.py new file mode 100644 index 0000000..991d071 --- /dev/null +++ b/api_chat/before/ollama_api.py @@ -0,0 +1,170 @@ +from fastapi import FastAPI, HTTPException +from fastapi.middleware.cors import CORSMiddleware +from pydantic import BaseModel +import requests +import json +from typing import List, Tuple +from kafka import KafkaConsumer, TopicPartition +from concurrent.futures import ThreadPoolExecutor +import threading +import asyncio +import redis +import uuid +import logging +import uvicorn +from dotenv import load_dotenv +import os +import torch + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) +app = FastAPI() + +device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") +torch.cuda.set_device(device) +print(f"Using device: {device}") + +# 加载 .env 文件 +load_dotenv() +# CORS 配置 +ALLOWED_ORIGINS = os.getenv('ALLOWED_ORIGINS').split(',') + +# 添加CORS中间件 +app.add_middleware( + CORSMiddleware, + allow_origins=ALLOWED_ORIGINS, + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + +# Kafka 设置 +KAFKA_BROKER = os.getenv('KAFKA_BROKER') +KAFKA_TOPIC = os.getenv('KAFKA_CHAT_TOPIC') +KAFKA_CONSUMER_GROUP = 'chat_group' +KAFKA_CONSUMER_NUM = 1 + +# Redis 设置 +REDIS_HOST = os.getenv('REDIS_HOST') +REDIS_PORT = int(os.getenv('REDIS_PORT')) +REDIS_DB = int(os.getenv('REDIS_CHAT_DB')) +REDIS_PASSWORD = os.getenv('REDIS_PASSWORD') + +# 创建Redis客户端 +redis_client = redis.Redis( + host=REDIS_HOST, + port=REDIS_PORT, + db=REDIS_DB, + password=REDIS_PASSWORD # 使用密码进行认证 +) +# 创建Kafka消费者 +def create_kafka_consumer(): + return KafkaConsumer( + bootstrap_servers=KAFKA_BROKER, + auto_offset_reset='earliest', + enable_auto_commit=True, + group_id=KAFKA_CONSUMER_GROUP, + value_deserializer=lambda x: json.loads(x.decode('utf-8')) + ) + +# Kafka消费者函数 +def kafka_consumer(consumer, consumer_id): + # 获取消费者分配的分区 + consumer.subscribe([KAFKA_TOPIC]) + partitions = consumer.assignment() + + logger.info(f"消费者 {consumer_id} 被分配了以下分区: {[p.partition for p in partitions]}") + + for message in consumer: + partition = message.partition + offset = message.offset + chat_request = message.value # 直接使用 message.value,它已经是一个字典 + session_id = chat_request['session_id'] + query = chat_request['query'] + + logger.info(f"消费者 {consumer_id} 正在处理来自分区 {partition} 的消息:") + + asyncio.run(process_chat_request(chat_request)) + +# 启动Kafka消费者线程 +def start_kafka_consumers(num_consumers=KAFKA_CONSUMER_NUM): + consumers = [] + for i in range(num_consumers): + consumer = create_kafka_consumer() + consumer_thread = threading.Thread(target=kafka_consumer, args=(consumer, i), daemon=True) + consumer_thread.start() + consumers.append((consumer, consumer_thread)) + return consumers + +DEFAULT_SYSTEM_PROMPT = "你是一个智能助手,请用尽可能简短、简洁、友好的方式回答问题。输入的所有内容都来自于语音识别输入,因此可能会出现各种错误,请尽可能猜测用户的意思" + +class ChatRequest(BaseModel): + session_id: str + query: str + model: str = "qwen2.5:3b" + +class ChatResponse(BaseModel): + response: str + history: List[Tuple[str, str]] + +# 处理聊天请求的异步函数 +async def process_chat_request(chat_request): + try: + response = await chat(ChatRequest(**chat_request)) + print(f"Processed message for session {chat_request['session_id']}: {response}") + except Exception as e: + print(f"Error processing message for session {chat_request['session_id']}: {str(e)}") + +@app.post("/chat", response_model=ChatResponse) +async def chat(request: ChatRequest): + session_id = request.session_id + query = request.query + model = request.model + + # 从Redis获取历史记录 + history = json.loads(redis_client.get(session_id) or '[]') + + # 构建包含历史对话的完整提示 + full_prompt = DEFAULT_SYSTEM_PROMPT + "\n" + for past_query, past_response in history: + full_prompt += f"用户: {past_query}\n助手: {past_response}\n" + full_prompt += f"用户: {query}" + + data = { + "model": model, + "prompt": full_prompt, + "stream": True, + "temperature": 0 + } + + try: + response = requests.post("http://127.0.0.1:11434/api/generate", json=data, stream=True) + response.raise_for_status() + + text_output = "" + for line in response.iter_lines(): + if line: + json_data = json.loads(line) + if 'response' in json_data: + text_output += json_data['response'] + + # 更新历史记录 + history.append((query, text_output)) + redis_client.set(session_id, json.dumps(history)) + + return ChatResponse(response=text_output, history=history) + + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + +@app.post("/start_chat") +async def start_chat(): + session_id = str(uuid.uuid4()) + return {"session_id": session_id} + +if __name__ == '__main__': + # 启动Kafka消费者线程 + start_kafka_consumers() + + # 启动FastAPI服务器 + uvicorn.run(app, host="0.0.0.0", port=6001) \ No newline at end of file diff --git a/api_chat/before/producer_chat_1.py b/api_chat/before/producer_chat_1.py new file mode 100644 index 0000000..907680a --- /dev/null +++ b/api_chat/before/producer_chat_1.py @@ -0,0 +1,319 @@ +from fastapi import FastAPI, HTTPException, Depends, Security, File, UploadFile, WebSocket, WebSocketDisconnect +from fastapi.middleware.cors import CORSMiddleware +from fastapi.security import APIKeyHeader +from pydantic import BaseModel +from kafka import KafkaProducer +from redis import Redis +import os +import json +import uuid +from datetime import datetime, timezone +from dotenv import load_dotenv +import tempfile +import hashlib +import asyncio + + +# 加载 .env 文件 +load_dotenv() + +app = FastAPI() +v1_chat_app = FastAPI() +app.mount("/v1_chat", v1_chat_app) + +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + +# 配置 +KAFKA_BROKER = os.getenv('KAFKA_BROKER') +REDIS_HOST = os.getenv('REDIS_HOST') +REDIS_PORT = int(os.getenv('REDIS_PORT')) +REDIS_PASSWORD = os.getenv('REDIS_PASSWORD') +REDIS_TTS_DB = int(os.getenv('REDIS_TTS_DB')) +REDIS_ASR_DB = int(os.getenv('REDIS_ASR_DB')) +REDIS_CHAT_DB = int(os.getenv('REDIS_CHAT_DB')) +REDIS_API_DB = int(os.getenv('REDIS_API_DB')) +REDIS_API_USAGE_DB = int(os.getenv('REDIS_API_USAGE_DB')) +REDIS_TASK_DB = int(os.getenv('REDIS_TASK_DB')) + +KAFKA_TTS_TOPIC = os.getenv('KAFKA_TTS_TOPIC') +KAFKA_ASR_TOPIC = os.getenv('KAFKA_ASR_TOPIC') +KAFKA_CHAT_TOPIC = os.getenv('KAFKA_CHAT_TOPIC') + +# 初始化 Kafka Producer +producer = KafkaProducer( + bootstrap_servers=[KAFKA_BROKER], + value_serializer=lambda v: json.dumps(v).encode('utf-8') +) + +# 初始化 Redis +redis_tts_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_TTS_DB) +redis_asr_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_ASR_DB) +redis_chat_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_CHAT_DB) +redis_api_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_API_DB) +redis_api_usage_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_API_USAGE_DB) +redis_task_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_TASK_DB) + +# 定义API密钥头部 +api_key_header = APIKeyHeader(name="Authorization", auto_error=False) + +def get_audio_hash(text): + return hashlib.md5(text.encode()).hexdigest() + +# 验证API密钥的函数 +async def get_api_key(api_key: str = Security(api_key_header)): + if api_key and api_key.startswith("Bearer "): + key = api_key.split(" ")[1] + if key.startswith("obs-"): + return key + raise HTTPException( + status_code=401, + detail="无效的API密钥", + headers={"WWW-Authenticate": "Bearer"}, + ) + +async def verify_api_key(api_key: str = Depends(get_api_key)): + redis_key = f"api_key:{api_key}" + + api_key_info = redis_api_client.hgetall(redis_key) + + if not api_key_info: + raise HTTPException(status_code=401, detail="无效的API密钥") + + api_key_info = {k.decode('utf-8'): v.decode('utf-8') for k, v in api_key_info.items()} + + if api_key_info.get('is_active') != '1': + raise HTTPException(status_code=401, detail="API密钥已停用") + + expires_at = datetime.fromisoformat(api_key_info.get('expires_at')) + if datetime.now(timezone.utc) > expires_at: + raise HTTPException(status_code=401, detail="API密钥已过期") + + usage_info = redis_api_usage_client.hgetall(redis_key) + usage_info = {k.decode('utf-8'): v.decode('utf-8') for k, v in usage_info.items()} + + return { + "api_key": api_key, + **api_key_info, + **usage_info + } + +async def update_token_usage(api_key: str, new_tokens_used: int, model_name: str): + redis_key = f"api_key:{api_key}" + current_time = datetime.now(timezone.utc).isoformat() + + pipe = redis_api_usage_client.pipeline() + + pipe.hincrby(redis_key, "tokens_used", new_tokens_used) + pipe.hset(redis_key, "last_used_at", current_time) + + model_tokens_field = f"{model_name}_tokens_used" + model_last_used_field = f"{model_name}_last_used_at" + + pipe.hincrby(redis_key, model_tokens_field, new_tokens_used) + pipe.hset(redis_key, model_last_used_field, current_time) + + pipe.execute() + +async def process_request(api_key_info: dict, model_name: str, tokens_required: int, task_data: dict, kafka_topic: str): + api_key = api_key_info['api_key'] + usage_key = f"api_key:{api_key}" + total_tokens = int(redis_api_usage_client.hget(usage_key, "total_tokens") or 0) + tokens_used = int(redis_api_usage_client.hget(usage_key, "tokens_used") or 0) + + if tokens_used + tokens_required > total_tokens: + raise HTTPException(status_code=403, detail="Token 余额不足") + + # 更新 token 使用量 + await update_token_usage(api_key, tokens_required, model_name) + + # 发送任务到Kafka + producer.send(kafka_topic, task_data) + + # 获取更新后的 token 使用情况 + updated_api_key_info = await verify_api_key(api_key) + new_tokens_used = int(updated_api_key_info.get("tokens_used", 0)) + model_tokens_used = int(updated_api_key_info.get(f"{model_name}_tokens_used", 0)) + + return { + "message": f"{model_name.upper()}请求已排队等待处理", + "tokens_used": tokens_required, + "total_tokens_used": new_tokens_used, + f"{model_name}_tokens_used": model_tokens_used, + "tokens_remaining": total_tokens - new_tokens_used + } + +class TTSRequest(BaseModel): + text: str + +class ChatRequest(BaseModel): + session_id: str + query: str + model: str = "qwen2.5:3b" + + +# 添加WebSocket连接管理 +class ConnectionManager: + def __init__(self): + self.active_connections = {} + + async def connect(self, websocket: WebSocket, client_id: str): + await websocket.accept() + self.active_connections[client_id] = websocket + + def disconnect(self, client_id: str): + self.active_connections.pop(client_id, None) + + async def send_message(self, message: str, client_id: str): + if client_id in self.active_connections: + await self.active_connections[client_id].send_text(message) + +manager = ConnectionManager() + + +@v1_chat_app.websocket("/ws/{client_id}") +async def websocket_endpoint(websocket: WebSocket, client_id: str): + await manager.connect(websocket, client_id) + try: + while True: + await websocket.receive_text() + except WebSocketDisconnect: + manager.disconnect(client_id) + +# 修改TTS请求处理函数 +@v1_chat_app.post("/tts") +async def tts_request(request: TTSRequest, api_key_info: dict = Depends(verify_api_key)): + task_id = str(uuid.uuid4()) + task_data = { + "task_id": task_id, + "text": request.text, + "status": "queued", + "created_at": datetime.now(timezone.utc).isoformat(), + } + + redis_task_client.set(f"task_status:{task_id}", "queued") + + result = await process_request(api_key_info, "tts", 100, task_data, KAFKA_TTS_TOPIC) + result["task_id"] = task_id + + # 将任务ID存储到Redis,以便后续WebSocket通信使用 + redis_task_client.set(f"task_client:{task_id}", api_key_info['api_key']) + + return result + +# 修改ASR请求处理函数 +@v1_chat_app.post("/asr") +async def asr_request(audio: UploadFile = File(...), api_key_info: dict = Depends(verify_api_key)): + task_id = str(uuid.uuid4()) + + UPLOAD_DIR = "/obscura/task/audio_upload" + os.makedirs(UPLOAD_DIR, exist_ok=True) + file_path = os.path.join(UPLOAD_DIR, f"{task_id}.wav") + + with open(file_path, "wb") as temp_audio: + content = await audio.read() + temp_audio.write(content) + + task_data = { + 'file_path': file_path, + 'task_id': task_id, + 'status': 'queued' + } + + redis_task_client.set(f"task_status:{task_id}", "queued") + + result = await process_request(api_key_info, "asr", 100, task_data, KAFKA_ASR_TOPIC) + result["task_id"] = task_id + + # 将任务ID存储到Redis,以便后续WebSocket通信使用 + redis_task_client.set(f"task_client:{task_id}", api_key_info['api_key']) + + return result + +# 修改聊天请求处理函数 +@v1_chat_app.post("/chat") +async def chat_request(request: ChatRequest, api_key_info: dict = Depends(verify_api_key)): + task_id = str(uuid.uuid4()) + task_data = { + "task_id": task_id, + "session_id": request.session_id, + "query": request.query, + "model": request.model, + "status": "queued", + "created_at": datetime.now(timezone.utc).isoformat(), + } + + redis_task_client.set(f"task_status:{task_id}", "queued") + + result = await process_request(api_key_info, "chat", 100, task_data, KAFKA_CHAT_TOPIC) + result["task_id"] = task_id + + # 将任务ID存储到Redis,以便后续WebSocket通信使用 + redis_task_client.set(f"task_client:{task_id}", api_key_info['api_key']) + + return result + +@v1_chat_app.get("/chat_result/{task_id}") +async def get_chat_result(task_id: str, api_key_info: dict = Depends(verify_api_key)): + # 从Redis任务数据库获取任务状态 + task_status = redis_task_client.get(f"task_status:{task_id}") + if task_status: + status = task_status.decode('utf-8') + if status == "completed": + # 从Redis聊天结果数据库获取聊天结果 + chat_result = redis_chat_client.get(task_id) + if chat_result: + result = json.loads(chat_result) + return { + "status": "completed", + "history": result # 直接返回整个历史记录 + } + return {"status": status} + + return {"status": "not_found"} + +@v1_chat_app.get("/tts_result/{task_id}") +async def get_tts_result(task_id: str, api_key_info: dict = Depends(verify_api_key)): + # 从Redis任务数据库获取任务状态 + task_status = redis_task_client.get(f"task_status:{task_id}") + if task_status: + status = task_status.decode('utf-8') + if status == "completed": + # 从Redis TTS结果数据库获取音频文件路径 + audio_info = redis_tts_client.get(task_id) + if audio_info: + audio_path = json.loads(audio_info)['path'] + return { + "status": "completed", + "audio_path": audio_path + } + return {"status": status} + + return {"status": "not_found"} + +@v1_chat_app.get("/asr_result/{task_id}") +async def get_asr_result(task_id: str, api_key_info: dict = Depends(verify_api_key)): + # 从Redis任务数据库获取任务状态 + task_status = redis_task_client.get(f"task_status:{task_id}") + if task_status: + status = task_status.decode('utf-8') + if status == "completed": + # 从Redis ASR结果数据库获取转录结果 + transcription = redis_asr_client.get(task_id) + if transcription: + return { + "status": "completed", + "transcription": transcription.decode('utf-8') + } + return {"status": status} + + return {"status": "not_found"} + +if __name__ == "__main__": + import uvicorn + uvicorn.run(app, host="0.0.0.0", port=8008) \ No newline at end of file diff --git a/api_chat/before/sovits_api.py b/api_chat/before/sovits_api.py new file mode 100644 index 0000000..1718c6c --- /dev/null +++ b/api_chat/before/sovits_api.py @@ -0,0 +1,180 @@ +import os +import soundfile as sf +from fastapi import FastAPI, HTTPException +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import FileResponse +from pydantic import BaseModel, Field +import uvicorn +import redis +import hashlib +import json +from kafka import KafkaProducer, KafkaConsumer +import threading +import time +from tools.i18n.i18n import I18nAuto +from GPT_SoVITS.inference_webui import change_gpt_weights, change_sovits_weights, get_tts_wav +from dotenv import load_dotenv +import os +import torch + +# 加载 .env 文件 +load_dotenv() +device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") +print(f"Using device: {device}") + +# FastAPI configuration +app = FastAPI() +i18n = I18nAuto() + +# CORS configuration +ALLOWED_ORIGINS = os.getenv('ALLOWED_ORIGINS').split(',') + +# Redis configuration +REDIS_HOST = os.getenv('REDIS_HOST') +REDIS_PORT = int(os.getenv('REDIS_PORT')) +REDIS_DB = int(os.getenv('REDIS_TTS_DB')) +REDIS_PASSWORD = os.getenv('REDIS_PASSWORD') + +# Kafka configuration +KAFKA_BROKER = os.getenv('KAFKA_BROKER') +KAFKA_TOPIC = os.getenv('KAFKA_TTS_TOPIC') +# KAFKA_GROUP_ID = 'tts_group' +KAFKA_CONSUMER_THREADS = 1 + +# TTS configuration +GPT_MODEL_PATH = os.getenv('GPT_MODEL_PATH') +SOVITS_MODEL_PATH = os.getenv('SOVITS_MODEL_PATH') +REF_AUDIO_PATH = os.getenv('REF_AUDIO_PATH') +REF_TEXT_PATH = os.getenv('REF_TEXT_PATH') +REF_LANGUAGE = os.getenv('REF_LANGUAGE') +TARGET_LANGUAGE = os.getenv('TARGET_LANGUAGE') +OUTPUT_PATH = os.getenv('OUTPUT_PATH') + +# Initialize FastAPI CORS +app.add_middleware( + CORSMiddleware, + allow_origins=ALLOWED_ORIGINS, + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + +# Initialize Redis client +redis_client = redis.Redis( + host=REDIS_HOST, + port=REDIS_PORT, + db=REDIS_DB, + password=REDIS_PASSWORD +) + +# Initialize Kafka producer +kafka_producer = KafkaProducer(bootstrap_servers=KAFKA_BROKER) + +class TTSRequest(BaseModel): + text: str = Field(..., alias="text") + +def get_audio_hash(text): + return hashlib.md5(text.encode()).hexdigest() + +# Initialize models at startup +print("Initializing models...") +change_gpt_weights(gpt_path=GPT_MODEL_PATH) +change_sovits_weights(sovits_path=SOVITS_MODEL_PATH) + +# Read reference text +with open(REF_TEXT_PATH, 'r', encoding='utf-8') as file: + ref_text = file.read() + +print("Models initialized successfully.") + +def synthesize(target_text, output_path): + # Synthesize audio + with torch.cuda.device(device): + synthesis_result = get_tts_wav(ref_wav_path=REF_AUDIO_PATH, + prompt_text=ref_text, + prompt_language=i18n(REF_LANGUAGE), + text=target_text, + text_language=i18n(TARGET_LANGUAGE), top_p=1, temperature=1) + + result_list = list(synthesis_result) + + if result_list: + last_sampling_rate, last_audio_data = result_list[-1] + audio_hash = get_audio_hash(target_text) + output_wav_path = os.path.join(output_path, f"{audio_hash}.wav") + sf.write(output_wav_path, last_audio_data, last_sampling_rate) + return output_wav_path + else: + return None + +@app.post("/tts") +async def synthesize_audio(request: TTSRequest): + try: + print(f"Received TTS request: {request.dict()}") + target_text = request.text + audio_hash = get_audio_hash(target_text) + + # Check Redis cache + cached_audio = redis_client.get(audio_hash) + if cached_audio: + audio_info = json.loads(cached_audio) + return FileResponse(audio_info['path'], media_type="audio/wav") + + # Check file system + file_path = os.path.join(OUTPUT_PATH, f"{audio_hash}.wav") + if os.path.exists(file_path): + # Cache the file path in Redis + redis_client.set(audio_hash, json.dumps({"path": file_path})) + return FileResponse(file_path, media_type="audio/wav") + + # Send message to Kafka + kafka_producer.send(KAFKA_TOPIC, json.dumps({ + 'text': target_text, + 'audio_hash': audio_hash + }).encode('utf-8')) + + # Wait for the audio to be generated (you might want to implement a more sophisticated waiting mechanism) + for _ in range(60): # Wait for up to 30 seconds + if os.path.exists(file_path): + return FileResponse(file_path, media_type="audio/wav") + time.sleep(1) + + # If audio is not generated within the timeout + raise HTTPException(status_code=504, detail="Audio generation timed out") + except Exception as e: + print(f"Error processing TTS request: {str(e)}") + raise HTTPException(status_code=500, detail=str(e)) + +@app.get("/") +async def root(): + return {"message": "TTS API is running"} + +def kafka_consumer_thread(): + consumer = KafkaConsumer( + KAFKA_TOPIC, + bootstrap_servers=KAFKA_BROKER, + # group_id=KAFKA_GROUP_ID, + auto_offset_reset='latest', + value_deserializer=lambda m: json.loads(m.decode('utf-8')) + ) + + for message in consumer: + target_text = message.value['text'] + audio_hash = message.value['audio_hash'] + + output_path = synthesize(target_text, OUTPUT_PATH) + + if output_path: + redis_client.set(audio_hash, json.dumps({"path": output_path})) + print(f"Audio synthesized successfully: {output_path}") + else: + print("Failed to synthesize audio") + +if __name__ == "__main__": + # Start Kafka consumer threads + torch.cuda.set_device(device) + for _ in range(KAFKA_CONSUMER_THREADS): + consumer_thread = threading.Thread(target=kafka_consumer_thread) + consumer_thread.start() + + uvicorn.run(app, host="0.0.0.0", port=6002) \ No newline at end of file diff --git a/api_chat/before/sovits_api_1.py b/api_chat/before/sovits_api_1.py new file mode 100644 index 0000000..bdda30f --- /dev/null +++ b/api_chat/before/sovits_api_1.py @@ -0,0 +1,180 @@ +import os +import soundfile as sf +from fastapi import FastAPI, HTTPException +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import FileResponse +from pydantic import BaseModel, Field +import uvicorn +import redis +import hashlib +import json +from kafka import KafkaProducer, KafkaConsumer +import threading +import time +from tools.i18n.i18n import I18nAuto +from GPT_SoVITS.inference_webui import change_gpt_weights, change_sovits_weights, get_tts_wav +from dotenv import load_dotenv +import os +import torch + +# 加载 .env 文件 +load_dotenv() +device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") +print(f"Using device: {device}") + +# FastAPI configuration +app = FastAPI() +i18n = I18nAuto() + +# CORS configuration +ALLOWED_ORIGINS = os.getenv('ALLOWED_ORIGINS').split(',') + +# Redis configuration +REDIS_HOST = os.getenv('REDIS_HOST') +REDIS_PORT = int(os.getenv('REDIS_PORT')) +REDIS_DB = int(os.getenv('REDIS_TTS_DB')) +REDIS_PASSWORD = os.getenv('REDIS_PASSWORD') + +# Kafka configuration +KAFKA_BROKER = os.getenv('KAFKA_BROKER') +KAFKA_TOPIC = os.getenv('KAFKA_TTS_TOPIC') +# KAFKA_GROUP_ID = 'tts_group' +KAFKA_CONSUMER_THREADS = 1 + +# TTS configuration +GPT_MODEL_PATH = os.getenv('GPT_MODEL_PATH') +SOVITS_MODEL_PATH = os.getenv('SOVITS_MODEL_PATH') +REF_AUDIO_PATH = os.getenv('REF_AUDIO_KO_PATH') +REF_TEXT_PATH = os.getenv('REF_TEXT_KO_PATH') +REF_LANGUAGE = os.getenv('REF_KO_LANGUAGE') +TARGET_LANGUAGE = os.getenv('TARGET_LANGUAGE') +OUTPUT_PATH = os.getenv('OUTPUT_PATH') + +# Initialize FastAPI CORS +app.add_middleware( + CORSMiddleware, + allow_origins=ALLOWED_ORIGINS, + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + +# Initialize Redis client +redis_client = redis.Redis( + host=REDIS_HOST, + port=REDIS_PORT, + db=REDIS_DB, + password=REDIS_PASSWORD +) + +# Initialize Kafka producer +kafka_producer = KafkaProducer(bootstrap_servers=KAFKA_BROKER) + +class TTSRequest(BaseModel): + text: str = Field(..., alias="text") + +def get_audio_hash(text): + return hashlib.md5(text.encode()).hexdigest() + +# Initialize models at startup +print("Initializing models...") +change_gpt_weights(gpt_path=GPT_MODEL_PATH) +change_sovits_weights(sovits_path=SOVITS_MODEL_PATH) + +# Read reference text +with open(REF_TEXT_PATH, 'r', encoding='utf-8') as file: + ref_text = file.read() + +print("Models initialized successfully.") + +def synthesize(target_text, output_path): + # Synthesize audio + with torch.cuda.device(device): + synthesis_result = get_tts_wav(ref_wav_path=REF_AUDIO_PATH, + prompt_text=ref_text, + prompt_language=i18n(REF_LANGUAGE), + text=target_text, + text_language=i18n(TARGET_LANGUAGE), top_p=1, temperature=1) + + result_list = list(synthesis_result) + + if result_list: + last_sampling_rate, last_audio_data = result_list[-1] + audio_hash = get_audio_hash(target_text) + output_wav_path = os.path.join(output_path, f"{audio_hash}.wav") + sf.write(output_wav_path, last_audio_data, last_sampling_rate) + return output_wav_path + else: + return None + +@app.post("/tts_ko") +async def synthesize_audio(request: TTSRequest): + try: + print(f"Received TTS request: {request.dict()}") + target_text = request.text + audio_hash = get_audio_hash(target_text) + + # Check Redis cache + cached_audio = redis_client.get(audio_hash) + if cached_audio: + audio_info = json.loads(cached_audio) + return FileResponse(audio_info['path'], media_type="audio/wav") + + # Check file system + file_path = os.path.join(OUTPUT_PATH, f"{audio_hash}.wav") + if os.path.exists(file_path): + # Cache the file path in Redis + redis_client.set(audio_hash, json.dumps({"path": file_path})) + return FileResponse(file_path, media_type="audio/wav") + + # Send message to Kafka + kafka_producer.send(KAFKA_TOPIC, json.dumps({ + 'text': target_text, + 'audio_hash': audio_hash + }).encode('utf-8')) + + # Wait for the audio to be generated (you might want to implement a more sophisticated waiting mechanism) + for _ in range(60): # Wait for up to 30 seconds + if os.path.exists(file_path): + return FileResponse(file_path, media_type="audio/wav") + time.sleep(1) + + # If audio is not generated within the timeout + raise HTTPException(status_code=504, detail="Audio generation timed out") + except Exception as e: + print(f"Error processing TTS request: {str(e)}") + raise HTTPException(status_code=500, detail=str(e)) + +@app.get("/") +async def root(): + return {"message": "TTS API is running"} + +def kafka_consumer_thread(): + consumer = KafkaConsumer( + KAFKA_TOPIC, + bootstrap_servers=KAFKA_BROKER, + # group_id=KAFKA_GROUP_ID, + auto_offset_reset='latest', + value_deserializer=lambda m: json.loads(m.decode('utf-8')) + ) + + for message in consumer: + target_text = message.value['text'] + audio_hash = message.value['audio_hash'] + + output_path = synthesize(target_text, OUTPUT_PATH) + + if output_path: + redis_client.set(audio_hash, json.dumps({"path": output_path})) + print(f"Audio synthesized successfully: {output_path}") + else: + print("Failed to synthesize audio") + +if __name__ == "__main__": + # Start Kafka consumer threads + torch.cuda.set_device(device) + for _ in range(KAFKA_CONSUMER_THREADS): + consumer_thread = threading.Thread(target=kafka_consumer_thread) + consumer_thread.start() + + uvicorn.run(app, host="0.0.0.0", port=6003) \ No newline at end of file diff --git a/api_chat/before/tts1.py b/api_chat/before/tts1.py new file mode 100644 index 0000000..7bf9e69 --- /dev/null +++ b/api_chat/before/tts1.py @@ -0,0 +1,176 @@ +# 导入所需的库 +import os +import soundfile as sf +import redis +import hashlib +import json +from kafka import KafkaConsumer +from tools.i18n.i18n import I18nAuto +from GPT_SoVITS.inference_webui import change_gpt_weights, change_sovits_weights, get_tts_wav +from dotenv import load_dotenv +import torch + +""" +整体设计说明: +这个脚本实现了一个文本到语音(TTS)的服务。它使用Kafka作为消息队列接收TTS任务, +使用Redis存储任务状态和结果,并利用GPT-SoVITS模型进行语音合成。 +主要功能包括: +1. 初始化配置和模型 +2. 提供语音合成功能 +3. 监听Kafka消息并处理TTS任务 +4. 将合成结果存储到Redis并更新任务状态 +""" + +# 加载环境变量 +load_dotenv() + +# 设置GPU设备(如果可用) +device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") +print(f"使用设备: {device}") + +# 从环境变量中读取Redis配置 +REDIS_HOST = os.getenv('REDIS_HOST') +REDIS_PORT = int(os.getenv('REDIS_PORT')) +REDIS_TTS_DB = int(os.getenv('REDIS_TTS_DB')) # DB 2用于存储TTS结果 +REDIS_TASK_DB = int(os.getenv('REDIS_TASK_DB')) # DB 3用于存储任务状态 +REDIS_PASSWORD = os.getenv('REDIS_PASSWORD') + +# 从环境变量中读取Kafka配置 +KAFKA_BROKER = os.getenv('KAFKA_BROKER') +KAFKA_TTS_TOPIC = os.getenv('KAFKA_TTS_TOPIC') + +# 从环境变量中读取TTS相关配置 +GPT_MODEL_PATH = os.getenv('GPT_MODEL_PATH') +SOVITS_MODEL_PATH = os.getenv('SOVITS_MODEL_PATH') +REF_AUDIO_PATH = os.getenv('REF_AUDIO_ZN_PATH') +REF_TEXT_PATH = os.getenv('REF_TEXT_ZN_PATH') +REF_LANGUAGE = os.getenv('REF_LANGUAGE') +TARGET_LANGUAGE = os.getenv('TARGET_LANGUAGE') +OUTPUT_PATH = os.getenv('OUTPUT_PATH') + +# 初始化Redis客户端 +redis_client = redis.Redis( + host=REDIS_HOST, + port=REDIS_PORT, + db=REDIS_TTS_DB, + password=REDIS_PASSWORD +) + +redis_task_client = redis.Redis( + host=REDIS_HOST, + port=REDIS_PORT, + db=REDIS_TASK_DB, + password=REDIS_PASSWORD +) + +# 初始化国际化工具 +i18n = I18nAuto() + +def get_audio_hash(text): + """ + 生成文本的MD5哈希值,用作音频文件名的一部分 + + 参数: + text (str): 需要生成哈希的文本 + + 返回: + str: 文本的MD5哈希值 + """ + return hashlib.md5(text.encode()).hexdigest() + +# 初始化模型 +print("正在初始化模型...") +change_gpt_weights(gpt_path=GPT_MODEL_PATH) +change_sovits_weights(sovits_path=SOVITS_MODEL_PATH) + +# 读取参考文本 +with open(REF_TEXT_PATH, 'r', encoding='utf-8') as file: + ref_text = file.read() + +print("模型初始化成功。") + +def synthesize(target_text, output_wav_path): + """ + 使用GPT-SoVITS模型合成语音 + + 参数: + target_text (str): 需要合成语音的目标文本 + output_wav_path (str): 输出音频文件的路径 + + 返回: + str: 如果成功,返回输出音频文件的路径;如果失败,返回None + """ + with torch.cuda.device(device): + synthesis_result = get_tts_wav(ref_wav_path=REF_AUDIO_PATH, + prompt_text=ref_text, + prompt_language=i18n(REF_LANGUAGE), + text=target_text, + text_language=i18n(TARGET_LANGUAGE), top_p=1, temperature=1) + + result_list = list(synthesis_result) + + if result_list: + last_sampling_rate, last_audio_data = result_list[-1] + sf.write(output_wav_path, last_audio_data, last_sampling_rate) + return output_wav_path + else: + return None + +def kafka_consumer(): + """ + Kafka消费者函数,用于接收和处理TTS任务 + + 该函数会持续监听Kafka的TTS主题,接收任务并进行处理: + 1. 接收任务信息 + 2. 更新任务状态 + 3. 调用synthesize函数合成语音 + 4. 将结果保存到Redis + 5. 更新任务完成状态 + """ + consumer = KafkaConsumer( + KAFKA_TTS_TOPIC, + bootstrap_servers=KAFKA_BROKER, + auto_offset_reset='latest', + value_deserializer=lambda m: json.loads(m.decode('utf-8')) + ) + print(f"TTS消费者已启动") + for message in consumer: + try: + task_id = message.value['task_id'] + target_text = message.value['text'] + text_hash = message.value['text_hash'] + + # 更新任务状态为 "processing" + redis_task_client.set(f"task_status:tts:{task_id}", "processing") + + output_wav_path = os.path.join(OUTPUT_PATH, f"{text_hash}.wav") + + # 再次检查文件是否存在(以防在此期间被其他进程创建) + if not os.path.exists(output_wav_path): + output_path = synthesize(target_text, output_wav_path) + else: + output_path = output_wav_path + + if output_path: + # 将结果保存在 DB 2 + redis_client.set(f"tts:{task_id}", json.dumps({"path": output_path})) + print(f"音频合成成功: {output_path}") + + # 更新任务状态为 "completed" + redis_task_client.set(f"task_status:tts:{task_id}", "completed") + else: + print("音频合成失败") + + # 更新任务状态为 "failed" + redis_task_client.set(f"task_status:tts:{task_id}", "failed") + except Exception as e: + print(f"处理消息时出错: {str(e)}") + + # 更新任务状态为 "failed" + redis_task_client.set(f"task_status:tts:{task_id}", "failed") + +if __name__ == "__main__": + # 设置CUDA设备 + torch.cuda.set_device(device) + # 启动Kafka消费者 + kafka_consumer() \ No newline at end of file diff --git a/api_chat/before/whisper_api.py b/api_chat/before/whisper_api.py new file mode 100644 index 0000000..cb221b0 --- /dev/null +++ b/api_chat/before/whisper_api.py @@ -0,0 +1,186 @@ +from fastapi import FastAPI, File, UploadFile, HTTPException, WebSocket +from fastapi.middleware.cors import CORSMiddleware +import whisper +import tempfile +import uvicorn +from kafka import KafkaProducer, KafkaConsumer +from starlette.websockets import WebSocketDisconnect +import json +import threading +import os +import uuid +import asyncio +import logging +import redis +from dotenv import load_dotenv + +# 设置要使用的GPU ID +GPU_ID = 1 # 修改这个值来选择要使用的GPU + +# 设置CUDA_VISIBLE_DEVICES环境变量 +os.environ["CUDA_VISIBLE_DEVICES"] = str(GPU_ID) + +# 设置日志 +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +app = FastAPI() +load_dotenv() + + +# CORS 配置 +ALLOWED_ORIGINS = os.getenv('ALLOWED_ORIGINS').split(',') + + +# 添加CORS中间件 +app.add_middleware( + CORSMiddleware, + allow_origins=ALLOWED_ORIGINS, + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + +print("正在加载Whisper模型...") +model = whisper.load_model("large-v3") +print("Whisper模型加载完成。") + +# Kafka配置 +KAFKA_BROKER = os.getenv('KAFKA_BROKER') +KAFKA_TOPIC = os.getenv('KAFKA_ASR_TOPIC') + +# Redis配置 +REDIS_HOST = os.getenv('REDIS_HOST') +REDIS_PORT = int(os.getenv('REDIS_PORT')) +REDIS_DB = int(os.getenv('REDIS_ASR_DB')) +REDIS_PASSWORD = os.getenv('REDIS_PASSWORD') + +# 创建Redis客户端 +redis_client = redis.Redis( + host=REDIS_HOST, + port=REDIS_PORT, + db=REDIS_DB, + password=REDIS_PASSWORD # 添加密码 +) + +# Kafka生产者 +producer = KafkaProducer( + bootstrap_servers=[KAFKA_BROKER], + value_serializer=lambda v: json.dumps(v).encode('utf-8') +) + +# 存储WebSocket连接的字典 +active_connections = {} + +@app.websocket("/asr/ws/{client_id}") +async def websocket_endpoint(websocket: WebSocket, client_id: str): + await websocket.accept() + active_connections[client_id] = websocket + try: + while True: + try: + # 设置接收超时 + data = await asyncio.wait_for(websocket.receive_text(), timeout=30) + if data == "ping": + await websocket.send_text("pong") + else: + await websocket.send_text(f"收到消息: {data}") + except asyncio.TimeoutError: + try: + # 发送心跳 + await websocket.send_text("heartbeat") + except WebSocketDisconnect: + logger.info(f"客户端 {client_id} 断开连接") + break + except WebSocketDisconnect: + logger.info(f"客户端 {client_id} 断开连接") + except Exception as e: + logger.error(f"WebSocket错误: {e}") + finally: + if client_id in active_connections: + del active_connections[client_id] + + +@app.post("/asr") +async def transcribe(audio: UploadFile = File(...)): + if not audio: + raise HTTPException(status_code=400, detail="未提供音频文件") + + client_id = str(uuid.uuid4()) + + # 生成缓存键 + cache_key = f"asr:{audio.filename}:{client_id}" + + # 检查缓存 + cached_result = redis_client.get(cache_key) + if cached_result: + logger.info(f"缓存命中: {cache_key}") + return {"message": "从缓存获取转录结果", "transcription": cached_result.decode('utf-8'), "client_id": client_id} + + with tempfile.NamedTemporaryFile(delete=False, suffix='.wav') as temp_audio: + content = await audio.read() + temp_audio.write(content) + temp_audio.flush() + + task = { + 'file_path': temp_audio.name, + 'client_id': client_id, + 'cache_key': cache_key + } + producer.send(KAFKA_TOPIC, value=task) + producer.flush() + + logger.info(f"发送任务到Kafka: {task}") + return {"message": "音频文件已接收并发送任务进行处理", "client_id": client_id} + +async def send_transcription(client_id: str, transcription: str): + if client_id in active_connections: + websocket = active_connections[client_id] + await websocket.send_json({"transcription": transcription}) + else: + logger.warning(f"客户端 {client_id} 的WebSocket连接不存在") + +def kafka_consumer(consumer_id): + consumer = KafkaConsumer( + KAFKA_TOPIC, + bootstrap_servers=[KAFKA_BROKER], + value_deserializer=lambda x: json.loads(x.decode('utf-8')), + group_id='asr_group', + max_poll_interval_ms=300000 + ) + + for message in consumer: + try: + task = message.value + file_path = task.get('file_path') + client_id = task.get('client_id') + cache_key = task.get('cache_key') + + if not file_path or not client_id or not cache_key: + logger.error(f"消费者 {consumer_id} 收到无效任务: {task}") + consumer.commit() + continue + + result = model.transcribe(file_path) + + logger.info(f"消费者 {consumer_id} 处理了文件: {file_path}") + logger.info(f"转录结果: {result['text']}") + + # 将结果存入Redis缓存 + redis_client.setex(cache_key, 3600, result['text']) # 缓存1小时 + + asyncio.run(send_transcription(client_id, result['text'])) + + os.remove(file_path) + consumer.commit() + except Exception as e: + logger.error(f"消费者 {consumer_id} 处理消息时发生错误: {str(e)}") + +def start_consumers(num_consumers=1): + for i in range(num_consumers): + consumer_thread = threading.Thread(target=kafka_consumer, args=(i,)) + consumer_thread.start() + +if __name__ == '__main__': + start_consumers() + uvicorn.run(app, host="0.0.0.0", port=6000) \ No newline at end of file diff --git a/api_chat/chat.py b/api_chat/chat.py new file mode 100644 index 0000000..50bc293 --- /dev/null +++ b/api_chat/chat.py @@ -0,0 +1,122 @@ +from kafka import KafkaConsumer +import json +import asyncio +import redis +import os +from dotenv import load_dotenv +import requests +from concurrent.futures import ThreadPoolExecutor + +# 加载 .env 文件 +load_dotenv() + +# Kafka 设置 +KAFKA_BROKER = os.getenv('KAFKA_BROKER') +KAFKA_CHAT_TOPIC = os.getenv('KAFKA_CHAT_TOPIC') +KAFKA_CONSUMER_GROUP = 'chat_group' +KAFKA_CONSUMER_NUM = 1 # 消费者数量 + +# Redis 设置 +REDIS_HOST = os.getenv('REDIS_HOST') +REDIS_PORT = int(os.getenv('REDIS_PORT')) +REDIS_CHAT_DB = int(os.getenv('REDIS_CHAT_DB')) +REDIS_PASSWORD = os.getenv('REDIS_PASSWORD') +REDIS_TASK_DB = int(os.getenv('REDIS_TASK_DB')) + +# 创建Redis客户端 +redis_client = redis.Redis( + host=REDIS_HOST, + port=REDIS_PORT, + db=REDIS_CHAT_DB, + password=REDIS_PASSWORD +) + +redis_task_client = redis.Redis( + host=REDIS_HOST, + port=REDIS_PORT, + db=REDIS_TASK_DB, + password=REDIS_PASSWORD +) + +DEFAULT_SYSTEM_PROMPT = "你是一个智能助手,请用尽可能简短、简洁、友好的方式回答问题。输入的所有内容都来自于语音识别输入,因此可能会出现各种错误,请尽可能猜测用户的意思" + +# 创建Kafka消费者 +def create_kafka_consumer(): + return KafkaConsumer( + KAFKA_CHAT_TOPIC, + bootstrap_servers=KAFKA_BROKER, + auto_offset_reset='latest', + enable_auto_commit=True, + group_id=KAFKA_CONSUMER_GROUP, + value_deserializer=lambda x: json.loads(x.decode('utf-8')) + ) + +async def process_chat_request(chat_request): + try: + task_id = chat_request['task_id'] + session_id = chat_request['session_id'] + query = chat_request['query'] + model = chat_request.get('model', 'qwen2.5:3b') + + # 设置任务状态为 "processing" + redis_task_client.set(f"chat:{task_id}:status", "processing") + + # 从Redis获取历史记录 (使用 session_id) + history = json.loads(redis_client.get(f"chat:{session_id}") or '[]') + + # 构建包含历史对话的完整提示 + full_prompt = DEFAULT_SYSTEM_PROMPT + "\n" + for past_query, past_response in history: + full_prompt += f"用户: {past_query}\n助手: {past_response}\n" + full_prompt += f"用户: {query}\n助手:" + + data = { + "model": model, + "prompt": full_prompt, + "stream": True, + "temperature": 0 + } + + response = requests.post("https://ffgregevrdcfyhtnhyudvr.myfastools.com/api/generate", json=data, stream=True) + response.raise_for_status() + + text_output = "" + for line in response.iter_lines(): + if line: + json_data = json.loads(line) + if 'response' in json_data: + text_output += json_data['response'] + + # 更新历史记录 (使用 session_id) + history.append((query, text_output)) + redis_client.set(f"chat:{session_id}", json.dumps(history)) + + # 设置任务状态为 "completed" 并存储响应 (使用 task_id) + redis_task_client.set(f"chat:{task_id}:status", "completed") + redis_task_client.set(f"chat:{task_id}:response", text_output) + + # 存储当前任务的结果到 REDIS_TASK_DB (db3) + redis_task_client.set(f"chat:{task_id}:result", json.dumps({"query": query, "response": text_output})) + + print(f"处理完成 task_id {task_id}, session_id {session_id}: {text_output}") + + except Exception as e: + print(f"处理 task {task_id} 时出错: {str(e)}") + # 设置任务状态为 "error" + redis_task_client.set(f"chat:{task_id}:status", "error") + redis_task_client.set(f"chat:{task_id}:error", str(e)) +def kafka_consumer_thread(consumer_id): + consumer = create_kafka_consumer() + print(f"消费者 {consumer_id} 已启动") + for message in consumer: + chat_request = message.value + asyncio.run(process_chat_request(chat_request)) + +def main(): + print("启动Kafka消费者处理聊天请求...") + with ThreadPoolExecutor(max_workers=KAFKA_CONSUMER_NUM) as executor: + for i in range(KAFKA_CONSUMER_NUM): + executor.submit(kafka_consumer_thread, i) + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/api_chat/docs/.gitkeep b/api_chat/docs/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/api_chat/mp4_to_wav.py b/api_chat/mp4_to_wav.py new file mode 100644 index 0000000..3838db7 --- /dev/null +++ b/api_chat/mp4_to_wav.py @@ -0,0 +1,63 @@ +import os +from moviepy.editor import VideoFileClip + +def mp4_to_wav(input_file, output_file): + """ + 将MP4文件转换为WAV格式 + + :param input_file: 输入的MP4文件路径 + :param output_file: 输出的WAV文件路径 + """ + try: + # 加载视频文件 + video = VideoFileClip(input_file) + + # 提取音频 + audio = video.audio + + # 将音频写入WAV文件 + audio.write_audiofile(output_file) + + # 关闭视频和音频对象 + audio.close() + video.close() + + print(f"转换成功: {input_file} -> {output_file}") + except Exception as e: + print(f"转换失败: {input_file} - {str(e)}") + +def process_directory(directory): + """ + 处理目录中的所有MP4文件 + + :param directory: 包含MP4文件的目录路径 + """ + for filename in os.listdir(directory): + if filename.lower().endswith('.mp4'): + input_file = os.path.join(directory, filename) + output_file = os.path.splitext(input_file)[0] + ".wav" + mp4_to_wav(input_file, output_file) + +def main(): + # 获取输入路径 + input_path = input("请输入MP4文件或包含MP4文件的目录路径: ").strip() + + # 检查输入路径是否存在 + if not os.path.exists(input_path): + print("错误: 输入路径不存在") + return + + # 判断输入路径是文件还是目录 + if os.path.isfile(input_path): + if not input_path.lower().endswith('.mp4'): + print("错误: 输入文件不是MP4格式") + return + output_file = os.path.splitext(input_path)[0] + ".wav" + mp4_to_wav(input_path, output_file) + elif os.path.isdir(input_path): + process_directory(input_path) + else: + print("错误: 输入路径既不是文件也不是目录") + +if __name__ == "__main__": + main() diff --git a/api_chat/ollamas.py b/api_chat/ollamas.py new file mode 100644 index 0000000..4d2db09 --- /dev/null +++ b/api_chat/ollamas.py @@ -0,0 +1,136 @@ +from fastapi import FastAPI, HTTPException +from pydantic import BaseModel +from fastapi.middleware.cors import CORSMiddleware +import httpx +import json +import redis +from typing import List, Dict, Optional +import logging +import ollama +import uuid + +app = FastAPI() + +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + +# Redis连接 +redis_client = redis.Redis(host='222.186.10.253', port=6379, db=14, password="Obscura@2024") + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +class GenerateRequest(BaseModel): + model: Optional[str] = "qwen2.5:3b" + prompt: str + +class RawGenerateRequest(BaseModel): + model: Optional[str] = "qwen2.5:3b" + prompt: str + system_prompt: Optional[str] = None + stream: Optional[bool] = False + raw: Optional[bool] = False + format: Optional[str] = None + options: Optional[Dict] = None + +class GenerateResponse(BaseModel): + response: dict + request_id: str + +@app.post("/generate", response_model=GenerateResponse) +async def generate(request: GenerateRequest): + logger.info(f"收到请求: {request}") + + request_id = str(uuid.uuid4()) + + try: + response = ollama.chat(model=request.model, messages=[{"role": "user", "content": request.prompt}]) + full_response = response['message']['content'] + + request_data = { + "model": request.model, + "prompt": request.prompt, + "response": full_response + } + + redis_client.set(f"request:{request_id}", json.dumps(request_data)) + + response_data = { + "response": full_response, + "model": request.model + } + + return GenerateResponse(response=response_data, request_id=request_id) + + except Exception as e: + logger.error(f"发生错误: {e}") + raise HTTPException(status_code=500, detail=str(e)) + +@app.post("/api/generate") +async def generate_without_history(request: RawGenerateRequest): + """ + 处理无历史记录的生成请求。 + + 参数: + - request: RawGenerateRequest对象,包含生成请求的所有参数。 + + 返回: + - 包含生成结果的字典。 + """ + try: + response = ollama.generate( + model=request.model, + prompt=request.prompt, + system=request.system_prompt, + format=request.format, + options=request.options, + stream=request.stream + ) + + response_data = { + "model": request.model, + "response": response['response'], + "done": True, + "context": response.get('context'), + "total_duration": response.get('total_duration'), + "load_duration": response.get('load_duration'), + "prompt_eval_count": response.get('prompt_eval_count'), + "prompt_eval_duration": response.get('prompt_eval_duration'), + "eval_count": response.get('eval_count'), + "eval_duration": response.get('eval_duration') + } + + request_id = str(uuid.uuid4()) + redis_client.set(f"request:{request_id}", json.dumps(response_data)) + + return response_data + + except Exception as e: + logger.error(f"发生未预期的错误: {e}") + logger.exception("详细错误信息:") + raise HTTPException(status_code=500, detail=f"处理Ollama请求时发生错误: {str(e)}") + +@app.get("/request/{request_id}", response_model=Dict) +async def get_request(request_id: str): + request_data = redis_client.get(f"request:{request_id}") + if request_data: + return json.loads(request_data) + raise HTTPException(status_code=404, detail="请求未找到") + +@app.get("/models") +async def list_models(): + return ollama.list() + +@app.get("/models/{model_name}") +async def show_model(model_name: str): + return ollama.show(model_name) + + +if __name__ == "__main__": + import uvicorn + uvicorn.run(app, host="0.0.0.0", port=7000) \ No newline at end of file diff --git a/api_chat/producer_chat.py b/api_chat/producer_chat.py new file mode 100644 index 0000000..75a70d2 --- /dev/null +++ b/api_chat/producer_chat.py @@ -0,0 +1,406 @@ +from fastapi import FastAPI, HTTPException, Depends, Security, File, UploadFile +from fastapi.middleware.cors import CORSMiddleware +from fastapi.security import APIKeyHeader +from fastapi.responses import FileResponse +from pydantic import BaseModel +from kafka import KafkaProducer +from redis import Redis +import os +import json +import uuid +from datetime import datetime, timezone +from dotenv import load_dotenv +import tempfile +import hashlib +from pydantic import BaseModel, Field + +# 在文件顶部添加这个函数 +def get_audio_hash(text): + return hashlib.md5(text.encode()).hexdigest() + +# 加载 .env 文件 +load_dotenv() + +app = FastAPI() +v1_chat_app = FastAPI() +app.mount("/v1_chat", v1_chat_app) + +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + +# 配置 +KAFKA_BROKER = os.getenv('KAFKA_BROKER') +REDIS_HOST = os.getenv('REDIS_HOST') +REDIS_PORT = int(os.getenv('REDIS_PORT')) +REDIS_PASSWORD = os.getenv('REDIS_PASSWORD') +REDIS_TTS_DB = int(os.getenv('REDIS_TTS_DB')) +REDIS_ASR_DB = int(os.getenv('REDIS_ASR_DB')) +REDIS_CHAT_DB = int(os.getenv('REDIS_CHAT_DB')) +REDIS_API_DB = int(os.getenv('REDIS_API_DB')) +REDIS_API_USAGE_DB = int(os.getenv('REDIS_API_USAGE_DB')) +REDIS_TASK_DB = int(os.getenv('REDIS_TASK_DB')) + + +# Redis 配置 +REDIS_GIRL_DB = int(os.getenv('REDIS_GIRL_DB')) +REDIS_WOMAN_DB = int(os.getenv('REDIS_WOMAN_DB')) +REDIS_MAN_DB = int(os.getenv('REDIS_MAN_DB')) +REDIS_LEIJUN_DB = int(os.getenv('REDIS_LEIJUN_DB')) +REDIS_DUFU_DB = int(os.getenv('REDIS_DUFU_DB')) +REDIS_HEJIONG_DB = int(os.getenv('REDIS_HEJIONG_DB')) +REDIS_MAHUATENG_DB = int(os.getenv('REDIS_MAHUATENG_DB')) +REDIS_LIDAN_DB = int(os.getenv('REDIS_LIDAN_DB')) +REDIS_YUHUA_DB = int(os.getenv('REDIS_YUHUA_DB')) +REDIS_LIUZHENYUN_DB = int(os.getenv('REDIS_LIUZHENYUN_DB')) +REDIS_DABING_DB = int(os.getenv('REDIS_DABING_DB')) +REDIS_LUOXIANG_DB = int(os.getenv('REDIS_LUOXIANG_DB')) +REDIS_XUZHIYUAN_DB = int(os.getenv('REDIS_XUZHIYUAN_DB')) + +KAFKA_TTS_TOPIC = os.getenv('KAFKA_TTS_TOPIC') +KAFKA_ASR_TOPIC = os.getenv('KAFKA_ASR_TOPIC') +KAFKA_CHAT_TOPIC = os.getenv('KAFKA_CHAT_TOPIC') + +OUTPUT_PATH= os.getenv('OUTPUT_PATH') + +# 初始化 Kafka Producer +producer = KafkaProducer( + bootstrap_servers=[KAFKA_BROKER], + value_serializer=lambda v: json.dumps(v).encode('utf-8') +) + +# 初始化 Redis +redis_tts_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_TTS_DB) +redis_asr_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_ASR_DB) +redis_chat_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_CHAT_DB) +redis_api_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_API_DB) +redis_api_usage_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_API_USAGE_DB) +redis_task_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_TASK_DB) + +redis_tts_girl = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_GIRL_DB) +redis_tts_woman = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_WOMAN_DB) +redis_tts_man = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_MAN_DB) +redis_tts_leijun = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_LEIJUN_DB) +redis_tts_dufu = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_DUFU_DB) +redis_tts_hejiong = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_HEJIONG_DB) +redis_tts_mahuateng = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_MAHUATENG_DB) +redis_tts_lidan = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_LIDAN_DB) +redis_tts_yuhua = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_YUHUA_DB) +redis_tts_liuzhenyun = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_LIUZHENYUN_DB) +redis_tts_dabing = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_DABING_DB) +redis_tts_luoxiang = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_LUOXIANG_DB) +redis_tts_xuzhiyuan = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_XUZHIYUAN_DB) + +# 创建一个音色到对应 Redis 客户端的映射 +voice_to_redis = { + "default": redis_tts_girl, + "girl": redis_tts_girl, + "woman": redis_tts_woman, + "man": redis_tts_man, + "leijun": redis_tts_leijun, + "dufu": redis_tts_dufu, + "hejiong": redis_tts_hejiong, + "mahuateng": redis_tts_mahuateng, + "lidan": redis_tts_lidan, + "yuhua": redis_tts_yuhua, + "liuzhenyun": redis_tts_liuzhenyun, + "dabing": redis_tts_dabing, + "luoxiang": redis_tts_luoxiang, + "xuzhiyuan": redis_tts_xuzhiyuan +} + + +# 定义API密钥头部 +api_key_header = APIKeyHeader(name="Authorization", auto_error=False) + +def get_audio_hash(text): + return hashlib.md5(text.encode()).hexdigest() + +# 验证API密钥的函数 +async def get_api_key(api_key: str = Security(api_key_header)): + if api_key and api_key.startswith("Bearer "): + key = api_key.split(" ")[1] + if key.startswith("obs-"): + return key + raise HTTPException( + status_code=401, + detail="无效的API密钥", + headers={"WWW-Authenticate": "Bearer"}, + ) + +async def verify_api_key(api_key: str = Depends(get_api_key)): + redis_key = f"api_key:{api_key}" + + api_key_info = redis_api_client.hgetall(redis_key) + + if not api_key_info: + raise HTTPException(status_code=401, detail="无效的API密钥") + + api_key_info = {k.decode('utf-8'): v.decode('utf-8') for k, v in api_key_info.items()} + + if api_key_info.get('is_active') != '1': + raise HTTPException(status_code=401, detail="API密钥已停用") + + expires_at = datetime.fromisoformat(api_key_info.get('expires_at')) + if datetime.now(timezone.utc) > expires_at: + raise HTTPException(status_code=401, detail="API密钥已过期") + + usage_info = redis_api_usage_client.hgetall(redis_key) + usage_info = {k.decode('utf-8'): v.decode('utf-8') for k, v in usage_info.items()} + + return { + "api_key": api_key, + **api_key_info, + **usage_info + } + +async def update_token_usage(api_key: str, new_tokens_used: int, model_name: str): + redis_key = f"api_key:{api_key}" + current_time = datetime.now(timezone.utc).isoformat() + + pipe = redis_api_usage_client.pipeline() + + pipe.hincrby(redis_key, "tokens_used", new_tokens_used) + pipe.hset(redis_key, "last_used_at", current_time) + + model_tokens_field = f"{model_name}_tokens_used" + model_last_used_field = f"{model_name}_last_used_at" + + pipe.hincrby(redis_key, model_tokens_field, new_tokens_used) + pipe.hset(redis_key, model_last_used_field, current_time) + + pipe.execute() + +async def process_request(api_key_info: dict, model_name: str, tokens_required: int, task_data: dict, kafka_topic: str): + api_key = api_key_info['api_key'] + usage_key = f"api_key:{api_key}" + total_tokens = int(redis_api_usage_client.hget(usage_key, "total_tokens") or 0) + tokens_used = int(redis_api_usage_client.hget(usage_key, "tokens_used") or 0) + + if tokens_used + tokens_required > total_tokens: + raise HTTPException(status_code=403, detail="Token 余额不足") + + # 更新 token 使用量 + await update_token_usage(api_key, tokens_required, model_name) + + # 发送任务到Kafka + producer.send(kafka_topic, task_data) + + # 获取更新后的 token 使用情况 + updated_api_key_info = await verify_api_key(api_key) + new_tokens_used = int(updated_api_key_info.get("tokens_used", 0)) + model_tokens_used = int(updated_api_key_info.get(f"{model_name}_tokens_used", 0)) + + return { + "message": f"{model_name.upper()}请求已排队等待处理", + "tokens_used": tokens_required, + "total_tokens_used": new_tokens_used, + f"{model_name}_tokens_used": model_tokens_used, + "tokens_remaining": total_tokens - new_tokens_used + } + +class TTSRequest(BaseModel): + text: str + voice: str = Field(..., description="选择的音色") + +class ChatRequest(BaseModel): + session_id: str + query: str + model: str = "qwen2.5:3b" + + +@v1_chat_app.post("/tts") +async def tts_request(request: TTSRequest, api_key_info: dict = Depends(verify_api_key)): + task_id = str(uuid.uuid4()) + text_hash = get_audio_hash(request.text) + + # 验证音色选择 + valid_voices = ["default", "girl", "woman", "man", "leijun", "dufu", "hejiong", "mahuateng", "lidan", "yuhua", "liuzhenyun", "dabing", "luoxiang", "xuzhiyuan"] + if request.voice not in valid_voices: + raise HTTPException(status_code=400, detail="无效的音色选择") + + # 如果声音是 'default',则将其视为 'girl' + voice = 'girl' if request.voice == 'default' else request.voice + + # 使用对应音色的 Redis 客户端 + redis_tts = voice_to_redis[request.voice] + + # 检查是否已存在相同内容的音频文件 + existing_audio_info = redis_tts.get(f"tts:{text_hash}") + if existing_audio_info: + existing_audio_path = json.loads(existing_audio_info)['path'] + if os.path.exists(existing_audio_path): + return { + "message": "TTS请求已完成", + "task_id": task_id, + "status": "completed", + "audio_path": existing_audio_path + } + + # 如果不存在,创建新的任务 + task_data = { + "task_id": task_id, + "text": request.text, + "text_hash": text_hash, + "voice": request.voice, + "status": "queued", + "created_at": datetime.now(timezone.utc).isoformat(), + } + + # 存储任务信息到Redis + redis_task_client.set(f"task_status:tts:{task_id}", "queued") + + result = await process_request(api_key_info, "tts", 1, task_data, KAFKA_TTS_TOPIC) + result["task_id"] = task_id + return result +@v1_chat_app.post("/asr") +async def asr_request(audio: UploadFile = File(...), api_key_info: dict = Depends(verify_api_key)): + task_id = str(uuid.uuid4()) + + UPLOAD_DIR = "/obscura/task/audio_upload" + os.makedirs(UPLOAD_DIR, exist_ok=True) + file_path = os.path.join(UPLOAD_DIR, f"{task_id}.wav") + + with open(file_path, "wb") as temp_audio: + content = await audio.read() + temp_audio.write(content) + + task_data = { + 'file_path': file_path, + 'task_id': task_id, + 'status': 'queued' + } + + # 存储任务状态,使用一致的键名格式 + redis_task_client.set(f"task_status:asr:{task_id}", "queued") + + result = await process_request(api_key_info, "asr", 1, task_data, KAFKA_ASR_TOPIC) + result["task_id"] = task_id + return result + +@v1_chat_app.post("/chat") +async def chat_request(request: ChatRequest, api_key_info: dict = Depends(verify_api_key)): + task_id = str(uuid.uuid4()) + task_data = { + "task_id": task_id, + "session_id": request.session_id, + "query": request.query, + "model": request.model, + "status": "queued", + "created_at": datetime.now(timezone.utc).isoformat(), + } + + # 设置任务状态为 "queued" + redis_task_client.set(f"chat:{task_id}:status", "queued") + + result = await process_request(api_key_info, "chat", 1, task_data, KAFKA_CHAT_TOPIC) + result["task_id"] = task_id + return result + + +@v1_chat_app.get("/chat_result/{task_id}") +async def get_chat_result(task_id: str, api_key_info: dict = Depends(verify_api_key)): + # 从Redis任务数据库获取任务状态 + task_status = redis_task_client.get(f"chat:{task_id}:status") + if task_status: + status = task_status.decode('utf-8') + if status == "completed": + # 从Redis任务数据库获取聊天结果 + chat_result = redis_task_client.get(f"chat:{task_id}:result") + if chat_result: + result = json.loads(chat_result) + return { + "status": "completed", + "result": result + } + return {"status": status} + + return {"status": "not_found"} + +@v1_chat_app.get("/tts_result/{task_id}") +async def get_tts_result(task_id: str, api_key_info: dict = Depends(verify_api_key)): + task_status = redis_task_client.get(f"task_status:tts:{task_id}") + if task_status: + status = task_status.decode('utf-8') + if status == "completed": + task_info = redis_task_client.get(f"task_info:tts:{task_id}") + if task_info: + task_data = json.loads(task_info) + text_hash = task_data['text_hash'] + voice = task_data['voice'] + # 'default' 和 'girl' 都使用 girl 的 Redis + redis_tts = voice_to_redis['girl'] if voice in ['default', 'girl'] else voice_to_redis[voice] + + audio_info = redis_tts.get(f"tts:{text_hash}") + if audio_info: + audio_path = json.loads(audio_info)['path'] + return { + "status": "completed", + "audio_path": audio_path + } + return {"status": status} + + return {"status": "not_found"} + +@v1_chat_app.get("/asr_result/{task_id}") +async def get_asr_result(task_id: str, api_key_info: dict = Depends(verify_api_key)): + # 从Redis任务数据库获取任务状态,使用一致的键名格式 + task_status = redis_task_client.get(f"task_status:asr:{task_id}") + if task_status: + status = task_status.decode('utf-8') + if status == "completed": + # 从Redis ASR结果数据库获取转录结果 + transcription = redis_asr_client.get(f"asr:{task_id}") + if transcription: + return { + "status": "completed", + "transcription": transcription.decode('utf-8') + } + return {"status": status} + + return {"status": "not_found"} + +@v1_chat_app.get("/tts_audio/{task_id}") +async def get_tts_audio(task_id: str, api_key_info: dict = Depends(verify_api_key)): + task_status = redis_task_client.get(f"task_status:tts:{task_id}") + if task_status: + status = task_status.decode('utf-8') + if status == "completed": + # 从任务信息中获取使用的音色 + task_info = redis_task_client.get(f"task_info:tts:{task_id}") + if task_info: + task_data = json.loads(task_info) + voice = task_data.get('voice', 'girl') # 默认使用 'girl' + # 'default' 和 'girl' 都使用 girl 的 Redis + redis_tts = voice_to_redis['girl'] if voice in ['default', 'girl'] else voice_to_redis[voice] + + # 从对应音色的 Redis 获取音频文件路径 + audio_info = redis_tts.get(f"tts:{task_data['text_hash']}") + if audio_info: + audio_path = json.loads(audio_info)['path'] + if os.path.exists(audio_path): + file_name = os.path.basename(audio_path) + return FileResponse(audio_path, media_type="audio/wav", filename=file_name) + else: + raise HTTPException(status_code=404, detail="音频文件不存在") + elif status == "queued" or status == "processing": + raise HTTPException(status_code=202, detail="音频文件正在生成中") + else: + raise HTTPException(status_code=500, detail="任务处理失败") + + raise HTTPException(status_code=404, detail="任务不存在") +@v1_chat_app.get("/getvoice") +async def get_available_voices(api_key_info: dict = Depends(verify_api_key)): + valid_voices = ["default", "girl", "woman", "man", "leijun", "dufu", "hejiong", "mahuateng", "lidan", "yuhua", "liuzhenyun", "dabing", "luoxiang", "xuzhiyuan"] + return {"available_voices": valid_voices} + + +if __name__ == "__main__": + import uvicorn + uvicorn.run(app, host="0.0.0.0", port=8008) + + diff --git a/api_chat/runtime/.gitkeep b/api_chat/runtime/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/api_chat/sample/dabing.txt b/api_chat/sample/dabing.txt new file mode 100644 index 0000000..29b3d63 --- /dev/null +++ b/api_chat/sample/dabing.txt @@ -0,0 +1,5 @@ +那一年济南冬天零下十几度 +我暖气费交不上 +因为从那年往前推好几年 +我不接任何的商业了早就 +因为不愿意再唱唐会了嘛 \ No newline at end of file diff --git a/api_chat/sample/dabing.wav b/api_chat/sample/dabing.wav new file mode 100644 index 0000000..bbb1101 Binary files /dev/null and b/api_chat/sample/dabing.wav differ diff --git a/api_chat/sample/dufu.txt b/api_chat/sample/dufu.txt new file mode 100644 index 0000000..b32f178 --- /dev/null +++ b/api_chat/sample/dufu.txt @@ -0,0 +1,4 @@ +金无足赤人无完人嘛 +连朕也出过错误 +就说这黄羽全图吧 +朕每次见着他 diff --git a/api_chat/sample/dufu.wav b/api_chat/sample/dufu.wav new file mode 100644 index 0000000..1b41041 Binary files /dev/null and b/api_chat/sample/dufu.wav differ diff --git a/api_chat/sample/gril.txt b/api_chat/sample/gril.txt new file mode 100644 index 0000000..3bb82cf --- /dev/null +++ b/api_chat/sample/gril.txt @@ -0,0 +1,3 @@ +一些研究表明一個晚上良好的睡眠 +就能幫助大腦恢復到最佳狀態 +所以如果你已經一週 diff --git a/api_chat/sample/gril.wav b/api_chat/sample/gril.wav new file mode 100644 index 0000000..6a8169d Binary files /dev/null and b/api_chat/sample/gril.wav differ diff --git a/api_chat/sample/hejiong.txt b/api_chat/sample/hejiong.txt new file mode 100644 index 0000000..250e704 --- /dev/null +++ b/api_chat/sample/hejiong.txt @@ -0,0 +1,4 @@ +很多年前我是主持人 +做音乐节目 +然后当时我们节目敲了当年最红的一个歌手 +叫陈冠希 \ No newline at end of file diff --git a/api_chat/sample/hejiong.wav b/api_chat/sample/hejiong.wav new file mode 100644 index 0000000..0b62700 Binary files /dev/null and b/api_chat/sample/hejiong.wav differ diff --git a/api_chat/sample/leijun.txt b/api_chat/sample/leijun.txt new file mode 100644 index 0000000..487aa60 --- /dev/null +++ b/api_chat/sample/leijun.txt @@ -0,0 +1,3 @@ +我们在短短的半年之间的时间里面 +就组成了超过一千人的团队 +在过去三年多的时间里面 diff --git a/api_chat/sample/leijun.wav b/api_chat/sample/leijun.wav new file mode 100644 index 0000000..88f4345 Binary files /dev/null and b/api_chat/sample/leijun.wav differ diff --git a/api_chat/sample/lidan.txt b/api_chat/sample/lidan.txt new file mode 100644 index 0000000..6a2a595 --- /dev/null +++ b/api_chat/sample/lidan.txt @@ -0,0 +1,4 @@ +他去那个商场 +两口子去逛 买电视 +大家知道现在的智能电视那个遥控器都有一个语音搜索功能 +年轻人不怎么用其实 \ No newline at end of file diff --git a/api_chat/sample/lidan.wav b/api_chat/sample/lidan.wav new file mode 100644 index 0000000..8314add Binary files /dev/null and b/api_chat/sample/lidan.wav differ diff --git a/api_chat/sample/liuzhenyun.txt b/api_chat/sample/liuzhenyun.txt new file mode 100644 index 0000000..9e12676 --- /dev/null +++ b/api_chat/sample/liuzhenyun.txt @@ -0,0 +1,5 @@ +在我们村里 +最有见识的人呢 +是我舅 +他是个赶马车的 +他不但去过县城 \ No newline at end of file diff --git a/api_chat/sample/liuzhenyun.wav b/api_chat/sample/liuzhenyun.wav new file mode 100644 index 0000000..e3ab8a7 Binary files /dev/null and b/api_chat/sample/liuzhenyun.wav differ diff --git a/api_chat/sample/luoxiang.txt b/api_chat/sample/luoxiang.txt new file mode 100644 index 0000000..746bb64 --- /dev/null +++ b/api_chat/sample/luoxiang.txt @@ -0,0 +1,3 @@ +所以无论是中外 +最大限度的反抗标准 +在很长一段时间都是一种主流立场 \ No newline at end of file diff --git a/api_chat/sample/luoxiang.wav b/api_chat/sample/luoxiang.wav new file mode 100644 index 0000000..9dee240 Binary files /dev/null and b/api_chat/sample/luoxiang.wav differ diff --git a/api_chat/sample/mahuateng.txt b/api_chat/sample/mahuateng.txt new file mode 100644 index 0000000..8e217de --- /dev/null +++ b/api_chat/sample/mahuateng.txt @@ -0,0 +1,2 @@ +那么今天呢 政治主持人介绍是我们第四次的互联网家峰会 +那么这次的规模是世界以来规模最大的 diff --git a/api_chat/sample/mahuateng.wav b/api_chat/sample/mahuateng.wav new file mode 100644 index 0000000..88540bd Binary files /dev/null and b/api_chat/sample/mahuateng.wav differ diff --git a/api_chat/sample/man.txt b/api_chat/sample/man.txt new file mode 100644 index 0000000..141918d --- /dev/null +++ b/api_chat/sample/man.txt @@ -0,0 +1,2 @@ +今年以來 我國全力推動鄉村產業全鏈條升級 +鄉村產業振興呈現良好勢頭 diff --git a/api_chat/sample/man.wav b/api_chat/sample/man.wav new file mode 100644 index 0000000..977ee4d Binary files /dev/null and b/api_chat/sample/man.wav differ diff --git a/api_chat/sample/woman.txt b/api_chat/sample/woman.txt new file mode 100644 index 0000000..eefa986 --- /dev/null +++ b/api_chat/sample/woman.txt @@ -0,0 +1,3 @@ +政法機關堅持黨對政法工作的絕對領導 +推動政法體制和工作機制 +實現歷史性變革 \ No newline at end of file diff --git a/api_chat/sample/woman.wav b/api_chat/sample/woman.wav new file mode 100644 index 0000000..5af03e8 Binary files /dev/null and b/api_chat/sample/woman.wav differ diff --git a/api_chat/sample/xuzhiyuan.txt b/api_chat/sample/xuzhiyuan.txt new file mode 100644 index 0000000..0281a36 --- /dev/null +++ b/api_chat/sample/xuzhiyuan.txt @@ -0,0 +1,5 @@ +我们约了他去做他的采访 +他已经答应了 +然后结果去那天 +他那时候已经得了白血病了 +他说真是不巧不好意思 \ No newline at end of file diff --git a/api_chat/sample/xuzhiyuan.wav b/api_chat/sample/xuzhiyuan.wav new file mode 100644 index 0000000..f13effc Binary files /dev/null and b/api_chat/sample/xuzhiyuan.wav differ diff --git a/api_chat/sample/yuhua.txt b/api_chat/sample/yuhua.txt new file mode 100644 index 0000000..aec9d49 --- /dev/null +++ b/api_chat/sample/yuhua.txt @@ -0,0 +1,5 @@ +很小很小的房间 +来人的话呢 +如果一个人说要外出去上厕所 +因为厕所是公共厕所 +所有人都得起来走到外面去 \ No newline at end of file diff --git a/api_chat/sample/yuhua.wav b/api_chat/sample/yuhua.wav new file mode 100644 index 0000000..59b7a21 Binary files /dev/null and b/api_chat/sample/yuhua.wav differ diff --git a/api_chat/tools/.gitkeep b/api_chat/tools/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/api_chat/tts.py b/api_chat/tts.py new file mode 100644 index 0000000..a2ad58e --- /dev/null +++ b/api_chat/tts.py @@ -0,0 +1,315 @@ +import os +import soundfile as sf +import redis +import hashlib +import json +import traceback +from kafka import KafkaConsumer +from tools.i18n.i18n import I18nAuto +from GPT_SoVITS.inference_webui import change_gpt_weights, change_sovits_weights, get_tts_wav +from dotenv import load_dotenv +import torch + +# 加载 .env 文件 +load_dotenv() +device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") +print(f"使用设备: {device}") + +# Redis 配置 +REDIS_HOST = os.getenv('REDIS_HOST') +REDIS_PORT = int(os.getenv('REDIS_PORT')) +REDIS_TASK_DB = int(os.getenv('REDIS_TASK_DB')) # DB 3 +REDIS_PASSWORD = os.getenv('REDIS_PASSWORD') + +# Kafka 配置 +KAFKA_BROKER = os.getenv('KAFKA_BROKER') +KAFKA_TTS_TOPIC = os.getenv('KAFKA_TTS_TOPIC') + +# TTS 配置 +GPT_MODEL_PATH = os.getenv('GPT_MODEL_PATH') +SOVITS_MODEL_PATH = os.getenv('SOVITS_MODEL_PATH') +REF_LANGUAGE = os.getenv('REF_LANGUAGE') +TARGET_LANGUAGE = os.getenv('TARGET_LANGUAGE') +OUTPUT_PATH = os.getenv('OUTPUT_PATH') + +# Redis 配置 +REDIS_GIRL_DB = int(os.getenv('REDIS_GIRL_DB')) +REDIS_WOMAN_DB = int(os.getenv('REDIS_WOMAN_DB')) +REDIS_MAN_DB = int(os.getenv('REDIS_MAN_DB')) +REDIS_LEIJUN_DB = int(os.getenv('REDIS_LEIJUN_DB')) +REDIS_DUFU_DB = int(os.getenv('REDIS_DUFU_DB')) +REDIS_HEJIONG_DB = int(os.getenv('REDIS_HEJIONG_DB')) +REDIS_MAHUATENG_DB = int(os.getenv('REDIS_MAHUATENG_DB')) +REDIS_LIDAN_DB = int(os.getenv('REDIS_LIDAN_DB')) +REDIS_YUHUA_DB = int(os.getenv('REDIS_YUHUA_DB')) +REDIS_LIUZHENYUN_DB = int(os.getenv('REDIS_LIUZHENYUN_DB')) +REDIS_DABING_DB = int(os.getenv('REDIS_DABING_DB')) +REDIS_LUOXIANG_DB = int(os.getenv('REDIS_LUOXIANG_DB')) +REDIS_XUZHIYUAN_DB = int(os.getenv('REDIS_XUZHIYUAN_DB')) + +# 初始化 Redis 客户端 +redis_tts_girl = redis.Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_GIRL_DB) +redis_tts_woman = redis.Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_WOMAN_DB) +redis_tts_man = redis.Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_MAN_DB) +redis_tts_leijun = redis.Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_LEIJUN_DB) +redis_tts_dufu = redis.Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_DUFU_DB) +redis_tts_hejiong = redis.Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_HEJIONG_DB) +redis_tts_mahuateng = redis.Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_MAHUATENG_DB) +redis_tts_lidan = redis.Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_LIDAN_DB) +redis_tts_yuhua = redis.Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_YUHUA_DB) +redis_tts_liuzhenyun = redis.Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_LIUZHENYUN_DB) +redis_tts_dabing = redis.Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_DABING_DB) +redis_tts_luoxiang = redis.Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_LUOXIANG_DB) +redis_tts_xuzhiyuan = redis.Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_XUZHIYUAN_DB) + +redis_task_client = redis.Redis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_TASK_DB, password=REDIS_PASSWORD) + +# 创建音色到对应 Redis 客户端的映射 +voice_to_redis = { + "default": redis_tts_girl, + "girl": redis_tts_girl, + "woman": redis_tts_woman, + "man": redis_tts_man, + "leijun": redis_tts_leijun, + "dufu": redis_tts_dufu, + "hejiong": redis_tts_hejiong, + "mahuateng": redis_tts_mahuateng, + "lidan": redis_tts_lidan, + "yuhua": redis_tts_yuhua, + "liuzhenyun": redis_tts_liuzhenyun, + "dabing": redis_tts_dabing, + "luoxiang": redis_tts_luoxiang, + "xuzhiyuan": redis_tts_xuzhiyuan +} + +i18n = I18nAuto() + +# Voice configurations +GIRL_REF_AUDIO = os.getenv('GIRL_REF_AUDIO') +GIRL_REF_TEXT = os.getenv('GIRL_REF_TEXT') +WOMAN_REF_AUDIO = os.getenv('WOMAN_REF_AUDIO') +WOMAN_REF_TEXT = os.getenv('WOMAN_REF_TEXT') +MAN_REF_AUDIO = os.getenv('MAN_REF_AUDIO') +MAN_REF_TEXT = os.getenv('MAN_REF_TEXT') +LEIJUN_REF_AUDIO = os.getenv('LEIJUN_REF_AUDIO') +LEIJUN_REF_TEXT = os.getenv('LEIJUN_REF_TEXT') +DUFU_REF_AUDIO = os.getenv('DUFU_REF_AUDIO') +DUFU_REF_TEXT = os.getenv('DUFU_REF_TEXT') +HEJIONG_REF_AUDIO = os.getenv('HEJIONG_REF_AUDIO') +HEJIONG_REF_TEXT = os.getenv('HEJIONG_REF_TEXT') +MAHUATENG_REF_AUDIO = os.getenv('MAHUATENG_REF_AUDIO') +MAHUATENG_REF_TEXT = os.getenv('MAHUATENG_REF_TEXT') +LIDAN_REF_AUDIO = os.getenv('LIDAN_REF_AUDIO') +LIDAN_REF_TEXT = os.getenv('LIDAN_REF_TEXT') +YUHUA_REF_AUDIO = os.getenv('YUHUA_REF_AUDIO') +YUHUA_REF_TEXT = os.getenv('YUHUA_REF_TEXT') +LIUZHENYUN_REF_AUDIO = os.getenv('LIUZHENYUN_REF_AUDIO') +LIUZHENYUN_REF_TEXT = os.getenv('LIUZHENYUN_REF_TEXT') +DABING_REF_AUDIO = os.getenv('DABING_REF_AUDIO') +DABING_REF_TEXT = os.getenv('DABING_REF_TEXT') +LUOXIANG_REF_AUDIO = os.getenv('LUOXIANG_REF_AUDIO') +LUOXIANG_REF_TEXT = os.getenv('LUOXIANG_REF_TEXT') +XUZHIYUAN_REF_AUDIO = os.getenv('XUZHIYUAN_REF_AUDIO') +XUZHIYUAN_REF_TEXT = os.getenv('XUZHIYUAN_REF_TEXT') + +VOICE_CONFIGS = { + "girl": { + "ref_audio": GIRL_REF_AUDIO, + "ref_text": GIRL_REF_TEXT, + "ref_language": REF_LANGUAGE + }, + "woman": { + "ref_audio": WOMAN_REF_AUDIO, + "ref_text": WOMAN_REF_TEXT, + "ref_language": REF_LANGUAGE + }, + "man": { + "ref_audio": MAN_REF_AUDIO, + "ref_text": MAN_REF_TEXT, + "ref_language": REF_LANGUAGE + }, + "leijun": { + "ref_audio": LEIJUN_REF_AUDIO, + "ref_text": LEIJUN_REF_TEXT, + "ref_language": REF_LANGUAGE + }, + "dufu": { + "ref_audio": DUFU_REF_AUDIO, + "ref_text": DUFU_REF_TEXT, + "ref_language": REF_LANGUAGE + }, + "hejiong": { + "ref_audio": HEJIONG_REF_AUDIO, + "ref_text": HEJIONG_REF_TEXT, + "ref_language": REF_LANGUAGE + }, + "mahuateng": { + "ref_audio": MAHUATENG_REF_AUDIO, + "ref_text": MAHUATENG_REF_TEXT, + "ref_language": REF_LANGUAGE + }, + "lidan": { + "ref_audio": LIDAN_REF_AUDIO, + "ref_text": LIDAN_REF_TEXT, + "ref_language": REF_LANGUAGE + }, + "default": { + "ref_audio": GIRL_REF_AUDIO, + "ref_text": GIRL_REF_TEXT, + "ref_language": REF_LANGUAGE + }, + "yuhua": { + "ref_audio": YUHUA_REF_AUDIO, + "ref_text": YUHUA_REF_TEXT, + "ref_language": REF_LANGUAGE + }, + "liuzhenyun": { + "ref_audio": LIUZHENYUN_REF_AUDIO, + "ref_text": LIUZHENYUN_REF_TEXT, + "ref_language": REF_LANGUAGE + }, + "dabing": { + "ref_audio": DABING_REF_AUDIO, + "ref_text": DABING_REF_TEXT, + "ref_language": REF_LANGUAGE + }, + "luoxiang": { + "ref_audio": LUOXIANG_REF_AUDIO, + "ref_text": LUOXIANG_REF_TEXT, + "ref_language": REF_LANGUAGE + }, + "xuzhiyuan": { + "ref_audio": XUZHIYUAN_REF_AUDIO, + "ref_text": XUZHIYUAN_REF_TEXT, + "ref_language": REF_LANGUAGE + } +} + +def get_audio_hash(text): + return hashlib.md5(text.encode()).hexdigest() + +# 在启动时初始化模型 +print("正在初始化模型...") +change_gpt_weights(gpt_path=GPT_MODEL_PATH) +change_sovits_weights(sovits_path=SOVITS_MODEL_PATH) +print("模型初始化成功。") + +def read_ref_text(voice_type): + ref_text_path = VOICE_CONFIGS[voice_type]["ref_text"] + ref_text = "" + try: + if os.path.exists(ref_text_path): + with open(ref_text_path, 'r', encoding='utf-8') as file: + ref_text = file.read() + else: + print(f"警告:{voice_type} 的参考文本文件 '{ref_text_path}' 不存在。") + except IOError as e: + print(f"错误:无法读取 {voice_type} 的参考文本文件 '{ref_text_path}'。{str(e)}") + return ref_text + +def synthesize(target_text, output_wav_path, voice): + voice_config = VOICE_CONFIGS[voice] + ref_audio_path = voice_config["ref_audio"] + + with open(voice_config["ref_text"], 'r', encoding='utf-8') as file: + ref_text = file.read() + + with torch.cuda.device(device): + synthesis_result = get_tts_wav( + ref_wav_path=ref_audio_path, + prompt_text=ref_text, + prompt_language=i18n(voice_config["ref_language"]), + text=target_text, + text_language=i18n(TARGET_LANGUAGE), + top_p=1, + temperature=1 + ) + + result_list = list(synthesis_result) + + if result_list: + last_sampling_rate, last_audio_data = result_list[-1] + sf.write(output_wav_path, last_audio_data, last_sampling_rate) + return output_wav_path + else: + return None + +def kafka_consumer(): + consumer = KafkaConsumer( + KAFKA_TTS_TOPIC, + bootstrap_servers=KAFKA_BROKER, + auto_offset_reset='latest', + value_deserializer=lambda m: json.loads(m.decode('utf-8')) + ) + print(f"TTS消费者已启动") + for message in consumer: + task_id = None + error_occurred = False # 将这行移到循环的开始 + try: + task_id = message.value['task_id'] + target_text = message.value['text'] + text_hash = message.value['text_hash'] + voice = message.value.get('voice', 'default') + + if voice == 'default': + voice = 'girl' + if voice not in VOICE_CONFIGS: + print(f"警告:无效的音色类型 '{voice}'。使用默认音色。") + voice = "girl" + + # 更新任务状态为 "processing" + redis_task_client.set(f"task_status:tts:{task_id}", "processing") + + # 使用对应音色的 Redis 客户端 + redis_tts = voice_to_redis[voice] + + # 检查是否已存在相同内容的音频文件 + existing_audio_info = redis_tts.get(f"tts:{text_hash}") + if existing_audio_info: + existing_audio_path = json.loads(existing_audio_info)['path'] + if os.path.exists(existing_audio_path): + # 如果文件已存在,直接使用现有文件 + output_path = existing_audio_path + else: + # 如果文件不存在,重新生成 + output_wav_path = os.path.join(OUTPUT_PATH, f"{text_hash}_{voice}.wav") + output_path = synthesize(target_text, output_wav_path, voice) + else: + # 如果不存在,创建新的音频文件 + output_wav_path = os.path.join(OUTPUT_PATH, f"{text_hash}_{voice}.wav") + output_path = synthesize(target_text, output_wav_path, voice) + + if output_path: + # 将结果保存在对应音色的 Redis 中 + redis_tts.set(f"tts:{text_hash}", json.dumps({"path": output_path})) + print(f"音频合成成功: {output_path}") + + # 更新任务状态为 "completed" + redis_task_client.set(f"task_status:tts:{task_id}", "completed") + + # 存储任务信息 + redis_task_client.set(f"task_info:tts:{task_id}", json.dumps({ + "text_hash": text_hash, + "voice": voice + })) + else: + print("音频合成失败") + error_occurred = True + + except KeyError as e: + print(f"错误:消息中缺少必要的键: {e}") + error_occurred = True + except Exception as e: + print(f"处理消息时出错: {str(e)}") + print(traceback.format_exc()) + error_occurred = True + finally: + if error_occurred: + print("处理消息时发生错误") + if task_id: + redis_task_client.set(f"task_status:tts:{task_id}", "failed") + else: + print("消息处理完成") +if __name__ == "__main__": + torch.cuda.set_device(device) + kafka_consumer() \ No newline at end of file diff --git a/api_chat/wav_to_text.py b/api_chat/wav_to_text.py new file mode 100644 index 0000000..7dd8527 --- /dev/null +++ b/api_chat/wav_to_text.py @@ -0,0 +1,68 @@ +import os +import whisper +import argparse + +def transcribe_audio(model, audio_path): + """ + 使用Whisper模型转录音频文件 + + :param model: 加载的Whisper模型 + :param audio_path: 音频文件路径 + :return: 转录的文本 + """ + try: + result = model.transcribe(audio_path) + return result["text"] + except Exception as e: + print(f"转录失败 {audio_path}: {str(e)}") + return None + +def process_directory(directory, model): + """ + 处理目录中的所有WAV文件 + + :param directory: 包含WAV文件的目录路径 + :param model: 加载的Whisper模型 + """ + for filename in os.listdir(directory): + if filename.lower().endswith('.wav'): + input_file = os.path.join(directory, filename) + output_file = os.path.splitext(input_file)[0] + ".txt" + + print(f"正在处理: {input_file}") + transcription = transcribe_audio(model, input_file) + + if transcription: + with open(output_file, 'w', encoding='utf-8') as f: + f.write(transcription) + print(f"转录完成: {output_file}") + else: + print(f"转录失败: {input_file}") + +def main(): + parser = argparse.ArgumentParser(description="使用Whisper将WAV文件转换为文本") + parser.add_argument("input_path", help="输入的WAV文件或包含WAV文件的目录路径") + parser.add_argument("--model", default="small", choices=["tiny", "base", "small", "medium", "large", "large-v3"], help="Whisper模型大小") + args = parser.parse_args() + + print(f"正在加载Whisper模型 ({args.model})...") + model = whisper.load_model(args.model) + print("模型加载完成") + + if os.path.isfile(args.input_path): + if not args.input_path.lower().endswith('.wav'): + print("错误: 输入文件不是WAV格式") + return + output_file = os.path.splitext(args.input_path)[0] + ".txt" + transcription = transcribe_audio(model, args.input_path) + if transcription: + with open(output_file, 'w', encoding='utf-8') as f: + f.write(transcription) + print(f"转录完成: {output_file}") + elif os.path.isdir(args.input_path): + process_directory(args.input_path, model) + else: + print("错误: 输入路径既不是文件也不是目录") + +if __name__ == "__main__": + main() diff --git a/api_chat/weight.json b/api_chat/weight.json new file mode 100644 index 0000000..58ec624 --- /dev/null +++ b/api_chat/weight.json @@ -0,0 +1 @@ +{"GPT": {"v2": "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt"}, "SoVITS": {"v2": "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth"}} \ No newline at end of file diff --git a/api_history/OpenBMB/.gitkeep b/api_history/OpenBMB/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/api_history/function/config.yaml b/api_history/function/config.yaml new file mode 100644 index 0000000..319e651 --- /dev/null +++ b/api_history/function/config.yaml @@ -0,0 +1,35 @@ +kafka: + bootstrap_servers: + - "222.186.136.78:9092" + value_serializer: "json" + topics: + all_frames: + name: "pose-input" + num_consumers: 3 + ten_seconds: + name: "cpm-input" + num_consumers: 1 + input_topic: + name: "raw-data" + num_consumers: 3 + +redis: + host: "222.186.136.78" + port: 6379 + db: 0 + password: "Obscura@2024" + +minio: + endpoint: "api.obscura.work" + access_key: "00v3MtLtIAIkR3hkIuYR" + secret_key: "XfDeVe5bJjPU21NEYc023gzJVUTJzQqxsWHqIKMf" + secure: true + +mongodb: + uri: "mongodb://minio_mongo:BCd4npzKBnwmCRdh@222.186.136.78:27017/minio_mongo" + db_name: "minio_mongo" + +model: + pose-path: "worker_sys/function/yolov8n-pose.pt" + cpm-path: "worker_sys/OpenBMB/MiniCPM-V-2_6" + diff --git a/api_history/function/cpm.py b/api_history/function/cpm.py new file mode 100644 index 0000000..f2fce55 --- /dev/null +++ b/api_history/function/cpm.py @@ -0,0 +1,362 @@ +import json +import io +from PIL import Image +import torch +from kafka import KafkaConsumer, KafkaProducer +from transformers import AutoModel, AutoTokenizer +import threading +import re +from datetime import datetime +import time +import base64 +import numpy as np +import cv2 +import redis +from redis import ConnectionPool +import yaml + +# 加载配置文件 +with open('worker_sys/function/config.yaml', 'r') as file: + config = yaml.safe_load(file) + +class JSONEncoder(json.JSONEncoder): + def default(self, obj): + if isinstance(obj, datetime): + return obj.isoformat() + return json.JSONEncoder.default(self, obj) + +class ImageSequenceProcessor: + def __init__(self, model_dir): + self.model = AutoModel.from_pretrained(model_dir, trust_remote_code=True, + attn_implementation='sdpa', torch_dtype=torch.bfloat16).eval().cuda() + self.tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True) + self.max_size = 512 # 设置最大尺寸 + def extract_time_from_key(self, key_name): + # 从key_name中提取时间信息 + time_str = key_name.split('_')[-2] + '_' + key_name.split('_')[-1].split('.')[0] + return datetime.strptime(time_str, "%Y%m%d_%H%M%S") + + def process_image_sequence(self, image_data, key_names): + frames = [] + image_times = [] + for i, (img, key_name) in enumerate(zip(image_data, key_names)): + try: + if isinstance(img, np.ndarray): + # 确保图像是 RGB 格式 + if img.shape[2] == 3: + frame = Image.fromarray(img) + else: + print(f"Unexpected number of channels for image {i}: {img.shape[2]}") + continue + else: + print(f"Unexpected data type for image {i}: {type(img)}") + continue + frames.append(frame) + image_times.append(self.extract_time_from_key(key_name)) + print(f"Successfully processed frame {i}") + except Exception as e: + print(f"Error processing frame {i}: {str(e)}") + continue + + if not frames: + raise ValueError("No valid frames were processed") + + # 修改时间范围格式 + start_time = min(image_times) + end_time = max(image_times) + time_range = { + 'start': start_time.strftime('%Y-%m-%d %H:%M'), + 'end': end_time.strftime('%Y-%m-%d %H:%M') + } + + # # 计算平均时间间隔(以分钟为单位) + # if len(image_times) > 1: + # sequence_period_seconds = (image_times[-1] - image_times[0]).total_seconds()/ (len(image_times) - 1) + # else: + # sequence_period_seconds = 0 + + total_duration = (end_time - start_time).total_seconds() + num_images = len(image_times) + if num_images > 1: + sequence_period_seconds = total_duration / (num_images - 1) + else: + sequence_period_seconds = 0 + + question = "Analyze these 3 images as if they were frames from a video. Describe the scene in detail in Chinese, including the setting, number of people, their actions, and any changes or movements observed across the frames." + msgs = [ + {'role': 'user', 'content': frames + [question]}, + ] + + params = { + "use_image_id": False, + "max_slice_nums": 1 + } + + answer = self.model.chat( + image=None, + msgs=msgs, + tokenizer=self.tokenizer, + **params + ) + + extracted_info = self.extract_info(answer) + + return { + "original_answer": answer, + "extracted_info": extracted_info, + "num_frames": len(frames), + "time_range": time_range, + "sequence_period_seconds": sequence_period_seconds + } + + @staticmethod + def extract_info(answer): + info = { + "environment": None, + "num_people": None, + "actions": [], + "interactions": [], + "objects": [], + "furniture": [] + } + + # 环境提取 + environments = ["办公室", "室内", "室外", "会议室", "办公"] + for env in environments: + if env in answer.lower(): + info["environment"] = env + break + + # 改进的人数提取 + people_patterns = [ + r'(\d+)\s*(人|个人|位|名|员工|用户|小朋友|成年人|女性|男性)', + r'(一|二|三|四|五|六|七|八|九|十)\s*(人|个人|位|名|员工|用户|小朋友|成年人|女性|男性)', + r'(一个|几个)\s*(人|个人|员工|用户|小朋友|成年人|女性|男性)', + r'几\s*(名|位)\s*(人|员工|用户|小朋友|成年人|女性|男性)?', + r'(男|女)(性|生|士)', + r'(成年|未成年|青少年|老年)\s*(人|群体)', + r'(员工|职工|工人|学生|顾客|观众|游客|乘客)', + r'(群众|民众|大众|公众)', + r'(男女|老少|老幼|大人|小孩)' + ] + for pattern in people_patterns: + match = re.search(pattern, answer) + if match: + if match.group(1).isdigit(): + info["num_people"] = int(match.group(1)) + elif match.group(1) in ['一个', '一']: + info["num_people"] = 1 + else: + num_word_to_digit = { + '二': 2, '三': 3, '四': 4, '五': 5, + '六': 6, '七': 7, '八': 8, '九': 9, '十': 10 + } + info["num_people"] = num_word_to_digit.get(match.group(1), 0) + break + + # 动作和互动提取 + actions = ["坐", "站", "摔倒", "跳舞", "转身", "摔", "倒", "倒下", "躺下", "转身", "跳跃", "跳", "躺", "睡", "说话"] + for action in actions: + if action in answer: + info["actions"].append(action) + + interactions = ["互动", "交流", "身体语言", "交谈", "讨论", "开会"] + for interaction in interactions: + if interaction in answer: + info["interactions"].append(interaction) + + # 物体和家具提取 + objects = ["水瓶", "办公用品", "文件", "电脑"] + furniture = ["椅子", "桌子", "咖啡桌", "文件柜", "床", "沙发"] + + for obj in objects: + if obj in answer: + info["objects"].append(obj) + + for item in furniture: + if item in answer: + info["furniture"].append(item) + + return info + +class ImageSequenceAnalysisSystem: + def __init__(self, model_dir): + self.image_processor = ImageSequenceProcessor(model_dir) + + def process_image_sequence(self, image_data, cache_keys): + print(f"Attempting to process sequence of {len(image_data)} images") + start_time = time.time() + try: + print("Processing new image sequence...") + for i, (img, cache_key) in enumerate(zip(image_data, cache_keys)): + print(f"Image {i} type: {type(img)}") + if isinstance(img, np.ndarray): + print(f"Image {i} shape: {img.shape}, dtype: {img.dtype}") + else: + print(f"Unexpected data type for image {i}") + print(f"Image {i} cache_key: {cache_key}") + + result = self.image_processor.process_image_sequence(image_data, cache_keys) + + end_time = time.time() + processing_time = end_time - start_time + + print(f"Processed image sequence for time range: {result['time_range']}") + print(f"Average time between frames: {result['sequence_period_seconds']:.2f} minutes") + print(f"Processing time: {processing_time:.2f} seconds") + + return result + + except Exception as e: + end_time = time.time() + processing_time = end_time - start_time + + print(f"Error processing image sequence: {str(e)}") + print(f"Processing time (including error): {processing_time:.2f} seconds") + import traceback + traceback.print_exc() + return None + +class ImageProcessingNode: + def __init__(self, bootstrap_servers, input_topic, model_dir, group_id, redis_pool_db0, redis_pool_db1): + self.consumer = KafkaConsumer( + input_topic, + bootstrap_servers=bootstrap_servers, + group_id=group_id, + value_deserializer=lambda x: json.loads(x.decode('utf-8')), + auto_offset_reset='earliest' + ) + self.producer = KafkaProducer( + bootstrap_servers=bootstrap_servers, + value_serializer=lambda x: json.dumps(x, cls=JSONEncoder).encode('utf-8') + ) + self.input_topic = input_topic + self.analysis_system = ImageSequenceAnalysisSystem(model_dir) + self.group_id = group_id + self.redis_pool_db0 = redis_pool_db0 + self.redis_pool_db1 = redis_pool_db1 + self.lock = threading.Lock() + self.image_buffer = [] + self.image_info_buffer = [] + self.cache_key_buffer = [] # 新增:用于存储cache_key + + + def process_and_produce(self, message_value): + cache_key = message_value.get('cache_key') + etag = message_value.get('etag') + size = message_value.get('size') + object_key = message_value.get('object') + + # print(f"Consumer {self.group_id} received image with key_hash: {key_hash}") + + try: + with redis.Redis(connection_pool=self.redis_pool_db0) as redis_client_db0: + img_str = redis_client_db0.get(cache_key) + if img_str is None: + print(f"Error: Image data not found in Redis db0 for cache_key: {cache_key}") + return + + img_data = base64.b64decode(img_str) + nparr = np.frombuffer(img_data, np.uint8) + img = cv2.imdecode(nparr, cv2.IMREAD_COLOR) + + self.image_buffer.append(img) + self.image_info_buffer.append({ + 'key_name': cache_key, # 使用 cache_key 作为 key_name + 'etag': etag, + 'size': size, + 'object_key': object_key + # 'key_hash': key_hash + }) + self.cache_key_buffer.append(cache_key) + + if len(self.image_buffer) == 3: + result = self.analysis_system.process_image_sequence(self.image_buffer, self.cache_key_buffer) + + if result: + result_data = { + 'results': result, + 'image_sequence': self.image_info_buffer + } + + with self.lock: + with redis.Redis(connection_pool=self.redis_pool_db1) as redis_client_db1: + # 使用第一张图片的 cache_key 作为 sequence_key + sequence_key = f"{self.image_info_buffer[0]['key_name']}" + redis_client_db1.set(sequence_key, json.dumps(result_data)) + print(f"Consumer {self.group_id} processed and saved results to Redis db1 with sequence_key {sequence_key}") + print(f"Processed image sequence: {[info['key_name'] for info in self.image_info_buffer]}") + + # 清空缓冲区 + self.image_buffer = [] + self.image_info_buffer = [] + self.cache_key_buffer = [] + + except Exception as e: + print(f"Error processing image: {str(e)}") + import traceback + traceback.print_exc() + + def run(self): + print(f"Consumer {self.group_id} starting to consume messages from {self.input_topic}") + while True: + messages = self.consumer.poll(timeout_ms=1000) + for tp, records in messages.items(): + for record in records: + message_value = record.value + self.process_and_produce(message_value) + + # 每分钟打印一次当前缓冲区状态 + if len(self.image_buffer) < 3: + print(f"Still waiting for more images. Current buffer size: {len(self.image_buffer)}") + time.sleep(60) # 等待60秒(1分钟) + +def start_consumer(kafka_bootstrap_servers, input_topic, model_dir, group_id, redis_host, redis_port, redis_password): + redis_pool_db0 = ConnectionPool(host=redis_host, port=redis_port, db=0, password=redis_password) + redis_pool_db1 = ConnectionPool(host=redis_host, port=redis_port, db=2, password=redis_password) + + node = ImageProcessingNode( + kafka_bootstrap_servers, + input_topic, + model_dir, + group_id, + redis_pool_db0, + redis_pool_db1 + ) + print(f"Image Processing Node {group_id} initialized.") + print(f"Listening on topic: {input_topic}") + node.run() + + +if __name__ == "__main__": + # Configuration + # kafka_bootstrap_servers = ['222.186.136.78:9092'] + # input_topic = 'cpm-input' + # model_dir = 'worker_sys/OpenBMB/MiniCPM-V-2_6' + # redis_host = '222.186.136.78' + # redis_port = 6379 + # redis_password = 'Obscura@2024' + # num_consumers = 1 + + kafka_bootstrap_servers = config['kafka']['bootstrap_servers'] + input_topic = config['kafka']['topics']['ten_seconds']['name'] + model_dir = config['model']['cpm-path'] + redis_host = config['redis']['host'] + redis_port = config['redis']['port'] + redis_password = config['redis']['password'] + num_consumers = config['kafka']['topics']['ten_seconds']['num_consumers'] + + # 创建多个消费者线程 + consumer_threads = [] + for i in range(num_consumers): + group_id = f'cpm_group_{i}' + thread = threading.Thread( + target=start_consumer, + args=(kafka_bootstrap_servers, input_topic, model_dir, group_id, redis_host, redis_port, redis_password) + ) + consumer_threads.append(thread) + thread.start() + + # 等待所有消费者线程完成 + for thread in consumer_threads: + thread.join() diff --git a/api_history/function/pose.py b/api_history/function/pose.py new file mode 100644 index 0000000..a63e3bc --- /dev/null +++ b/api_history/function/pose.py @@ -0,0 +1,183 @@ + +import json +import cv2 +import numpy as np +from kafka import KafkaConsumer +from ultralytics import YOLO +import threading +import redis +import base64 +import yaml +from PIL import Image + +# 加载配置文件 +with open('worker_sys/function/config.yaml', 'r') as file: + config = yaml.safe_load(file) + +class YOLOv8nPoseProcessor: + def __init__(self, model_path): + self.model = YOLO(model_path) + + def process_image(self, img): + results = self.model(img) + return results + + def format_results(self, results): + result_data = [] + for result in results: + boxes = result.boxes.xywh.tolist() + keypoints = result.keypoints.xy.tolist() if hasattr(result, 'keypoints') else None + classes = result.boxes.cls.tolist() + confs = result.boxes.conf.tolist() + + for i, (box, cls, conf) in enumerate(zip(boxes, classes, confs)): + result_data.append({ + 'box': box, + 'keypoints': keypoints[i] if keypoints else None, + 'class': int(cls), + 'class_name': self.model.names[int(cls)], + 'confidence': float(conf) + }) + + if not result_data: + result_data.append({ + 'box': None, + 'keypoints': None, + 'class': None, + 'class_name': None, + 'confidence': None, + 'message': 'No object detected' + }) + + return json.dumps(result_data) + +class ImageProcessingNode: + def __init__(self, bootstrap_servers, input_topic, model_path, group_id, redis_host, redis_port, redis_password): + self.consumer = KafkaConsumer( + input_topic, + bootstrap_servers=bootstrap_servers, + group_id=group_id, + value_deserializer=lambda x: json.loads(x.decode('utf-8')), + auto_offset_reset='earliest' + ) + self.input_topic = input_topic + self.yolo_processor = YOLOv8nPoseProcessor(model_path) + self.group_id = group_id + self.redis_client_db0 = redis.Redis(host=redis_host, port=redis_port, db=0, password=redis_password) + self.redis_client_db1 = redis.Redis(host=redis_host, port=redis_port, db=1, password=redis_password) + + def process_and_produce(self, message): + + cache_key = message.get('cache_key') + object_key = message.get('object') + etag = message.get('etag') + size = message.get('size') + + print(f"Consumer {self.group_id} processing image with cache_key: {cache_key}") + + + # 从Redis db0获取图片数据 + img_str = self.redis_client_db0.get(cache_key) + if img_str is None: + print(f"Error: Image data not found in Redis db0 for cache_key: {cache_key}") + return + + # 将base64编码的图片数据转换为OpenCV格式 + try: + img_data = base64.b64decode(img_str) + nparr = np.frombuffer(img_data, np.uint8) + img = cv2.imdecode(nparr, cv2.IMREAD_COLOR) + + if img is None: + print(f"Error: Unable to decode image data for cache_key: {cache_key}") + return + + results = self.yolo_processor.process_image(img) + formatted_results = self.yolo_processor.format_results(results) + + # 创建包含所有必要信息的字典 + result_data = { + "pose_results": json.loads(formatted_results), + 'cache_key': cache_key, + "object": object_key, + "etag": etag, + "size": size + } + + # 将结果保存到Redis db1 + try: + serialized_data = json.dumps(result_data) + if not serialized_data: + print(f"Error: Serialized data is empty for cache_key: {cache_key}") + return + + self.redis_client_db1.set(cache_key, serialized_data) + print(f"Consumer {self.group_id} processed and saved results to Redis db1 with cache_key {cache_key}") + except TypeError as e: + print(f"Error serializing data: {e}") + print(f"Problematic data: {result_data}") + except Exception as e: + print(f"Unexpected error when saving to Redis: {e}") + except Exception as e: + print(f"Error processing image: {e}") + + + def run(self): + print(f"Consumer {self.group_id} starting to consume messages from {self.input_topic}") + for message in self.consumer: + message_value = message.value + cache_key = message_value.get('cache_key') + if cache_key: + threading.Thread(target=self.process_and_produce, args=(message_value,)).start() + else: + print(f"Consumer {self.group_id} error: Received message without cache_key") + +def start_consumer(kafka_bootstrap_servers, input_topic, model_path, group_id, redis_host, redis_port, redis_password): + node = ImageProcessingNode( + kafka_bootstrap_servers, + input_topic, + model_path, + group_id, + redis_host, + redis_port, + redis_password + ) + print(f"Image Processing Node {group_id} initialized.") + print(f"Listening on topic: {input_topic}") + node.run() + +if __name__ == "__main__": + # 创建多个消费者线程 + # Kafka 配置 + kafka_bootstrap_servers = config['kafka']['bootstrap_servers'] + input_topic = config['kafka']['topics']['all_frames']['name'] + num_consumers = config['kafka']['topics']['all_frames']['num_consumers'] + # 模型路径 + model_path = config['model']['pose-path'] + # Redis 配置 + redis_host = config['redis']['host'] + redis_port = config['redis']['port'] + redis_password = config['redis']['password'] + + + consumer_threads = [] + for i in range(num_consumers): + group_id = f'pose_group_{i}' + thread = threading.Thread( + target=start_consumer, + args=( + kafka_bootstrap_servers, + input_topic, + model_path, + group_id, + redis_host, + redis_port, + redis_password + ) + ) + consumer_threads.append(thread) + thread.start() + + # 等待所有消费者线程完成 + for thread in consumer_threads: + thread.join() diff --git a/api_history/function/producer-minio.py b/api_history/function/producer-minio.py new file mode 100644 index 0000000..60c0d40 --- /dev/null +++ b/api_history/function/producer-minio.py @@ -0,0 +1,204 @@ +import json +import yaml +from kafka import KafkaConsumer, KafkaProducer +import time +from datetime import datetime +import redis +from minio import Minio +import io +from PIL import Image +import base64 +import traceback + + +# 加载配置文件 +with open('worker_sys/function/config.yaml', 'r') as file: + config = yaml.safe_load(file) + +# Kafka 配置 +kafka_config = { + 'bootstrap_servers': config['kafka']['bootstrap_servers'], + 'value_serializer': lambda v: json.dumps(v).encode('utf-8') if config['kafka']['value_serializer'] == 'json' else None +} + +# 创建Kafka生产者 +producer = KafkaProducer(**kafka_config) + +# 创建Kafka消费者(用于输入主题) +consumer = KafkaConsumer( + config['kafka']['input_topic']['name'], + group_id='image-processor', + bootstrap_servers=config['kafka']['bootstrap_servers'] +) + +# Redis 配置 +redis_config = config['redis'] + +# 创建 Redis 客户端 +redis_client = redis.Redis(**redis_config) + +# MinIO 配置 +minio_config = config['minio'] + +# 创建 MinIO 客户端 +minio_client = Minio( + minio_config['endpoint'], + access_key=minio_config['access_key'], + secret_key=minio_config['secret_key'], + secure=minio_config['secure'] +) + +# Kafka topics +topic_all_frames = config['kafka']['topics']['all_frames']['name'] +topic_ten_seconds = config['kafka']['topics']['ten_seconds']['name'] +topic_input = config['kafka']['input_topic']['name'] + +# 消费者数量 +NUM_CONSUMERS_ALL_FRAMES = config['kafka']['topics']['all_frames']['num_consumers'] +NUM_CONSUMERS_TEN_SECONDS = config['kafka']['topics']['ten_seconds']['num_consumers'] +NUM_CONSUMERS_INPUT = config['kafka']['input_topic']['num_consumers'] + +def parse_key_name(key_name): + parts = key_name.split('/') + bucket = parts[0] + image_name = '/'.join(parts[1:]) + image_parts = image_name.rsplit('.', 1)[0].split('_') + camera_name = image_parts[0] + timestamp = '_'.join(image_parts[1:3]) + return bucket, camera_name, timestamp + +def should_send_to_ten_seconds_topic(timestamp): + try: + dt = datetime.strptime(timestamp, '%Y%m%d_%H%M%S') + return dt.second % 10 == 0 + except ValueError as e: + print(f"Error parsing timestamp '{timestamp}': {str(e)}") + return False + +def get_image_from_minio_and_cache(key_name): + bucket, object_name = key_name.split('/', 1) + cache_key = f"{key_name}" + + try: + response = minio_client.get_object(bucket, object_name) + image_data = response.read() + print(f"Successfully retrieved image from MinIO: {bucket}/{object_name}") + + image = Image.open(io.BytesIO(image_data)) + buffered = io.BytesIO() + image.save(buffered, format="JPEG") + img_str = base64.b64encode(buffered.getvalue()).decode() + + redis_client.setex(cache_key, 86400, img_str) + print(f"Successfully cached image in Redis: {cache_key}") + + return cache_key + except Exception as e: + print(f"Error in get_image_from_minio_and_cache: {str(e)}") + print("Traceback:") + traceback.print_exc() + return None + +def process_message(message): + + value_data = json.loads(message.value.decode('utf-8')) + key_name = value_data['Key'] + print(f"Processing key: {key_name}") + + # 从 key_name 中提取 bucket + parts = key_name.split('/', 1) + if len(parts) != 2: + print(f"Error: Invalid key format. Expected 'bucket/object', got '{key_name}'") + return + + bucket, object_name = parts + + # 解析 camera_name 和 timestamp + camera_name, timestamp = parse_object_name(object_name) + + cache_key = get_image_from_minio_and_cache(key_name) + + if cache_key is None: + print(f"Failed to process image: {key_name}") + return + # 提取 etag 和 size 信息 + object_info = value_data['Records'][0]['s3']['object'] + etag = object_info.get('eTag', '') + size = object_info.get('size', 0) + + + message_data = { + 'bucket': bucket, + 'camera_name': camera_name, + 'timestamp': timestamp, + 'object': object_name, + 'cache_key': cache_key, + 'etag': etag, + 'size': size + } + + # 发送到 pose-input 主题 + producer.send(topic_all_frames, value=message_data) + print(f"Sent message to {topic_all_frames}: {key_name}") + + # 发送到 cpm-input 主题 + producer.send(topic_ten_seconds, value=message_data) + print(f"Sent message to {topic_ten_seconds}: {key_name}") + + # #只有在满足特定条件时才发送到 cpm-input 主题 + # if should_send_to_ten_seconds_topic(timestamp): + # producer.send(TOPIC_TEN_SECONDS, value=message_data) + # print(f"Sent message to {TOPIC_TEN_SECONDS}: {key_name}") + + producer.flush() + print(f"Successfully processed and sent messages for: {key_name}") + + +def parse_object_name(object_name): + # 假设对象名格式为 "cameraX_YYYYMMDD_HHMMSS.jpg" + parts = object_name.split('_') + if len(parts) != 3: + raise ValueError(f"Invalid object name format: {object_name}") + + camera_name = parts[0] + timestamp = f"{parts[1]}_{parts[2].split('.')[0]}" + return camera_name, timestamp + +def get_image_from_minio_and_cache(key_name): + print(f"Received key_name: {key_name}") + + try: + bucket, object_name = key_name.split('/', 1) + except ValueError as e: + print(f"Error splitting key_name '{key_name}': {str(e)}") + return None + + cache_key = f"{key_name}" + + try: + response = minio_client.get_object(bucket, object_name) + image_data = response.read() + print(f"Successfully retrieved image from MinIO: {bucket}/{object_name}") + + image = Image.open(io.BytesIO(image_data)) + buffered = io.BytesIO() + image.save(buffered, format="JPEG") + img_str = base64.b64encode(buffered.getvalue()).decode() + + redis_client.setex(cache_key, 86400, img_str) + print(f"Successfully cached image in Redis: {cache_key}") + + return cache_key + except Exception as e: + print(f"Error in get_image_from_minio_and_cache: {str(e)}") + print("Traceback:") + traceback.print_exc() + return None + +def main(): + print("Starting image processing service...") + for message in consumer: + process_message(message) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/api_history/function/sync.py b/api_history/function/sync.py new file mode 100644 index 0000000..3148628 --- /dev/null +++ b/api_history/function/sync.py @@ -0,0 +1,91 @@ +import redis +import pymongo +import time +import json +import yaml + +def sync_data(redis_client, mongo_db, redis_db, mongo_collection_name): + mongo_collection = mongo_db[mongo_collection_name] + all_keys = redis_client.keys('*') + synced_count = 0 + + for key in all_keys: + try: + poem_data = redis_client.get(key) + + if poem_data: + poem_dict = json.loads(poem_data) + + # 检查MongoDB中是否已存在相同的文档 + existing_doc = mongo_collection.find_one({'_id': key.decode('utf-8')}) + + if not existing_doc or existing_doc != poem_dict: + mongo_collection.update_one( + {'_id': key.decode('utf-8')}, + {'$set': poem_dict}, + upsert=True + ) + print(f"Synced data with key: {key} from Redis DB {redis_db} to MongoDB collection '{mongo_collection_name}'") + synced_count += 1 + except json.JSONDecodeError: + print(f"Error decoding JSON for key: {key} in Redis DB {redis_db}") + except Exception as e: + print(f"Error syncing data for key {key} in Redis DB {redis_db}: {str(e)}") + + return synced_count + +def main(config): + # Redis配置 + redis_host = config['redis']['host'] + redis_port = config['redis']['port'] + redis_password = config['redis']['password'] + + # MongoDB配置 + mongo_uri = config['mongodb']['uri'] + mongo_db_name = config['mongodb']['db_name'] + + # 连接到MongoDB + mongo_client = pymongo.MongoClient(mongo_uri) + mongo_db = mongo_client[mongo_db_name] + + # 固定的Redis数据库和MongoDB集合映射 + db_collection_map = { + 0: 'pose-result-db0', + 1: 'pose-result-db1', + 2: 'cpm-result-db2' + } + + print("Selected databases and collections for syncing:") + for db, collection in db_collection_map.items(): + print(f" Redis DB {db} -> MongoDB collection '{collection}'") + + while True: + print("Starting sync...") + total_synced = 0 + for db, collection in db_collection_map.items(): + print(f"Syncing Redis DB {db} to MongoDB collection '{collection}'...") + try: + redis_client = redis.Redis(host=redis_host, port=redis_port, db=db, password=redis_password) + synced_count = sync_data(redis_client, mongo_db, db, collection) + total_synced += synced_count + except redis.exceptions.AuthenticationError: + print(f"Error: Authentication failed for Redis DB {db}. Skipping...") + except redis.exceptions.ConnectionError: + print(f"Error: Unable to connect to Redis DB {db}. Skipping...") + except Exception as e: + print(f"Error occurred while syncing Redis DB {db}: {str(e)}. Skipping...") + + if total_synced > 0: + print(f"Sync completed. {total_synced} documents synced. Waiting for next update...") + else: + print("No new data to sync. Waiting for next update...") + + time.sleep(300) # 等待5分钟后再次同步 + +if __name__ == "__main__": + # 加载配置文件 + with open('worker_sys/function/config.yaml', 'r') as file: + config = yaml.safe_load(file) + + # 运行主程序 + main(config) \ No newline at end of file diff --git a/api_history/local/local_cpm.py b/api_history/local/local_cpm.py new file mode 100644 index 0000000..7fe3fad --- /dev/null +++ b/api_history/local/local_cpm.py @@ -0,0 +1,299 @@ +import torch +from PIL import Image +from transformers import AutoModel, AutoTokenizer +import json +import re +from pymongo import MongoClient +import time +from bson import ObjectId +import os +import glob +from datetime import datetime, timedelta + +# 数据库连接模块 +class DatabaseHandler: + def __init__(self, mongo_uri, database_name, results_collection_name): + self.client = MongoClient(mongo_uri) + self.db = self.client[database_name] + self.results_collection = self.db[results_collection_name] + + def save_result(self, result): + # 如果 result 中没有 filename,使用时间戳作为替代 + # filename = result.get('filename', f"unknown_{result['timestamp']}") + filename = result.get('filename') + # 检查是否已存在相同 filename 的结果 + existing_result = self.results_collection.find_one({'filename': filename}) + if existing_result: + print(f"Video with filename {filename} has already been processed. Skipping.") + return + + # 确保 result 中有 filename + result['filename'] = filename + + # 将 ObjectId 转换为字符串 + if 'video_id' in result and isinstance(result['video_id'], ObjectId): + result['video_id'] = str(result['video_id']) + + self.results_collection.insert_one(result) + + def is_sequence_processed(self, filename): + return self.results_collection.find_one({'filename': filename}) is not None + + +class JSONEncoder(json.JSONEncoder): + def default(self, o): + if isinstance(o, ObjectId): + return str(o) + return super().default(o) + +# 视频处理模块 +class ImageSequenceProcessor: + def __init__(self, model_dir): + self.model = AutoModel.from_pretrained(model_dir, trust_remote_code=True, + attn_implementation='sdpa', torch_dtype=torch.bfloat16).eval().cuda() + self.tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True) + self.max_size = 512 # 设置最大尺寸 + + def compress_image(self, image): + # 保持纵横比的情况下调整图片大小 + image.thumbnail((self.max_size, self.max_size)) + + # 如果图像已经是JPEG格式,直接返回调整大小后的图像 + if image.format == 'JPEG': + return image + + # 对于非JPEG格式,进行压缩 + buffer = io.BytesIO() + if image.mode in ('RGBA', 'LA') or (image.mode == 'P' and 'transparency' in image.info): + # 保持透明度 + image.save(buffer, format="PNG", optimize=True) + else: + # 转换为JPEG并压缩 + image.convert('RGB').save(buffer, format="JPEG", quality=85, optimize=True) + + buffer.seek(0) + return Image.open(buffer) + + def process_image_sequence(self, image_paths): + frames = [self.compress_image(Image.open(img_path)) for img_path in image_paths] + question = "Analyze these 10 images as if they were frames from a video. Describe the scene in detail in Chinese, including the setting, number of people, their actions, and any changes or movements observed across the frames." + msgs = [ + {'role': 'user', 'content': frames + [question]}, + ] + + params = { + "use_image_id": False, + "max_slice_nums": 1 + } + + answer = self.model.chat( + image=None, + msgs=msgs, + tokenizer=self.tokenizer, + **params + ) + + extracted_info = self.extract_info(answer) + + return { + "original_answer": answer, + "extracted_info": extracted_info, + "num_frames": len(frames), + } + + @staticmethod + def extract_info(answer): + info = { + "environment": None, + "num_people": None, + "actions": [], + "interactions": [], + "objects": [], + "furniture": [] + } + + # 环境提取 + environments = ["办公室", "室内", "室外", "会议室", "办公"] + for env in environments: + if env in answer.lower(): + info["environment"] = env + break + + # 改进的人数提取 + people_patterns = [ + r'(\d+)\s*(人|个人|位|名|员工|用户|小朋友|成年人|女性|男性)', + r'(一|二|三|四|五|六|七|八|九|十)\s*(人|个人|位|名|员工|用户|小朋友|成年人|女性|男性)', + r'(一个|几个)\s*(人|个人|员工|用户|小朋友|成年人|女性|男性)', + r'几\s*(名|位)\s*(人|员工|用户|小朋友|成年人|女性|男性)?', + r'(男|女)(性|生|士)', + r'(成年|未成年|青少年|老年)\s*(人|群体)', + r'(员工|职工|工人|学生|顾客|观众|游客|乘客)', + r'(群众|民众|大众|公众)', + r'(男女|老少|老幼|大人|小孩)' + ] + for pattern in people_patterns: + match = re.search(pattern, answer) + if match: + if match.group(1).isdigit(): + info["num_people"] = int(match.group(1)) + elif match.group(1) in ['一个', '一']: + info["num_people"] = 1 + else: + num_word_to_digit = { + '二': 2, '三': 3, '四': 4, '五': 5, + '六': 6, '七': 7, '八': 8, '九': 9, '十': 10 + } + info["num_people"] = num_word_to_digit.get(match.group(1), 0) + break + + # 动作和互动提取 + actions = ["坐", "站", "摔倒", "跳舞", "转身", "摔", "倒", "倒下", "躺下", "转身", "跳跃", "跳", "躺", "睡", "说话"] + for action in actions: + if action in answer: + info["actions"].append(action) + + interactions = ["互动", "交流", "身体语言", "交谈", "讨论", "开会"] + for interaction in interactions: + if interaction in answer: + info["interactions"].append(interaction) + + # 物体和家具提取 + objects = ["水瓶", "办公用品", "文件", "电脑"] + furniture = ["椅子", "桌子", "咖啡桌", "文件柜", "床", "沙发"] + + for obj in objects: + if obj in answer: + info["objects"].append(obj) + + for item in furniture: + if item in answer: + info["furniture"].append(item) + + return info + +# 主处理类 +class ImageSequenceAnalysisSystem: + def __init__(self, mongo_uri, db_name, model_dir, results_collection_name): + self.db_handler = DatabaseHandler(mongo_uri, db_name, results_collection_name) + self.image_processor = ImageSequenceProcessor(model_dir) + self.last_processed_time = datetime.now() - timedelta(hours=1) + + def get_all_images(self, image_folders): + image_files = [] + for folder in image_folders: + image_files.extend(glob.glob(os.path.join(folder, '*.jpg'))) + image_files.sort() + return image_files + + + def process_image_sequence(self, image_paths): + print(f"Attempting to process sequence: {[os.path.basename(img) for img in image_paths]}") + start_time = time.time() + try: + # 使用第一张图片的文件名作为序列的标识符 + filename = os.path.basename(image_paths[0]) + + if self.db_handler.is_sequence_processed(filename): + print(f"Skipping already processed image sequence: {filename}") + return False + + print("Processing new image sequence...") + result = self.image_processor.process_image_sequence(image_paths) + + # timestamp = datetime.now() + # result['timestamp'] = timestamp.strftime("%Y%m%d_%H%M%S") + result['image_paths'] = image_paths + result['filename'] = filename + + # 计算图片序列的周期 + image_times = [self.get_file_time(img) for img in image_paths] + if len(image_times) >= 2: + time_diff = (image_times[-1] - image_times[0]).total_seconds() + period_minutes = time_diff / (len(image_times) - 1) / 60 + result['sequence_period_minutes'] = round(period_minutes, 2) + + # 添加时间段信息 + result['time_range'] = { + 'start': image_times[0].strftime("%Y-%m-%d %H:%M"), + 'end': image_times[-1].strftime("%Y-%m-%d %H:%M") + } + + save_result = self.db_handler.save_result(result) + print(f"Result saved to: {self.db_handler.results_collection.name}") + print(f"Result filename: {filename}") + return save_result + + except Exception as e: + end_time = time.time() + processing_time = end_time - start_time + + print(f"Error processing image sequence: {str(e)}") + print(f"Processing time (including error): {processing_time:.2f} seconds") + import traceback + traceback.print_exc() + return False + # @staticmethod + # def extract_time_from_filename(filename): + # # 假设文件名格式为 "YYYYMMDDHHMMSS.jpg" + # time_str = filename.split('.')[0] + # return datetime.strptime(time_str, "%Y%m%d%H%M") + + @staticmethod + def get_file_time(file_path): + # 获取文件的修改时间 + mod_time = os.path.getmtime(file_path) + return datetime.fromtimestamp(mod_time) + + def process_all_unprocessed_images(self, image_folders): + print(f"Searching for unprocessed images in: {image_folders}") + all_images = self.get_all_images(image_folders) + print(f"Found {len(all_images)} images in total") + selected_images = all_images[::10] + # Group images into sequences of 3 + image_sequences = [selected_images[i:i+10] for i in range(0, len(selected_images), 10)] + # image_sequences = [all_images[i:i+10] for i in range(0, len(all_images), 10)] + + processed_sequences = 0 + for sequence in image_sequences: + if len(sequence) == 10: + self.process_image_sequence(sequence) + processed_sequences += 1 + else: + print(f"Warning: Incomplete sequence. Found {len(sequence)} images.") + + if processed_sequences == 0: + print("All current photos have been processed. Waiting for new photos...") + else: + print(f"Processed {processed_sequences} sequences.") + + def run(self, root_folders): + print(f"Starting the system with root folder: {', '.join(root_folders)}") + + while True: + current_time = datetime.now() + time_since_last_process = (current_time - self.last_processed_time).total_seconds() + + if time_since_last_process >= 3600: # 1小时 = 3600秒 + self.process_all_unprocessed_images(root_folders) + self.last_processed_time = current_time + + # 计算下次检查的等待时间 + wait_time = max(0, 3600 - (datetime.now() - self.last_processed_time).total_seconds()) + print(f"Waiting for new photos... Next check in {wait_time:.0f} seconds.") + + time.sleep(60) # 每分钟检查一次是否需要处理 +# 使用示例 +if __name__ == "__main__": + mongo_uri = "mongodb://minio_mongo:BCd4npzKBnwmCRdh@222.186.136.78:27017/minio_mongo" + db_name = "minio_mongo" + results_collection_name = "cpm" + + model_dir = "worker_sys/OpenBMB/MiniCPM-V-2_6" + + root_folders = [ + "/www/wwwroot/zj.obscura.ac.cn/ipcam/Office/Cam2/CapturePics" , + "/www/wwwroot/zj.obscura.ac.cn/ipcam/Office/Cam1/CapturePics" + ] # 修改为 cam1 文件夹的路径 + + system = ImageSequenceAnalysisSystem(mongo_uri, db_name, model_dir, results_collection_name) + system.run(root_folders) \ No newline at end of file diff --git a/api_history/local/local_pose.py b/api_history/local/local_pose.py new file mode 100644 index 0000000..6fefc75 --- /dev/null +++ b/api_history/local/local_pose.py @@ -0,0 +1,194 @@ +import torch +from PIL import Image +import json +from pymongo import MongoClient +import os +import glob +from datetime import datetime, timedelta +from ultralytics import YOLO +from bson import ObjectId +import time + +# 数据库连接模块 +class DatabaseHandler: + def __init__(self, mongo_uri, database_name, results_collection_name): + self.client = MongoClient(mongo_uri) + self.db = self.client[database_name] + self.results_collection = self.db[results_collection_name] + + def save_result(self, result): + filename = result.get('filename', f"unknown_{result['timestamp']}") + + existing_result = self.results_collection.find_one({'filename': filename}) + if existing_result: + print(f"Image with filename {filename} has already been processed. Skipping.") + return False + + result['filename'] = filename + + if 'image_id' in result and isinstance(result['image_id'], ObjectId): + result['image_id'] = str(result['image_id']) + + self.results_collection.insert_one(result) + return True + + def is_image_processed(self, filename): + return self.results_collection.find_one({'filename': filename}) is not None + +class JSONEncoder(json.JSONEncoder): + def default(self, o): + if isinstance(o, ObjectId): + return str(o) + return super().default(o) + +# YOLOv8nPoseProcessor 类 +class YOLOv8nPoseProcessor: + def __init__(self, model_path): + self.model = YOLO(model_path) + + def process_image(self, img): + results = self.model(img) + return results + + def format_results(self, results): + result_data = [] + for result in results: + boxes = result.boxes.xywh.tolist() + keypoints = result.keypoints.xy.tolist() if hasattr(result, 'keypoints') else None + classes = result.boxes.cls.tolist() + confs = result.boxes.conf.tolist() + + for i, (box, cls, conf) in enumerate(zip(boxes, classes, confs)): + result_data.append({ + 'box': box, + 'keypoints': keypoints[i] if keypoints else None, + 'class': int(cls), + 'class_name': self.model.names[int(cls)], + 'confidence': float(conf) + }) + + if not result_data: + result_data.append({ + 'box': None, + 'keypoints': None, + 'class': None, + 'class_name': None, + 'confidence': None, + 'message': 'No object detected' + }) + + return json.dumps(result_data) + +# 主处理类 +class ImageAnalysisSystem: + def __init__(self, mongo_uri, db_name, model_path, results_collection_name): + self.db_handler = DatabaseHandler(mongo_uri, db_name, results_collection_name) + self.image_processor = YOLOv8nPoseProcessor(model_path) + self.last_processed_time = datetime.now() - timedelta(hours=1) + + def get_all_images(self, image_folders): + image_files = [] + for folder in image_folders: + image_files.extend(glob.glob(os.path.join(folder, '*.jpg'))) + image_files.sort() + return image_files + @staticmethod + def get_file_time(file_path): + # 获取文件的修改时间 + mod_time = os.path.getmtime(file_path) + return datetime.fromtimestamp(mod_time) + def process_image(self, image_path): + print(f"Attempting to process image: {os.path.basename(image_path)}") + try: + # json_folder = os.path.join("/www/wwwroot/zj.obscura.ac.cn/ipcam/Office/Cam2", 'json') + # json_filename = f"{os.path.basename(image_path).split('.')[0]}.json" + # json_path = os.path.join(json_folder, json_filename) + + # if os.path.exists(json_path): + # print(f"Skipping already processed image: {json_path}") + # return + + filename = os.path.basename(image_path) + + if self.db_handler.is_image_processed(filename): + print(f"Skipping already processed image: {filename}") + return False + + print("Processing new image...") + + image = Image.open(image_path) + results = self.image_processor.process_image(image) + formatted_results = self.image_processor.format_results(results) + + # timestamp = datetime.now() + file_timestamp = self.get_file_time(image_path) + result = { + 'timestamp': file_timestamp.strftime("%Y%m%d_%H%M%S"), + 'image_path': image_path, + 'filename': os.path.basename(image_path), + 'results': json.loads(formatted_results) + } + + # os.makedirs(json_folder, exist_ok=True) + # with open(json_path, 'w', encoding='utf-8') as f: + # json.dump(result, f, ensure_ascii=False, indent=4, cls=JSONEncoder) + + # self.db_handler.save_result(result) + # print(f"Processed image at: {timestamp}") + # print(f"JSON saved to: {json_path}") + # print(f"result saved to: {results_collection_name}") + if self.db_handler.save_result(result): + print(f"Result saved to: {self.db_handler.results_collection.name}") + return True + else: + print(f"Image {filename} was already in the database. Skipping.") + return False + + except Exception as e: + print(f"Error processing image: {str(e)}") + import traceback + traceback.print_exc() + + def process_all_unprocessed_images(self, image_folders): + print(f"Searching for unprocessed images in: {image_folders}") + all_images = self.get_all_images(image_folders) + print(f"Found {len(all_images)} images in total") + + processed_count = 0 + for image_path in all_images: + if self.process_image(image_path): + processed_count += 1 + + return processed_count + + def run(self, root_folder): + print(f"Starting the system with root folder: {', '.join(root_folder)}") + # image_folder = os.path.join(root_folder) + + while True: + print("Checking for unprocessed images...") + processed_count = self.process_all_unprocessed_images(root_folder) + + if processed_count > 0: + print(f"Finished processing {processed_count} images.") + else: + print("No new images to process. Waiting for new images...") + + # 等待一段时间后再次检查新图片 + time.sleep(60) # 每分钟检查一次是否有新图片 + +# 使用示例 +if __name__ == "__main__": + mongo_uri = "mongodb://minio_mongo:BCd4npzKBnwmCRdh@222.186.136.78:27017/minio_mongo" + db_name = "minio_mongo" + results_collection_name = "pose" + + model_path = "worker_sys/function/yolov8x-pose.pt" # 请确保这个路径指向你的YOLO-Pose模型文件 + + root_folder = [ + "/www/wwwroot/zj.obscura.ac.cn/ipcam/Office/Cam2/CapturePics" , + "/www/wwwroot/zj.obscura.ac.cn/ipcam/Office/Cam1/CapturePics" + ] # 修改为 cam1 文件夹的路径 + + system = ImageAnalysisSystem(mongo_uri, db_name, model_path, results_collection_name) + system.run(root_folder) \ No newline at end of file diff --git a/api_history/mini_douyin_time.py b/api_history/mini_douyin_time.py new file mode 100644 index 0000000..54af528 --- /dev/null +++ b/api_history/mini_douyin_time.py @@ -0,0 +1,255 @@ +import torch +from PIL import Image +from transformers import AutoModel, AutoTokenizer +from decord import VideoReader, cpu +import json +import re +from pymongo import MongoClient +import io +from minio import Minio +import time +from bson import ObjectId +import concurrent.futures +import os + + +# Minio连接模块 +class MinioHandler: + def __init__(self, endpoint, access_key, secret_key): + self.client = Minio( + endpoint, + access_key=access_key, + secret_key=secret_key, + secure=True + ) + + def get_video_data(self, bucket, object_name): + response = self.client.get_object(bucket, object_name) + data = response.read() + # print(f"Read {len(data)} bytes from Minio for {object_name}") + return data + +# 数据库连接模块 +class DatabaseHandler: + def __init__(self, mongo_uri, database_name, results_collection_name): + self.client = MongoClient(mongo_uri) + self.db = self.client[database_name] + self.minio_files_collection = self.db['minio_files'] + self.results_collection = self.db[results_collection_name] + + def get_unprocessed_videos(self): + processed_etags = set(self.results_collection.distinct('etag')) + return self.minio_files_collection.find({ + 'bucket_name': 'raw', + 'object_name': {'$regex': r'^douyin/.*/.+\.(mp4|avi|mov|flv)$'}, + 'etag': {'$nin': list(processed_etags)} + }) + def save_result(self, result): + # 检查是否已存在相同 etag 的结果 + existing_result = self.results_collection.find_one({'etag': result['etag']}) + if existing_result: + print(f"Video with etag {result['etag']} has already been processed. Skipping.") + return + + # 将 ObjectId 转换为字符串 + if 'video_id' in result and isinstance(result['video_id'], ObjectId): + result['video_id'] = str(result['video_id']) + + self.results_collection.insert_one(result) + +class JSONEncoder(json.JSONEncoder): + def default(self, o): + if isinstance(o, ObjectId): + return str(o) + return super().default(o) + + +# 视频处理模块 +class VideoProcessor: + def __init__(self, model_dir): + self.model = AutoModel.from_pretrained(model_dir, trust_remote_code=True, + attn_implementation='sdpa', torch_dtype=torch.bfloat16).eval().cuda() + self.tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True) + self.MAX_NUM_FRAMES = 64 + + def encode_video(self, video_data): + def uniform_sample(l, n): + gap = len(l) / n + idxs = [int(i * gap + gap / 2) for i in range(n)] + return [l[i] for i in idxs] + + video_file = io.BytesIO(video_data) + vr = VideoReader(video_file, ctx=cpu(0)) + sample_fps = round(vr.get_avg_fps() / 1) + frame_idx = [i for i in range(0, len(vr), sample_fps)] + if len(frame_idx) > self.MAX_NUM_FRAMES: + frame_idx = uniform_sample(frame_idx, self.MAX_NUM_FRAMES) + frames = vr.get_batch(frame_idx).asnumpy() + frames = [Image.fromarray(v.astype('uint8')) for v in frames] + print('num frames:', len(frames)) + return frames + + def process_video(self, video_data, object_name): + if not video_data: + raise ValueError(f"Empty video data for {object_name}") + print(f"Processing video: {object_name}, data size: {len(video_data)} bytes") + frames = self.encode_video(video_data) + question = "Describe the video in as much detail as possible in Chinese, including the setting, clear number of people, and changes in behavior." + msgs = [ + {'role': 'user', 'content': frames + [question]}, + ] + + params = { + "use_image_id": False, + "max_slice_nums": 1 + } + + answer = self.model.chat( + image=None, + msgs=msgs, + tokenizer=self.tokenizer, + **params + ) + + extracted_info = self.extract_info(answer) + + return { + "original_answer": answer, + "extracted_info": extracted_info, + "num_frames": len(frames), + } + + + @staticmethod + def extract_info(answer): + info = { + "environment": None, + "num_people": None, + "actions": [], + "interactions": [], + "objects": [], + "furniture": [] + } + + # 环境提取 + environments = ["办公室", "室内", "室外", "会议室", "办公"] + for env in environments: + if env in answer.lower(): + info["environment"] = env + break + + # 改进的人数提取 + people_patterns = [ + r'(\d+)\s*(人|个人|位|名|员工|用户|小朋友|成年人|女性|男性)', + r'(一|二|三|四|五|六|七|八|九|十)\s*(人|个人|位|名|员工|用户|小朋友|成年人|女性|男性)', + r'(一个|几个)\s*(人|个人|员工|用户|小朋友|成年人|女性|男性)', + r'几\s*(名|位)\s*(人|员工|用户|小朋友|成年人|女性|男性)?', + r'(男|女)(性|生|士)', + r'(成年|未成年|青少年|老年)\s*(人|群体)', + r'(员工|职工|工人|学生|顾客|观众|游客|乘客)', + r'(群众|民众|大众|公众)', + r'(男女|老少|老幼|大人|小孩)' + ] + for pattern in people_patterns: + match = re.search(pattern, answer) + if match: + if match.group(1).isdigit(): + info["num_people"] = int(match.group(1)) + elif match.group(1) in ['一个', '一']: + info["num_people"] = 1 + else: + num_word_to_digit = { + '二': 2, '三': 3, '四': 4, '五': 5, + '六': 6, '七': 7, '八': 8, '九': 9, '十': 10 + } + info["num_people"] = num_word_to_digit.get(match.group(1), 0) + break + + # 动作和互动提取 + actions = ["坐", "站", "摔倒", "跳舞", "转身", "摔", "倒", "倒下", "躺下", "转身", "跳跃", "跳", "躺", "睡", "说话"] + for action in actions: + if action in answer: + info["actions"].append(action) + + interactions = ["互动", "交流", "身体语言", "交谈", "讨论", "开会"] + for interaction in interactions: + if interaction in answer: + info["interactions"].append(interaction) + + # 物体和家具提取 + objects = ["水瓶", "办公用品", "文件", "电脑"] + furniture = ["椅子", "桌子", "咖啡桌", "文件柜", "床", "沙发"] + + for obj in objects: + if obj in answer: + info["objects"].append(obj) + + for item in furniture: + if item in answer: + info["furniture"].append(item) + + return info + +# 主处理类 +class VideoAnalysisSystem: + def __init__(self, minio_endpoint, minio_access_key, minio_secret_key, + mongo_uri, db_name, model_dir, results_collection_name): + self.minio_handler = MinioHandler(minio_endpoint, minio_access_key, minio_secret_key) + self.db_handler = DatabaseHandler(mongo_uri, db_name, results_collection_name) + self.video_processor = VideoProcessor(model_dir) + + def process_video(self, video_doc): + start_time = time.time() + try: + video_data = self.minio_handler.get_video_data(video_doc['bucket_name'], video_doc['object_name']) + result = self.video_processor.process_video(video_data, video_doc['object_name']) + + result['etag'] = video_doc['etag'] + result['bucket_name'] = video_doc['bucket_name'] + result['object_name'] = video_doc['object_name'] + + self.db_handler.save_result(result) + + end_time = time.time() + processing_time = end_time - start_time + + print(f"Processed video: {video_doc['object_name']}") + print(f"Processing time: {processing_time:.2f} seconds") + except Exception as e: + end_time = time.time() + processing_time = end_time - start_time + + print(f"Error processing video {video_doc['object_name']}: {str(e)}") + print(f"Processing time (including error): {processing_time:.2f} seconds") + import traceback + traceback.print_exc() + + def run(self): + while True: + unprocessed_videos = list(self.db_handler.get_unprocessed_videos()) + + if not unprocessed_videos: + print("No new videos to process. Waiting for 60 seconds before checking again...") + time.sleep(60) + continue + + for video_doc in unprocessed_videos: + self.process_video(video_doc) + + print("Finished processing current batch of videos. Waiting for new videos...") + time.sleep(30) +# 使用示例 +if __name__ == "__main__": + minio_endpoint = "api.obscura.work" + minio_access_key = "MnHTAG2NOLyXXIZrwDLp" + minio_secret_key = "WVlmMgww0aRIU43pCJ1XCjubXQO6YsbHysxX2hBf" + + mongo_uri = "mongodb://minio_mongo:BCd4npzKBnwmCRdh@222.186.136.78:27017/minio_mongo" + db_name = "minio_mongo" + results_collection_name = "douyin_results" + + model_dir = "MiniCPM-V-2_6" + + system = VideoAnalysisSystem(minio_endpoint, minio_access_key, minio_secret_key, + mongo_uri, db_name, model_dir, results_collection_name) + system.run() \ No newline at end of file diff --git a/api_history/mini_up.py b/api_history/mini_up.py new file mode 100644 index 0000000..30de80c --- /dev/null +++ b/api_history/mini_up.py @@ -0,0 +1,248 @@ +import torch +from PIL import Image +from transformers import AutoModel, AutoTokenizer +from decord import VideoReader, cpu +import json +import re +from pymongo import MongoClient +import io +from minio import Minio +import time +from bson import ObjectId + + +# Minio连接模块 +class MinioHandler: + def __init__(self, endpoint, access_key, secret_key): + self.client = Minio( + endpoint, + access_key=access_key, + secret_key=secret_key, + secure=True + ) + + def get_video_data(self, bucket, object_name): + response = self.client.get_object(bucket, object_name) + data = response.read() + print(f"Read {len(data)} bytes from Minio for {object_name}") + return data + +# 数据库连接模块 +class DatabaseHandler: + def __init__(self, mongo_uri, database_name, results_collection_name): + self.client = MongoClient(mongo_uri) + self.db = self.client[database_name] + self.minio_files_collection = self.db['minio_files'] + self.results_collection = self.db[results_collection_name] + + def get_unprocessed_videos(self): + # 查找 bucket_name 为 'raw-video' 且在结果集合中没有对应 etag 的视频 + processed_etags = set(self.results_collection.distinct('etag')) + return self.minio_files_collection.find({ + 'bucket_name': 'raw', + 'object_name': {'$regex': 'videoupload/'}, + 'etag': {'$nin': list(processed_etags)} + }) + + def save_result(self, result): + # 检查是否已存在相同 etag 的结果 + existing_result = self.results_collection.find_one({'etag': result['etag']}) + if existing_result: + print(f"Video with etag {result['etag']} has already been processed. Skipping.") + return + + # 将 ObjectId 转换为字符串 + if 'video_id' in result and isinstance(result['video_id'], ObjectId): + result['video_id'] = str(result['video_id']) + + self.results_collection.insert_one(result) + +class JSONEncoder(json.JSONEncoder): + def default(self, o): + if isinstance(o, ObjectId): + return str(o) + return super().default(o) + + +# 视频处理模块 +class VideoProcessor: + def __init__(self, model_dir): + self.model = AutoModel.from_pretrained(model_dir, trust_remote_code=True, + attn_implementation='sdpa', torch_dtype=torch.bfloat16).eval().cuda() + self.tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True) + self.MAX_NUM_FRAMES = 64 + + def encode_video(self, video_data): + def uniform_sample(l, n): + gap = len(l) / n + idxs = [int(i * gap + gap / 2) for i in range(n)] + return [l[i] for i in idxs] + + video_file = io.BytesIO(video_data) + vr = VideoReader(video_file, ctx=cpu(0)) + sample_fps = round(vr.get_avg_fps() / 1) + frame_idx = [i for i in range(0, len(vr), sample_fps)] + if len(frame_idx) > self.MAX_NUM_FRAMES: + frame_idx = uniform_sample(frame_idx, self.MAX_NUM_FRAMES) + frames = vr.get_batch(frame_idx).asnumpy() + frames = [Image.fromarray(v.astype('uint8')) for v in frames] + print('num frames:', len(frames)) + return frames + + def process_video(self, video_data, object_name): + if not video_data: + raise ValueError(f"Empty video data for {object_name}") + print(f"Processing video: {object_name}, data size: {len(video_data)} bytes") + frames = self.encode_video(video_data) + question = "Describe the video in as much detail as possible in Chinese, including the setting, clear number of people, and changes in behavior." + msgs = [ + {'role': 'user', 'content': frames + [question]}, + ] + + params = { + "use_image_id": False, + "max_slice_nums": 1 + } + + answer = self.model.chat( + image=None, + msgs=msgs, + tokenizer=self.tokenizer, + **params + ) + + extracted_info = self.extract_info(answer) + + return { + "original_answer": answer, + "extracted_info": extracted_info, + "num_frames": len(frames), + } + + + @staticmethod + def extract_info(answer): + info = { + "environment": None, + "num_people": None, + "actions": [], + "interactions": [], + "objects": [], + "furniture": [] + } + + # 环境提取 + environments = ["办公室", "室内", "室外", "会议室"] + for env in environments: + if env in answer.lower(): + info["environment"] = env + break + + # 改进的人数提取 + people_patterns = [ + r'(\d+)\s*(人|个人|位|名|员工|用户|小朋友|成年人|女性|男性)', + r'(一|二|三|四|五|六|七|八|九|十)\s*(人|个人|位|名|员工|用户|小朋友|成年人|女性|男性)', + r'(一个|几个)\s*(人|个人|员工|用户|小朋友|成年人|女性|男性)', + r'几\s*(名|位)\s*(人|员工|用户|小朋友|成年人|女性|男性)?', + r'(男|女)(性|生|士)', + r'(成年|未成年|青少年|老年)\s*(人|群体)', + r'(员工|职工|工人|学生|顾客|观众|游客|乘客)', + r'(群众|民众|大众|公众)', + r'(男女|老少|老幼|大人|小孩)' + ] + for pattern in people_patterns: + match = re.search(pattern, answer) + if match: + if match.group(1).isdigit(): + info["num_people"] = int(match.group(1)) + elif match.group(1) in ['一个', '一']: + info["num_people"] = 1 + else: + num_word_to_digit = { + '二': 2, '三': 3, '四': 4, '五': 5, + '六': 6, '七': 7, '八': 8, '九': 9, '十': 10 + } + info["num_people"] = num_word_to_digit.get(match.group(1), 0) + break + + # 动作和互动提取 + actions = ["坐", "站", "摔倒", "跳舞", "转身", "摔", "倒", "倒下", "躺下", "转身", "跳跃", "跳", "躺", "睡", "说话"] + for action in actions: + if action in answer: + info["actions"].append(action) + + interactions = ["互动", "交流", "身体语言", "交谈", "讨论", "开会"] + for interaction in interactions: + if interaction in answer: + info["interactions"].append(interaction) + + # 物体和家具提取 + objects = ["水瓶", "办公用品", "文件", "电脑"] + furniture = ["椅子", "桌子", "咖啡桌", "文件柜", "床", "沙发"] + + for obj in objects: + if obj in answer: + info["objects"].append(obj) + + for item in furniture: + if item in answer: + info["furniture"].append(item) + + return info + +# 主处理类 +class VideoAnalysisSystem: + def __init__(self, minio_endpoint, minio_access_key, minio_secret_key, + mongo_uri, db_name, model_dir, results_collection_name): + self.minio_handler = MinioHandler(minio_endpoint, minio_access_key, minio_secret_key) + self.db_handler = DatabaseHandler(mongo_uri, db_name, results_collection_name) + self.video_processor = VideoProcessor(model_dir) + + def run(self): + while True: + unprocessed_videos = list(self.db_handler.get_unprocessed_videos()) + + if not unprocessed_videos: + print("No new videos to process. Waiting for 60 seconds before checking again...") + time.sleep(1) + continue + + for video_doc in unprocessed_videos: + try: + video_data = self.minio_handler.get_video_data(video_doc['bucket_name'], video_doc['object_name']) + result = self.video_processor.process_video(video_data, video_doc['object_name']) + + # 添加额外信息到结果中 + result['etag'] = video_doc['etag'] + # result['video_id'] = str(video_doc['_id']) # 将 ObjectId 转换为字符串 + result['bucket_name'] = video_doc['bucket_name'] + result['object_name'] = video_doc['object_name'] + + # 保存结果到 MongoDB + self.db_handler.save_result(result) + + print(f"Processed video: {video_doc['object_name']}") + # print(json.dumps(result, ensure_ascii=False, indent=2, cls=JSONEncoder)) + except Exception as e: + print(f"Error processing video {video_doc['object_name']}: {str(e)}") + import traceback + traceback.print_exc() # 打印完整的错误堆栈 + + print("Finished processing current batch of videos. Waiting for new videos...") + time.sleep(30) + +# 使用示例 +if __name__ == "__main__": + minio_endpoint = "api.obscura.work" + minio_access_key = "MnHTAG2NOLyXXIZrwDLp" + minio_secret_key = "WVlmMgww0aRIU43pCJ1XCjubXQO6YsbHysxX2hBf" + + mongo_uri = "mongodb://minio_mongo:BCd4npzKBnwmCRdh@222.186.136.78:27017/minio_mongo" + db_name = "minio_mongo" + results_collection_name = "douyin_results" + + model_dir = "OpenBMB/MiniCPM-V-2_6" + + system = VideoAnalysisSystem(minio_endpoint, minio_access_key, minio_secret_key, + mongo_uri, db_name, model_dir, results_collection_name) + system.run() \ No newline at end of file diff --git a/api_history/minicpmv2.6.py b/api_history/minicpmv2.6.py new file mode 100644 index 0000000..8998110 --- /dev/null +++ b/api_history/minicpmv2.6.py @@ -0,0 +1,264 @@ +import torch +from PIL import Image +from transformers import AutoModel, AutoTokenizer +from decord import VideoReader, cpu +import json +import re +from datetime import datetime, timedelta +from pymongo import MongoClient +import io +from minio import Minio +import time +import os +from bson import ObjectId + + +# Minio连接模块 +class MinioHandler: + def __init__(self, endpoint, access_key, secret_key): + self.client = Minio( + endpoint, + access_key=access_key, + secret_key=secret_key, + secure=True + ) + + def get_video_data(self, bucket, object_name): + response = self.client.get_object(bucket, object_name) + return response.read() + +# 数据库连接模块 +class DatabaseHandler: + def __init__(self, mongo_uri, database_name, results_collection_name): + self.client = MongoClient(mongo_uri) + self.db = self.client[database_name] + self.minio_files_collection = self.db['minio_files'] + self.results_collection = self.db[results_collection_name] + + def get_unprocessed_videos(self): + # 查找 bucket_name 为 'raw-video' 且在结果集合中没有对应 etag 的视频 + processed_etags = set(self.results_collection.distinct('etag')) + return self.minio_files_collection.find({ + 'bucket_name': 'raw', + 'object_name': {'$regex': '/douyin/'}, + 'etag': {'$nin': list(processed_etags)} + }) + + def save_result(self, result): + # 检查是否已存在相同 etag 的结果 + existing_result = self.results_collection.find_one({'etag': result['etag']}) + if existing_result: + print(f"Video with etag {result['etag']} has already been processed. Skipping.") + return + + # 将 ObjectId 转换为字符串 + if 'video_id' in result and isinstance(result['video_id'], ObjectId): + result['video_id'] = str(result['video_id']) + + self.results_collection.insert_one(result) + +class JSONEncoder(json.JSONEncoder): + def default(self, o): + if isinstance(o, ObjectId): + return str(o) + return super().default(o) + + +# 视频处理模块 +class VideoProcessor: + def __init__(self, model_dir): + self.model = AutoModel.from_pretrained(model_dir, trust_remote_code=True, + attn_implementation='sdpa', torch_dtype=torch.bfloat16).eval().cuda() + self.tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True) + self.MAX_NUM_FRAMES = 64 + + def encode_video(self, video_data): + def uniform_sample(l, n): + gap = len(l) / n + idxs = [int(i * gap + gap / 2) for i in range(n)] + return [l[i] for i in idxs] + + video_file = io.BytesIO(video_data) + vr = VideoReader(video_file, ctx=cpu(0)) + sample_fps = round(vr.get_avg_fps() / 1) + frame_idx = [i for i in range(0, len(vr), sample_fps)] + if len(frame_idx) > self.MAX_NUM_FRAMES: + frame_idx = uniform_sample(frame_idx, self.MAX_NUM_FRAMES) + frames = vr.get_batch(frame_idx).asnumpy() + frames = [Image.fromarray(v.astype('uint8')) for v in frames] + print('num frames:', len(frames)) + return frames + + def process_video(self, video_data, object_name): + frames = self.encode_video(video_data) + question = "Describe the video in as much detail as possible in Chinese, including the setting, clear number of people, and changes in behavior." + msgs = [ + {'role': 'user', 'content': frames + [question]}, + ] + start_time, end_time = self.extract_time_from_filename(object_name) + + params = { + "use_image_id": False, + "max_slice_nums": 1 + } + + answer = self.model.chat( + image=None, + msgs=msgs, + tokenizer=self.tokenizer, + **params + ) + + extracted_info = self.extract_info(answer) + + return { + "original_answer": answer, + "extracted_info": extracted_info, + "num_frames": len(frames), + "start_time": start_time.strftime("%Y-%m-%d %H:%M:%S"), + "end_time": end_time.strftime("%Y-%m-%d %H:%M:%S") + } + + @staticmethod + def extract_time_from_filename(object_name): + # 从 object_name 中提取文件名 + filename = os.path.basename(object_name) + + # 从文件名中提取日期时间部分 + time_str = filename.split('_')[0] + '_' + filename.split('_')[1].split('.')[0] + + try: + start_time = datetime.strptime(time_str, "%Y%m%d_%H%M%S") + end_time = start_time + timedelta(seconds=10) + return start_time, end_time + except ValueError: + print(f"无法从文件名 '{filename}' 解析时间。使用默认时间。") + return datetime.now(), datetime.now() + timedelta(seconds=10) + + + @staticmethod + def extract_info(answer): + info = { + "environment": None, + "num_people": None, + "actions": [], + "interactions": [], + "objects": [], + "furniture": [] + } + + # 环境提取 + environments = ["办公室", "室内", "室外", "会议室"] + for env in environments: + if env in answer.lower(): + info["environment"] = env + break + + # 改进的人数提取 + people_patterns = [ + r'(\d+)\s*(人|个人|位|名|员工|用户|小朋友|成年人|女性|男性)', + r'(一|二|三|四|五|六|七|八|九|十)\s*(人|个人|位|名|员工|用户|小朋友|成年人|女性|男性)', + r'(一个|几个)\s*(人|个人|员工|用户|小朋友|成年人|女性|男性)', + r'几\s*(名|位)\s*(人|员工|用户|小朋友|成年人|女性|男性)?', + r'(男|女)(性|生|士)', + r'(成年|未成年|青少年|老年)\s*(人|群体)', + r'(员工|职工|工人|学生|顾客|观众|游客|乘客)', + r'(群众|民众|大众|公众)', + r'(男女|老少|老幼|大人|小孩)' + ] + for pattern in people_patterns: + match = re.search(pattern, answer) + if match: + if match.group(1).isdigit(): + info["num_people"] = int(match.group(1)) + elif match.group(1) in ['一个', '一']: + info["num_people"] = 1 + else: + num_word_to_digit = { + '二': 2, '三': 3, '四': 4, '五': 5, + '六': 6, '七': 7, '八': 8, '九': 9, '十': 10 + } + info["num_people"] = num_word_to_digit.get(match.group(1), 0) + break + + # 动作和互动提取 + actions = ["坐", "站", "摔倒", "跳舞", "转身", "摔", "倒", "倒下", "躺下", "转身", "跳跃", "跳", "躺", "睡", "说话"] + for action in actions: + if action in answer: + info["actions"].append(action) + + interactions = ["互动", "交流", "身体语言", "交谈", "讨论", "开会"] + for interaction in interactions: + if interaction in answer: + info["interactions"].append(interaction) + + # 物体和家具提取 + objects = ["水瓶", "办公用品", "文件", "电脑"] + furniture = ["椅子", "桌子", "咖啡桌", "文件柜", "床", "沙发"] + + for obj in objects: + if obj in answer: + info["objects"].append(obj) + + for item in furniture: + if item in answer: + info["furniture"].append(item) + + return info + +# 主处理类 +class VideoAnalysisSystem: + def __init__(self, minio_endpoint, minio_access_key, minio_secret_key, + mongo_uri, db_name, model_dir, results_collection_name): + self.minio_handler = MinioHandler(minio_endpoint, minio_access_key, minio_secret_key) + self.db_handler = DatabaseHandler(mongo_uri, db_name, results_collection_name) + self.video_processor = VideoProcessor(model_dir) + + def run(self): + while True: + unprocessed_videos = list(self.db_handler.get_unprocessed_videos()) + + if not unprocessed_videos: + print("No new videos to process. Waiting for 60 seconds before checking again...") + time.sleep(1) + continue + + for video_doc in unprocessed_videos: + try: + video_data = self.minio_handler.get_video_data(video_doc['bucket_name'], video_doc['object_name']) + result = self.video_processor.process_video(video_data, video_doc['object_name']) + + # 添加额外信息到结果中 + result['etag'] = video_doc['etag'] + # result['video_id'] = str(video_doc['_id']) # 将 ObjectId 转换为字符串 + result['bucket_name'] = video_doc['bucket_name'] + result['object_name'] = video_doc['object_name'] + + # 保存结果到 MongoDB + self.db_handler.save_result(result) + + print(f"Processed video: {video_doc['object_name']}") + # print(json.dumps(result, ensure_ascii=False, indent=2, cls=JSONEncoder)) + except Exception as e: + print(f"Error processing video {video_doc['object_name']}: {str(e)}") + import traceback + traceback.print_exc() # 打印完整的错误堆栈 + + print("Finished processing current batch of videos. Waiting for new videos...") + time.sleep(30) + +# 使用示例 +if __name__ == "__main__": + minio_endpoint = "api.obscura.work" + minio_access_key = "MnHTAG2NOLyXXIZrwDLp" + minio_secret_key = "WVlmMgww0aRIU43pCJ1XCjubXQO6YsbHysxX2hBf" + + mongo_uri = "mongodb://minio_mongo:BCd4npzKBnwmCRdh@222.186.136.78:27017/minio_mongo" + db_name = "minio_mongo" + results_collection_name = "douyin_results" + + model_dir = "OpenBMB/MiniCPM-V-2_6" + + system = VideoAnalysisSystem(minio_endpoint, minio_access_key, minio_secret_key, + mongo_uri, db_name, model_dir, results_collection_name) + system.run() \ No newline at end of file diff --git a/api_history/sound.py b/api_history/sound.py new file mode 100644 index 0000000..113590b --- /dev/null +++ b/api_history/sound.py @@ -0,0 +1,121 @@ +import io +import os +import tempfile +import time +from bson import ObjectId +from minio import Minio +from pymongo import MongoClient +import whisper + +class MinioHandler: + def __init__(self, endpoint, access_key, secret_key): + self.client = Minio( + endpoint, + access_key=access_key, + secret_key=secret_key, + secure=True + ) + + def get_video_data(self, bucket, object_name): + response = self.client.get_object(bucket, object_name) + data = response.read() + print(f"Read {len(data)} bytes from Minio for {object_name}") + return data + +class DatabaseHandler: + def __init__(self, mongo_uri, database_name, collection_name): + self.client = MongoClient(mongo_uri) + self.db = self.client[database_name] + self.collection = self.db[collection_name] + + def get_unprocessed_videos(self): + return self.collection.find({ + 'bucket_name': 'raw', + 'object_name': {'$regex': r'^douyin/.*/.+\.(mp4|avi|mov|flv)$'}, + 'whisper_transcription': {'$exists': False} + }) + + def update_transcription(self, video_id, transcription): + self.collection.update_one( + {'_id': video_id}, + {'$set': {'whisper_transcription': transcription}} + ) + +class WhisperProcessor: + def __init__(self, model_name, model_path=None): + if model_path: + self.model = whisper.load_model(model_name, download_root=model_path) + else: + self.model = whisper.load_model(model_name) + + def transcribe_audio(self, video_data): + with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") as temp_video: + temp_video.write(video_data) + temp_video_path = temp_video.name + + try: + result = self.model.transcribe(temp_video_path) + return result["text"] + finally: + os.unlink(temp_video_path) + +class WhisperTranscriptionSystem: + def __init__(self, minio_endpoint, minio_access_key, minio_secret_key, + mongo_uri, db_name, collection_name, whisper_model_name, whisper_model_path=None): + self.minio_handler = MinioHandler(minio_endpoint, minio_access_key, minio_secret_key) + self.db_handler = DatabaseHandler(mongo_uri, db_name, collection_name) + self.whisper_processor = WhisperProcessor(whisper_model_name, whisper_model_path) + + def process_video(self, video_doc): + start_time = time.time() + try: + video_data = self.minio_handler.get_video_data(video_doc['bucket_name'], video_doc['object_name']) + transcription = self.whisper_processor.transcribe_audio(video_data) + + self.db_handler.update_transcription(video_doc['_id'], transcription) + + end_time = time.time() + processing_time = end_time - start_time + + print(f"Processed video: {video_doc['object_name']}") + print(f"Processing time: {processing_time:.2f} seconds") + except Exception as e: + end_time = time.time() + processing_time = end_time - start_time + + print(f"Error processing video {video_doc['object_name']}: {str(e)}") + print(f"Processing time (including error): {processing_time:.2f} seconds") + import traceback + traceback.print_exc() + + def run(self): + while True: + unprocessed_videos = list(self.db_handler.get_unprocessed_videos()) + + if not unprocessed_videos: + print("No new videos to process. Waiting for 60 seconds before checking again...") + time.sleep(60) + continue + + for video_doc in unprocessed_videos: + self.process_video(video_doc) + + print("Finished processing current batch of videos. Waiting for new videos...") + time.sleep(30) + +if __name__ == "__main__": + minio_endpoint = "api.obscura.work" + minio_access_key = "MnHTAG2NOLyXXIZrwDLp" + minio_secret_key = "WVlmMgww0aRIU43pCJ1XCjubXQO6YsbHysxX2hBf" + + mongo_uri = "mongodb://minio_mongo:BCd4npzKBnwmCRdh@222.186.136.78:27017/minio_mongo" + db_name = "minio_mongo" + collection_name = "douyin_results" + + whisper_model_name = "large-v3" # 指定模型名称 + whisper_model_path = "whisper" # 指定模型存放路径 + + system = WhisperTranscriptionSystem(minio_endpoint, minio_access_key, minio_secret_key, + mongo_uri, db_name, collection_name, + whisper_model_name, whisper_model_path) + system.run() \ No newline at end of file diff --git a/api_history/sound_result copy.py b/api_history/sound_result copy.py new file mode 100644 index 0000000..a1666fa --- /dev/null +++ b/api_history/sound_result copy.py @@ -0,0 +1,114 @@ +import io +import os +import tempfile +import time +from bson import ObjectId +from minio import Minio +from pymongo import MongoClient +import whisper + +class MinioHandler: + def __init__(self, endpoint, access_key, secret_key): + self.client = Minio( + endpoint, + access_key=access_key, + secret_key=secret_key, + secure=True + ) + + def get_video_data(self, bucket, object_name): + response = self.client.get_object(bucket, object_name) + data = response.read() + print(f"Read {len(data)} bytes from Minio for {object_name}") + return data + +class DatabaseHandler: + def __init__(self, mongo_uri, database_name, collection_name): + self.client = MongoClient(mongo_uri) + self.db = self.client[database_name] + self.collection = self.db[collection_name] + + def get_unprocessed_videos(self): + return self.collection.find({ + 'bucket_name': 'raw', + 'object_name': {'$regex': 'douyin/'}, + 'whisper_transcription': {'$exists': False} + }) + + def update_transcription(self, video_id, transcription): + self.collection.update_one( + {'_id': video_id}, + {'$set': {'whisper_transcription': transcription}} + ) + +class WhisperProcessor: + def __init__(self, model_name="large-v3"): + self.model = whisper.load_model(model_name) + + def transcribe_audio(self, video_data): + with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") as temp_video: + temp_video.write(video_data) + temp_video_path = temp_video.name + + try: + result = self.model.transcribe(temp_video_path) + return result["text"] + finally: + os.unlink(temp_video_path) + +class WhisperTranscriptionSystem: + def __init__(self, minio_endpoint, minio_access_key, minio_secret_key, + mongo_uri, db_name, collection_name): + self.minio_handler = MinioHandler(minio_endpoint, minio_access_key, minio_secret_key) + self.db_handler = DatabaseHandler(mongo_uri, db_name, collection_name) + self.whisper_processor = WhisperProcessor() + + def process_video(self, video_doc): + start_time = time.time() + try: + video_data = self.minio_handler.get_video_data(video_doc['bucket_name'], video_doc['object_name']) + transcription = self.whisper_processor.transcribe_audio(video_data) + + self.db_handler.update_transcription(video_doc['_id'], transcription) + + end_time = time.time() + processing_time = end_time - start_time + + print(f"Processed video: {video_doc['object_name']}") + print(f"Processing time: {processing_time:.2f} seconds") + except Exception as e: + end_time = time.time() + processing_time = end_time - start_time + + print(f"Error processing video {video_doc['object_name']}: {str(e)}") + print(f"Processing time (including error): {processing_time:.2f} seconds") + import traceback + traceback.print_exc() + + def run(self): + while True: + unprocessed_videos = list(self.db_handler.get_unprocessed_videos()) + + if not unprocessed_videos: + print("No new videos to process. Waiting for 60 seconds before checking again...") + time.sleep(60) + continue + + print(f"Found {len(unprocessed_videos)} videos to process.") + for video_doc in unprocessed_videos: + self.process_video(video_doc) + + print("Finished processing current batch of videos. Checking for more...") + +if __name__ == "__main__": + minio_endpoint = "api.obscura.work" + minio_access_key = "MnHTAG2NOLyXXIZrwDLp" + minio_secret_key = "WVlmMgww0aRIU43pCJ1XCjubXQO6YsbHysxX2hBf" + + mongo_uri = "mongodb://minio_mongo:BCd4npzKBnwmCRdh@222.186.136.78:27017/minio_mongo" + db_name = "minio_mongo" + collection_name = "douyin_results" + + system = WhisperTranscriptionSystem(minio_endpoint, minio_access_key, minio_secret_key, + mongo_uri, db_name, collection_name) + system.run() \ No newline at end of file diff --git a/api_history/sound_result.py b/api_history/sound_result.py new file mode 100644 index 0000000..a1666fa --- /dev/null +++ b/api_history/sound_result.py @@ -0,0 +1,114 @@ +import io +import os +import tempfile +import time +from bson import ObjectId +from minio import Minio +from pymongo import MongoClient +import whisper + +class MinioHandler: + def __init__(self, endpoint, access_key, secret_key): + self.client = Minio( + endpoint, + access_key=access_key, + secret_key=secret_key, + secure=True + ) + + def get_video_data(self, bucket, object_name): + response = self.client.get_object(bucket, object_name) + data = response.read() + print(f"Read {len(data)} bytes from Minio for {object_name}") + return data + +class DatabaseHandler: + def __init__(self, mongo_uri, database_name, collection_name): + self.client = MongoClient(mongo_uri) + self.db = self.client[database_name] + self.collection = self.db[collection_name] + + def get_unprocessed_videos(self): + return self.collection.find({ + 'bucket_name': 'raw', + 'object_name': {'$regex': 'douyin/'}, + 'whisper_transcription': {'$exists': False} + }) + + def update_transcription(self, video_id, transcription): + self.collection.update_one( + {'_id': video_id}, + {'$set': {'whisper_transcription': transcription}} + ) + +class WhisperProcessor: + def __init__(self, model_name="large-v3"): + self.model = whisper.load_model(model_name) + + def transcribe_audio(self, video_data): + with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") as temp_video: + temp_video.write(video_data) + temp_video_path = temp_video.name + + try: + result = self.model.transcribe(temp_video_path) + return result["text"] + finally: + os.unlink(temp_video_path) + +class WhisperTranscriptionSystem: + def __init__(self, minio_endpoint, minio_access_key, minio_secret_key, + mongo_uri, db_name, collection_name): + self.minio_handler = MinioHandler(minio_endpoint, minio_access_key, minio_secret_key) + self.db_handler = DatabaseHandler(mongo_uri, db_name, collection_name) + self.whisper_processor = WhisperProcessor() + + def process_video(self, video_doc): + start_time = time.time() + try: + video_data = self.minio_handler.get_video_data(video_doc['bucket_name'], video_doc['object_name']) + transcription = self.whisper_processor.transcribe_audio(video_data) + + self.db_handler.update_transcription(video_doc['_id'], transcription) + + end_time = time.time() + processing_time = end_time - start_time + + print(f"Processed video: {video_doc['object_name']}") + print(f"Processing time: {processing_time:.2f} seconds") + except Exception as e: + end_time = time.time() + processing_time = end_time - start_time + + print(f"Error processing video {video_doc['object_name']}: {str(e)}") + print(f"Processing time (including error): {processing_time:.2f} seconds") + import traceback + traceback.print_exc() + + def run(self): + while True: + unprocessed_videos = list(self.db_handler.get_unprocessed_videos()) + + if not unprocessed_videos: + print("No new videos to process. Waiting for 60 seconds before checking again...") + time.sleep(60) + continue + + print(f"Found {len(unprocessed_videos)} videos to process.") + for video_doc in unprocessed_videos: + self.process_video(video_doc) + + print("Finished processing current batch of videos. Checking for more...") + +if __name__ == "__main__": + minio_endpoint = "api.obscura.work" + minio_access_key = "MnHTAG2NOLyXXIZrwDLp" + minio_secret_key = "WVlmMgww0aRIU43pCJ1XCjubXQO6YsbHysxX2hBf" + + mongo_uri = "mongodb://minio_mongo:BCd4npzKBnwmCRdh@222.186.136.78:27017/minio_mongo" + db_name = "minio_mongo" + collection_name = "douyin_results" + + system = WhisperTranscriptionSystem(minio_endpoint, minio_access_key, minio_secret_key, + mongo_uri, db_name, collection_name) + system.run() \ No newline at end of file diff --git a/api_history/video_s3.py b/api_history/video_s3.py new file mode 100644 index 0000000..ec63d48 --- /dev/null +++ b/api_history/video_s3.py @@ -0,0 +1,256 @@ +import torch +from PIL import Image +from transformers import AutoModel, AutoTokenizer +from decord import VideoReader, cpu +import json +import re +from pymongo import MongoClient +import io +from minio import Minio +import time +from bson import ObjectId +import concurrent.futures +import os + +class MinioHandler: + def __init__(self, endpoint, access_key, secret_key, secure=True): + self.client = Minio( + endpoint, + access_key=access_key, + secret_key=secret_key, + secure=secure + ) + + def list_objects(self, bucket_name, prefix): + objects = self.client.list_objects(bucket_name, prefix=prefix, recursive=True) + return [obj for obj in objects if obj.object_name.lower().endswith(('.mp4', '.avi', '.mov', '.flv'))] + + def get_video_data(self, bucket_name, object_name): + try: + response = self.client.get_object(bucket_name, object_name) + return response.read() + except Exception as e: + print(f"Error retrieving video data for {object_name}: {str(e)}") + return None + +class DatabaseHandler: + def __init__(self, mongo_uri, database_name, results_collection_name): + self.client = MongoClient(mongo_uri) + self.db = self.client[database_name] + self.results_collection = self.db[results_collection_name] + + def get_unprocessed_videos(self, minio_handler, bucket_name='raw', prefix='videoupload/'): + all_objects = minio_handler.list_objects(bucket_name, prefix) + processed_etags = set(self.results_collection.distinct('etag')) + + unprocessed_videos = [ + { + 'bucket_name': bucket_name, + 'object_name': obj.object_name, + 'etag': obj.etag, + 'size': obj.size, + 'last_modified': obj.last_modified + } + for obj in all_objects if obj.etag not in processed_etags + ] + + return unprocessed_videos + + def save_result(self, result): + existing_result = self.results_collection.find_one({'etag': result['etag']}) + if existing_result: + print(f"Video with etag {result['etag']} has already been processed. Skipping.") + return + + if 'video_id' in result and isinstance(result['video_id'], ObjectId): + result['video_id'] = str(result['video_id']) + + self.results_collection.insert_one(result) + +class JSONEncoder(json.JSONEncoder): + def default(self, o): + if isinstance(o, ObjectId): + return str(o) + return super().default(o) + +class VideoProcessor: + def __init__(self, model_dir): + self.model = AutoModel.from_pretrained(model_dir, trust_remote_code=True, + attn_implementation='sdpa', torch_dtype=torch.bfloat16).eval().cuda() + self.tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True) + self.MAX_NUM_FRAMES = 12 + + def encode_video(self, video_data): + def uniform_sample(l, n): + gap = len(l) / n + return [l[int(i * gap + gap / 2)] for i in range(n)] + + video_file = io.BytesIO(video_data) + vr = VideoReader(video_file, ctx=cpu(0)) + sample_fps = round(vr.get_avg_fps() / 1) + frame_idx = list(range(0, len(vr), sample_fps)) + if len(frame_idx) > self.MAX_NUM_FRAMES: + frame_idx = uniform_sample(frame_idx, self.MAX_NUM_FRAMES) + frames = vr.get_batch(frame_idx).asnumpy() + frames = [Image.fromarray(v.astype('uint8')) for v in frames] + print('num frames:', len(frames)) + return frames + + def process_video(self, video_data, object_name): + if not video_data: + raise ValueError(f"Empty video data for {object_name}") + print(f"Processing video: {object_name}, data size: {len(video_data)} bytes") + frames = self.encode_video(video_data) + question = "Describe the video in as much detail as possible in Chinese, including the setting, clear number of people, and changes in behavior." + msgs = [ + {'role': 'user', 'content': frames + [question]}, + ] + + params = { + "use_image_id": False, + "max_slice_nums": 1 + } + + answer = self.model.chat( + image=None, + msgs=msgs, + tokenizer=self.tokenizer, + **params + ) + + extracted_info = self.extract_info(answer) + + return { + "original_answer": answer, + "extracted_info": extracted_info, + "num_frames": len(frames), + } + + @staticmethod + def extract_info(answer): + info = { + "environment": None, + "num_people": None, + "actions": [], + "interactions": [], + "objects": [], + "furniture": [] + } + + environments = ["办公室", "室内", "室外", "会议室", "办公"] + for env in environments: + if env in answer.lower(): + info["environment"] = env + break + + people_patterns = [ + r'(\d+)\s*(人|个人|位|名|员工|用户|小朋友|成年人|女性|男性)', + r'(一|二|三|四|五|六|七|八|九|十)\s*(人|个人|位|名|员工|用户|小朋友|成年人|女性|男性)', + r'(一个|几个)\s*(人|个人|员工|用户|小朋友|成年人|女性|男性)', + r'几\s*(名|位)\s*(人|员工|用户|小朋友|成年人|女性|男性)?', + r'(男|女)(性|生|士)', + r'(成年|未成年|青少年|老年)\s*(人|群体)', + r'(员工|职工|工人|学生|顾客|观众|游客|乘客)', + r'(群众|民众|大众|公众)', + r'(男女|老少|老幼|大人|小孩)' + ] + for pattern in people_patterns: + match = re.search(pattern, answer) + if match: + if match.group(1).isdigit(): + info["num_people"] = int(match.group(1)) + elif match.group(1) in ['一个', '一']: + info["num_people"] = 1 + else: + num_word_to_digit = { + '二': 2, '三': 3, '四': 4, '五': 5, + '六': 6, '七': 7, '八': 8, '九': 9, '十': 10 + } + info["num_people"] = num_word_to_digit.get(match.group(1), 0) + break + + actions = ["坐", "站", "摔倒", "跳舞", "转身", "摔", "倒", "倒下", "躺下", "转身", "跳跃", "跳", "躺", "睡", "说话"] + interactions = ["互动", "交流", "身体语言", "交谈", "讨论", "开会"] + objects = ["水瓶", "办公用品", "文件", "电脑"] + furniture = ["椅子", "桌子", "咖啡桌", "文件柜", "床", "沙发"] + + for action in actions: + if action in answer: + info["actions"].append(action) + + for interaction in interactions: + if interaction in answer: + info["interactions"].append(interaction) + + for obj in objects: + if obj in answer: + info["objects"].append(obj) + + for item in furniture: + if item in answer: + info["furniture"].append(item) + + return info + +class VideoAnalysisSystem: + def __init__(self, minio_endpoint, minio_access_key, minio_secret_key, + mongo_uri, db_name, model_dir, results_collection_name): + self.minio_handler = MinioHandler(minio_endpoint, minio_access_key, minio_secret_key) + self.db_handler = DatabaseHandler(mongo_uri, db_name, results_collection_name) + self.video_processor = VideoProcessor(model_dir) + + def process_video(self, video_doc): + start_time = time.time() + try: + video_data = self.minio_handler.get_video_data(video_doc['bucket_name'], video_doc['object_name']) + result = self.video_processor.process_video(video_data, video_doc['object_name']) + + result['etag'] = video_doc['etag'] + result['bucket_name'] = video_doc['bucket_name'] + result['object_name'] = video_doc['object_name'] + + self.db_handler.save_result(result) + + end_time = time.time() + processing_time = end_time - start_time + + print(f"Processed video: {video_doc['object_name']}") + print(f"Processing time: {processing_time:.2f} seconds") + except Exception as e: + end_time = time.time() + processing_time = end_time - start_time + + print(f"Error processing video {video_doc['object_name']}: {str(e)}") + print(f"Processing time (including error): {processing_time:.2f} seconds") + import traceback + traceback.print_exc() + + def run(self): + while True: + unprocessed_videos = self.db_handler.get_unprocessed_videos(self.minio_handler) + + if not unprocessed_videos: + print("No new videos to process. Waiting for 5 seconds before checking again...") + time.sleep(1) + continue + + for video_doc in unprocessed_videos: + self.process_video(video_doc) + + print("Finished processing current batch of videos. Waiting for new videos...") + time.sleep(1) + +if __name__ == "__main__": + minio_endpoint = "api.obscura.work" + minio_access_key = "MnHTAG2NOLyXXIZrwDLp" + minio_secret_key = "WVlmMgww0aRIU43pCJ1XCjubXQO6YsbHysxX2hBf" + + mongo_uri = "mongodb://minio_mongo:BCd4npzKBnwmCRdh@222.186.136.78:27017/minio_mongo" + db_name = "minio_mongo" + results_collection_name = "videoupload_results" + + model_dir = "MiniCPM-V-2_6" + + system = VideoAnalysisSystem(minio_endpoint, minio_access_key, minio_secret_key, + mongo_uri, db_name, model_dir, results_collection_name) + system.run() \ No newline at end of file diff --git a/api_old/__pycache__/config.cpython-310.pyc b/api_old/__pycache__/config.cpython-310.pyc new file mode 100644 index 0000000..5695d51 Binary files /dev/null and b/api_old/__pycache__/config.cpython-310.pyc differ diff --git a/api_old/__pycache__/config.cpython-311.pyc b/api_old/__pycache__/config.cpython-311.pyc new file mode 100644 index 0000000..7c60927 Binary files /dev/null and b/api_old/__pycache__/config.cpython-311.pyc differ diff --git a/api_old/__pycache__/cpm_api.cpython-311.pyc b/api_old/__pycache__/cpm_api.cpython-311.pyc new file mode 100644 index 0000000..f6ee1af Binary files /dev/null and b/api_old/__pycache__/cpm_api.cpython-311.pyc differ diff --git a/api_old/__pycache__/face_api.cpython-311.pyc b/api_old/__pycache__/face_api.cpython-311.pyc new file mode 100644 index 0000000..9d35129 Binary files /dev/null and b/api_old/__pycache__/face_api.cpython-311.pyc differ diff --git a/api_old/__pycache__/fall_api.cpython-311.pyc b/api_old/__pycache__/fall_api.cpython-311.pyc new file mode 100644 index 0000000..860bf63 Binary files /dev/null and b/api_old/__pycache__/fall_api.cpython-311.pyc differ diff --git a/api_old/__pycache__/mediapipe.cpython-310.pyc b/api_old/__pycache__/mediapipe.cpython-310.pyc new file mode 100644 index 0000000..ac74b89 Binary files /dev/null and b/api_old/__pycache__/mediapipe.cpython-310.pyc differ diff --git a/api_old/__pycache__/mediapipe.cpython-311.pyc b/api_old/__pycache__/mediapipe.cpython-311.pyc new file mode 100644 index 0000000..a2bade3 Binary files /dev/null and b/api_old/__pycache__/mediapipe.cpython-311.pyc differ diff --git a/api_old/__pycache__/mediapipe_api.cpython-311.pyc b/api_old/__pycache__/mediapipe_api.cpython-311.pyc new file mode 100644 index 0000000..133a8de Binary files /dev/null and b/api_old/__pycache__/mediapipe_api.cpython-311.pyc differ diff --git a/api_old/__pycache__/pose.cpython-311.pyc b/api_old/__pycache__/pose.cpython-311.pyc new file mode 100644 index 0000000..bce79c2 Binary files /dev/null and b/api_old/__pycache__/pose.cpython-311.pyc differ diff --git a/api_old/__pycache__/qwenvl_api.cpython-311.pyc b/api_old/__pycache__/qwenvl_api.cpython-311.pyc new file mode 100644 index 0000000..3b43c10 Binary files /dev/null and b/api_old/__pycache__/qwenvl_api.cpython-311.pyc differ diff --git a/api_old/__pycache__/yolo_api.cpython-311.pyc b/api_old/__pycache__/yolo_api.cpython-311.pyc new file mode 100644 index 0000000..8d5ec31 Binary files /dev/null and b/api_old/__pycache__/yolo_api.cpython-311.pyc differ diff --git a/api_old/before/__pycache__/mediapipe.cpython-311.pyc b/api_old/before/__pycache__/mediapipe.cpython-311.pyc new file mode 100644 index 0000000..0e013dc Binary files /dev/null and b/api_old/before/__pycache__/mediapipe.cpython-311.pyc differ diff --git a/api_old/before/cpm_api.py b/api_old/before/cpm_api.py new file mode 100644 index 0000000..172e1ce --- /dev/null +++ b/api_old/before/cpm_api.py @@ -0,0 +1,369 @@ +import os +import json +import uuid +from datetime import datetime, timedelta +from fastapi import FastAPI, HTTPException, UploadFile, File +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import JSONResponse +from kafka import KafkaProducer, KafkaConsumer +from transformers import AutoTokenizer, AutoModel +from decord import VideoReader, cpu +from PIL import Image +from redis import Redis +import io +import re +import torch +import asyncio +from contextlib import asynccontextmanager +import threading + +app = FastAPI() +cpm_app = FastAPI() +app.mount("/cpm", cpm_app) + +# CORS设置 +ALLOWED_ORIGINS = ['https://beta.obscura.work'] + +app.add_middleware( + CORSMiddleware, + allow_origins=ALLOWED_ORIGINS, + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + +# 配置 +MODEL_PATH = "/home/zydi/worker_sys/OpenBMB/MiniCPM-V-2_6" +KAFKA_BROKER = "222.186.10.253:9092" +KAFKA_TOPIC = "cpm" +KAFKA_GROUP_ID = "cpm_group" + +REDIS_HOST = "222.186.10.253" +REDIS_PORT = 6379 +REDIS_PASSWORD = "Obscura@2024" +REDIS_DB = 5 + +UPLOAD_DIR = "/www/wwwroot/beta.obscura.work/upload_files/upload" +RESULT_DIR = "/www/wwwroot/beta.obscura.work/upload_files/result" +MAX_FILE_AGE = timedelta(hours=1) + +# 确保目录存在 +os.makedirs(UPLOAD_DIR, exist_ok=True) +os.makedirs(RESULT_DIR, exist_ok=True) + +# 初始化 Kafka +producer = KafkaProducer(bootstrap_servers=[KAFKA_BROKER]) +consumer = KafkaConsumer( + KAFKA_TOPIC, + bootstrap_servers=[KAFKA_BROKER], + group_id=KAFKA_GROUP_ID, + auto_offset_reset='earliest', + enable_auto_commit=True, + value_deserializer=lambda x: json.loads(x.decode('utf-8')) +) + +# 初始化 Redis +redis_client = Redis( + host=REDIS_HOST, + port=REDIS_PORT, + password=REDIS_PASSWORD, + db=REDIS_DB +) + +# 设置 GPU 设备 +torch.cuda.set_device(0) + +# 初始化模型 +model = AutoModel.from_pretrained(MODEL_PATH, trust_remote_code=True) +model = model.half().cuda().eval() +tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True) + +class MediaAnalysisSystem: + def __init__(self, model, tokenizer): + self.model = model + self.tokenizer = tokenizer + self.device = torch.device("cuda:0") + self.model = self.model.to(self.device) + self.MAX_NUM_FRAMES = 16 + + def encode_video(self, video_data): + def uniform_sample(l, n): + gap = len(l) / n + return [l[int(i * gap + gap / 2)] for i in range(n)] + + video_file = io.BytesIO(video_data) + vr = VideoReader(video_file, ctx=cpu(0)) + sample_fps = round(vr.get_avg_fps() / 1) + frame_idx = list(range(0, len(vr), sample_fps)) + if len(frame_idx) > self.MAX_NUM_FRAMES: + frame_idx = uniform_sample(frame_idx, self.MAX_NUM_FRAMES) + frames = vr.get_batch(frame_idx).asnumpy() + frames = [Image.fromarray(v.astype('uint8')) for v in frames] + print('num frames:', len(frames)) + return frames + + + def process_video(self, video_data, object_name): + if not video_data: + raise ValueError(f"Empty video data for {object_name}") + print(f"Processing video: {object_name}, data size: {len(video_data)} bytes") + frames = self.encode_video(video_data) + question = "Describe the video in as much detail as possible in Chinese, including the setting, clear number of people, and changes in behavior." + msgs = [ + {'role': 'user', 'content': frames + [question]}, + ] + + params = { + "use_image_id": False, + "max_slice_nums": 1 + } + answer = self.model.chat( + image=frames, # 直接传递 frames + msgs=msgs, + tokenizer=self.tokenizer, + max_length=512, + temperature=0.7, + top_p=0.9, + **params + ) + extracted_info = self.extract_info(answer) + + return { + "original_answer": answer, + "extracted_info": extracted_info, + "num_frames": len(frames), + # "start_time": start_time.strftime("%Y-%m-%d %H:%M:%S"), + # "end_time": end_time.strftime("%Y-%m-%d %H:%M:%S") + } + + def process_image(self, image_data, object_name): + image = Image.open(io.BytesIO(image_data)) + question = "描述这张图片,包括场景、人物数量和行为等细节。" + msgs = [ + {'role': 'user', 'content': [image] + [question]}, + ] + + params = { + "use_image_id": False, + "max_slice_nums": 1 + } + + answer = self.model.chat( + image=None, + msgs=msgs, + tokenizer=self.tokenizer, + max_length=512, + temperature=0.7, + top_p=0.9, + **params + ) + + extracted_info = self.extract_info(answer) + + return { + "original_answer": answer, + "extracted_info": extracted_info, + "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S") + } + + @staticmethod + def extract_time_from_filename(object_name): + filename = os.path.basename(object_name) + time_str = filename.split('_')[0] + '_' + filename.split('_')[1].split('.')[0] + + try: + start_time = datetime.strptime(time_str, "%Y%m%d_%H%M%S") + end_time = start_time + timedelta(seconds=10) + return start_time, end_time + except ValueError: + print(f"无法从文件名 '{filename}' 解析时间。使用默认时间。") + return datetime.now(), datetime.now() + timedelta(seconds=10) + + @staticmethod + def extract_info(answer): + info = { + "environment": None, + "num_people": None, + "actions": [], + "interactions": [], + "objects": [], + "furniture": [] + } + + environments = ["办公室", "室内", "室外", "会议室"] + for env in environments: + if env in answer.lower(): + info["environment"] = env + break + + people_patterns = [ + r'(\d+)\s*(人|个人|位|名|员工|用户|小朋友|成年人|女性|男性)', + r'(一|二|三|四|五|六|七|八|九|十)\s*(人|个人|位|名|员工|用户|小朋友|成年人|女性|男性)', + r'(一个|几个)\s*(人|个人|员工|用户|小朋友|成年人|女性|男性)', + r'几\s*(名|位)\s*(人|员工|用户|小朋友|成年人|女性|男性)?', + r'(男|女)(性|生|士)', + r'(成年|未成年|青少年|老年)\s*(人|群体)', + r'(员工|职工|工人|学生|顾客|观众|游客|乘客)', + r'(群众|民众|大众|公众)', + r'(男女|老少|老幼|大人|小孩)' + ] + for pattern in people_patterns: + match = re.search(pattern, answer) + if match: + if match.group(1).isdigit(): + info["num_people"] = int(match.group(1)) + elif match.group(1) in ['一个', '一']: + info["num_people"] = 1 + else: + num_word_to_digit = { + '二': 2, '三': 3, '四': 4, '五': 5, + '六': 6, '七': 7, '八': 8, '九': 9, '十': 10 + } + info["num_people"] = num_word_to_digit.get(match.group(1), 0) + break + + actions = ["坐", "站", "摔倒", "跳舞", "转身", "摔", "倒", "倒下", "躺下", "转身", "跳跃", "跳", "躺", "睡", "说话"] + for action in actions: + if action in answer: + info["actions"].append(action) + + interactions = ["互动", "交流", "身体语言", "交谈", "讨论", "开会"] + for interaction in interactions: + if interaction in answer: + info["interactions"].append(interaction) + + objects = ["水瓶", "办公用品", "文件", "电脑"] + furniture = ["椅子", "桌子", "咖啡桌", "文件柜", "床", "沙发"] + + for obj in objects: + if obj in answer: + info["objects"].append(obj) + + for item in furniture: + if item in answer: + info["furniture"].append(item) + + return info + +# 初始化 MediaAnalysisSystem +media_analysis_system = MediaAnalysisSystem(model, tokenizer) + +async def process_file(file: UploadFile, file_type: str): + content = await file.read() + # 获取原始文件的后缀 + original_extension = os.path.splitext(file.filename)[1] + + # 生成新的文件名,包含 UUID 和原始后缀 + filename = f"cpm_{uuid.uuid4()}{original_extension}" + file_path = os.path.join(UPLOAD_DIR, filename) + with open(file_path, "wb") as f: + f.write(content) + + producer.send(KAFKA_TOPIC, json.dumps({ + "filename": filename, + "type": file_type + }).encode('utf-8')) + + redis_key = f"{file_type}_result:{filename}" + redis_client.set(redis_key, json.dumps({"status": "queued"})) + + return {"message": f"{file_type.capitalize()} uploaded and queued for processing", "filename": filename} + +@cpm_app.post("/upload") +async def upload_file(file: UploadFile = File(...)): + try: + file_type = "image" if file.content_type.startswith("image") else "video" + return await process_file(file, file_type) + except Exception as e: + return JSONResponse(content={"error": str(e)}, status_code=500) + +@cpm_app.post("/analyze_video") +async def analyze_video(file: UploadFile = File(...)): + try: + return await process_file(file, "video") + except Exception as e: + return JSONResponse(content={"error": str(e)}, status_code=500) + +@cpm_app.post("/analyze_image") +async def analyze_image(file: UploadFile = File(...)): + try: + return await process_file(file, "image") + except Exception as e: + return JSONResponse(content={"error": str(e)}, status_code=500) + +def process_task(): + for message in consumer: + try: + if isinstance(message.value, dict): + task = message.value + else: + task = json.loads(message.value.decode('utf-8')) + + filename = task['filename'] + file_type = task['type'] + + file_path = os.path.join(UPLOAD_DIR, filename) + + redis_key = f"{file_type}_result:{filename}" + redis_client.set(redis_key, json.dumps({"status": "processing"})) + + with open(file_path, 'rb') as f: + file_data = f.read() + + if file_type == "video": + result = media_analysis_system.process_video(file_data, filename) + elif file_type == "image": + result = media_analysis_system.process_image(file_data, filename) + + # 保存结果到 JSON 文件 + result_file_path = os.path.join(RESULT_DIR, f"{filename}.json") + with open(result_file_path, 'w', encoding='utf-8') as f: + json.dump(result, f, ensure_ascii=False, indent=2) + + # 将结果存储在 Redis 中 + redis_client.set(redis_key, json.dumps({ + "status": "completed", + "result": result + })) + + except Exception as e: + print(f"Error processing task: {str(e)}") + if 'filename' in locals() and 'file_type' in locals(): + redis_key = f"{file_type}_result:{filename}" + redis_client.set(redis_key, json.dumps({"status": "failed", "error": str(e)})) + else: + print("Error occurred before task details were extracted") + +@cpm_app.get("/result/{filename}") +async def get_result(filename: str): + for file_type in ["video", "image"]: + redis_key = f"{file_type}_result:{filename}" + result = redis_client.get(redis_key) + if result: + result_json = json.loads(result) + + if result_json.get("status") == "queued": + return {"status": "queued", "message": "Your request is in the queue and will be processed soon."} + elif result_json.get("status") == "processing": + return {"status": "processing", "message": "Your request is being processed."} + else: + return result_json + + raise HTTPException(status_code=404, detail="Result not found") + +async def listen_redis_changes(): + pubsub = redis_client.pubsub() + pubsub.psubscribe('__keyspace@5__:*_result:*') + + for message in pubsub.listen(): + if message['type'] == 'pmessage': + key = message['channel'].decode('utf-8').split(':')[-1] + print(f"Key changed: {key}") + +if __name__ == "__main__": + # 在后台线程中启动Kafka消费者 + consumer_thread = threading.Thread(target=process_task, daemon=True) + consumer_thread.start() + + import uvicorn + uvicorn.run(app, host="0.0.0.0", port=7000) \ No newline at end of file diff --git a/api_old/before/cpm_key.py b/api_old/before/cpm_key.py new file mode 100644 index 0000000..8530610 --- /dev/null +++ b/api_old/before/cpm_key.py @@ -0,0 +1,518 @@ +import os +import json +import uuid +from datetime import datetime, timedelta, timezone +from fastapi import FastAPI, HTTPException, UploadFile, File, Depends, Header +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import JSONResponse +from fastapi.security import APIKeyHeader +from kafka import KafkaProducer, KafkaConsumer +from transformers import AutoTokenizer, AutoModel +from decord import VideoReader, cpu +from PIL import Image +from redis import Redis +import io +import re +import torch +import asyncio +from contextlib import asynccontextmanager +import threading + +app = FastAPI() +cpm_app = FastAPI() +app.mount("/cpm", cpm_app) + +# CORS设置 +ALLOWED_ORIGINS = ['https://beta.obscura.work'] + +app.add_middleware( + CORSMiddleware, + allow_origins=ALLOWED_ORIGINS, + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + +# 配置 +MODEL_PATH = "/home/zydi/worker_sys/OpenBMB/MiniCPM-V-2_6" +KAFKA_BROKER = "222.186.10.253:9092" +KAFKA_TOPIC = "cpm" +KAFKA_GROUP_ID = "cpm_group" + +REDIS_HOST = "222.186.10.253" +REDIS_PORT = 6379 +REDIS_PASSWORD = "Obscura@2024" +REDIS_DB = 5 +REDIS_API_DB = 12 +REDIS_API_USAGE_DB = 13 + + +UPLOAD_DIR = "/www/wwwroot/beta.obscura.work/upload_files/upload" +RESULT_DIR = "/www/wwwroot/beta.obscura.work/upload_files/result" +MAX_FILE_AGE = timedelta(hours=1) + +# 确保目录存在 +os.makedirs(UPLOAD_DIR, exist_ok=True) +os.makedirs(RESULT_DIR, exist_ok=True) + +# 初始化 Kafka +producer = KafkaProducer(bootstrap_servers=[KAFKA_BROKER]) +consumer = KafkaConsumer( + KAFKA_TOPIC, + bootstrap_servers=[KAFKA_BROKER], + group_id=KAFKA_GROUP_ID, + auto_offset_reset='earliest', + enable_auto_commit=True, + value_deserializer=lambda x: json.loads(x.decode('utf-8')) +) + +# 初始化 Redis +redis_client = Redis( + host=REDIS_HOST, + port=REDIS_PORT, + password=REDIS_PASSWORD, + db=REDIS_DB +) + +redis_api_client = Redis( + host=REDIS_HOST, + port=REDIS_PORT, + password=REDIS_PASSWORD, + db=REDIS_API_DB +) + +redis_api_usage_client = Redis( + host=REDIS_HOST, + port=REDIS_PORT, + password=REDIS_PASSWORD, + db=REDIS_API_USAGE_DB +) + +# 添加API密钥验证 +API_KEY_NAME = "X-API-Key" +api_key_header = APIKeyHeader(name=API_KEY_NAME, auto_error=False) + +async def get_api_key(api_key: str = Header(None, alias=API_KEY_NAME)): + if api_key is None: + raise HTTPException(status_code=400, detail="API密钥缺失") + return api_key + +async def verify_api_key(api_key: str = Depends(get_api_key)): + redis_key = f"api_key:{api_key}" + + api_key_info = redis_api_client.hgetall(redis_key) + + if not api_key_info: + return None + + api_key_info = {k.decode('utf-8'): v.decode('utf-8') for k, v in api_key_info.items()} + + if api_key_info.get('is_active') != '1': + return None + + expires_at = datetime.fromisoformat(api_key_info.get('expires_at')) + if datetime.now(timezone.utc) > expires_at: + return None + + usage_info = redis_api_usage_client.hgetall(redis_key) + usage_info = {k.decode('utf-8'): v.decode('utf-8') for k, v in usage_info.items()} + + return { + **api_key_info, + **usage_info + } + +async def update_token_usage(api_key: str, new_tokens_used: int, model_name: str): + redis_key = f"api_key:{api_key}" + current_time = datetime.now(timezone.utc).isoformat() + + pipe = redis_api_usage_client.pipeline() + + # 更新总的token使用量 + pipe.hincrby(redis_key, "tokens_used", new_tokens_used) + pipe.hset(redis_key, "last_used_at", current_time) + + # 更新特定模型的token使用量 + model_tokens_field = f"{model_name}_tokens_used" + model_last_used_field = f"{model_name}_last_used_at" + + pipe.hincrby(redis_key, model_tokens_field, new_tokens_used) + pipe.hset(redis_key, model_last_used_field, current_time) + + pipe.execute() + +def calculate_tokens(file_path: str, file_type: str) -> int: + base_tokens = 0 + + try: + file_size = os.path.getsize(file_path) # 获取文件大小(字节) + + # 基础token:每MB文件大小消耗10个token + base_tokens = int((file_size / (1024 * 1024)) * 10) + + if file_type == "image": + img = Image.open(file_path) + width, height = img.size + pixel_count = width * height + + # 图片token:每100个像素额外消耗5个token + image_tokens = int((pixel_count / 10000) * 5) + + base_tokens += image_tokens + + elif file_type == "video": + vr = VideoReader(file_path) + fps = vr.get_avg_fps() + frame_count = len(vr) + width, height = vr[0].shape[1], vr[0].shape[0] + + pixel_count = width * height * frame_count + duration = frame_count / fps # 视频时长(秒) + + # 视频token:每100万像素每秒额外消耗1个token + video_tokens = int((pixel_count / 10000) * (duration / 60)) + + base_tokens += video_tokens + + return max(1, base_tokens) # 确保至少返回1个token + except Exception as e: + print(f"计算token时出错: {str(e)}") + return 1 # 出错时返回默认值1 + +# 设置 GPU 设备 +torch.cuda.set_device(0) + +# 初始化模型 +model = AutoModel.from_pretrained(MODEL_PATH, trust_remote_code=True) +model = model.half().cuda().eval() +tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True) + +class MediaAnalysisSystem: + def __init__(self, model, tokenizer): + self.model = model + self.tokenizer = tokenizer + self.device = torch.device("cuda:0") + self.model = self.model.to(self.device) + self.MAX_NUM_FRAMES = 16 + + def encode_video(self, video_data): + def uniform_sample(l, n): + gap = len(l) / n + return [l[int(i * gap + gap / 2)] for i in range(n)] + + video_file = io.BytesIO(video_data) + vr = VideoReader(video_file, ctx=cpu(0)) + sample_fps = round(vr.get_avg_fps() / 1) + frame_idx = list(range(0, len(vr), sample_fps)) + if len(frame_idx) > self.MAX_NUM_FRAMES: + frame_idx = uniform_sample(frame_idx, self.MAX_NUM_FRAMES) + frames = vr.get_batch(frame_idx).asnumpy() + frames = [Image.fromarray(v.astype('uint8')) for v in frames] + print('num frames:', len(frames)) + return frames + + + def process_video(self, video_data, object_name): + if not video_data: + raise ValueError(f"Empty video data for {object_name}") + print(f"Processing video: {object_name}, data size: {len(video_data)} bytes") + frames = self.encode_video(video_data) + question = "Describe the video in as much detail as possible in Chinese, including the setting, clear number of people, and changes in behavior." + msgs = [ + {'role': 'user', 'content': frames + [question]}, + ] + + params = { + "use_image_id": False, + "max_slice_nums": 1 + } + answer = self.model.chat( + image=frames, # 直接传递 frames + msgs=msgs, + tokenizer=self.tokenizer, + max_length=512, + temperature=0.7, + top_p=0.9, + **params + ) + extracted_info = self.extract_info(answer) + + return { + "original_answer": answer, + "extracted_info": extracted_info, + "num_frames": len(frames), + # "start_time": start_time.strftime("%Y-%m-%d %H:%M:%S"), + # "end_time": end_time.strftime("%Y-%m-%d %H:%M:%S") + } + + def process_image(self, image_data, object_name): + image = Image.open(io.BytesIO(image_data)) + question = "描述这张图片,包括场景、人物数量和行为等细节。" + msgs = [ + {'role': 'user', 'content': [image] + [question]}, + ] + + params = { + "use_image_id": False, + "max_slice_nums": 1 + } + + answer = self.model.chat( + image=None, + msgs=msgs, + tokenizer=self.tokenizer, + max_length=512, + temperature=0.7, + top_p=0.9, + **params + ) + + extracted_info = self.extract_info(answer) + + return { + "original_answer": answer, + "extracted_info": extracted_info, + "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S") + } + + @staticmethod + def extract_time_from_filename(object_name): + filename = os.path.basename(object_name) + time_str = filename.split('_')[0] + '_' + filename.split('_')[1].split('.')[0] + + try: + start_time = datetime.strptime(time_str, "%Y%m%d_%H%M%S") + end_time = start_time + timedelta(seconds=10) + return start_time, end_time + except ValueError: + print(f"无法从文件名 '{filename}' 解析时间。使用默认时间。") + return datetime.now(), datetime.now() + timedelta(seconds=10) + + @staticmethod + def extract_info(answer): + info = { + "environment": None, + "num_people": None, + "actions": [], + "interactions": [], + "objects": [], + "furniture": [] + } + + environments = ["办公室", "室内", "室外", "会议室"] + for env in environments: + if env in answer.lower(): + info["environment"] = env + break + + people_patterns = [ + r'(\d+)\s*(人|个人|位|名|员工|用户|小朋友|成年人|女性|男性)', + r'(一|二|三|四|五|六|七|八|九|十)\s*(人|个人|位|名|员工|用户|小朋友|成年人|女性|男性)', + r'(一个|几个)\s*(人|个人|员工|用户|小朋友|成年人|女性|男性)', + r'几\s*(名|位)\s*(人|员工|用户|小朋友|成年人|女性|男性)?', + r'(男|女)(性|生|士)', + r'(成年|未成年|青少年|老年)\s*(人|群体)', + r'(员工|职工|工人|学生|顾客|观众|游客|乘客)', + r'(群众|民众|大众|公众)', + r'(男女|老少|老幼|大人|小孩)' + ] + for pattern in people_patterns: + match = re.search(pattern, answer) + if match: + if match.group(1).isdigit(): + info["num_people"] = int(match.group(1)) + elif match.group(1) in ['一个', '一']: + info["num_people"] = 1 + else: + num_word_to_digit = { + '二': 2, '三': 3, '四': 4, '五': 5, + '六': 6, '七': 7, '八': 8, '九': 9, '十': 10 + } + info["num_people"] = num_word_to_digit.get(match.group(1), 0) + break + + actions = ["坐", "站", "摔倒", "跳舞", "转身", "摔", "倒", "倒下", "躺下", "转身", "跳跃", "跳", "躺", "睡", "说话"] + for action in actions: + if action in answer: + info["actions"].append(action) + + interactions = ["互动", "交流", "身体语言", "交谈", "讨论", "开会"] + for interaction in interactions: + if interaction in answer: + info["interactions"].append(interaction) + + objects = ["水瓶", "办公用品", "文件", "电脑"] + furniture = ["椅子", "桌子", "咖啡桌", "文件柜", "床", "沙发"] + + for obj in objects: + if obj in answer: + info["objects"].append(obj) + + for item in furniture: + if item in answer: + info["furniture"].append(item) + + return info + +# 初始化 MediaAnalysisSystem +media_analysis_system = MediaAnalysisSystem(model, tokenizer) + +async def process_file(file: UploadFile, file_type: str, api_key: str): + content = await file.read() + original_extension = os.path.splitext(file.filename)[1] + + filename = f"cpm_{uuid.uuid4()}{original_extension}" + file_path = os.path.join(UPLOAD_DIR, filename) + with open(file_path, "wb") as f: + f.write(content) + + # 计算token + tokens_required = calculate_tokens(file_path, file_type) + + # 检查并更新token使用量 + usage_key = f"api_key:{api_key}" + total_tokens = int(redis_api_usage_client.hget(usage_key, "total_tokens") or 0) + tokens_used = int(redis_api_usage_client.hget(usage_key, "tokens_used") or 0) + + if tokens_used + tokens_required > total_tokens: + raise HTTPException(status_code=403, detail="Token 余额不足") + + # 更新token使用量 + model_name = "MiniCPM-V-2_6" + await update_token_usage(api_key, tokens_required, model_name) + + + producer.send(KAFKA_TOPIC, json.dumps({ + "filename": filename, + "type": file_type + }).encode('utf-8')) + + redis_key = f"{file_type}_result:{filename}" + redis_client.set(redis_key, json.dumps({"status": "queued"})) + + # 获取更新后的 token 使用情况 + updated_api_key_info = await verify_api_key(api_key) + new_tokens_used = int(updated_api_key_info.get("tokens_used", 0)) + model_tokens_used = int(updated_api_key_info.get(f"{model_name}_tokens_used", 0)) + + return JSONResponse(content={ + "message": "文件已上传并排队等待处理", + "filename": filename, + "tokens_used": tokens_required, + "total_tokens_used": new_tokens_used, + f"{model_name}_tokens_used": model_tokens_used, + "tokens_remaining": total_tokens - new_tokens_used + }) + +@cpm_app.post("/upload") +async def upload_file(file: UploadFile = File(...), api_key: str = Depends(get_api_key)): + api_key_info = await verify_api_key(api_key) + if not api_key_info: + raise HTTPException(status_code=403, detail="无效的API密钥") + + try: + file_type = "image" if file.content_type.startswith("image") else "video" + return await process_file(file, file_type, api_key) + except Exception as e: + return JSONResponse(content={"error": str(e)}, status_code=500) + + +@cpm_app.post("/analyze_video") +async def analyze_video(file: UploadFile = File(...), api_key: str = Depends(get_api_key)): + api_key_info = await verify_api_key(api_key) + if not api_key_info: + raise HTTPException(status_code=403, detail="无效的API密钥") + + try: + return await process_file(file, "video", api_key) + except Exception as e: + return JSONResponse(content={"error": str(e)}, status_code=500) + +@cpm_app.post("/analyze_image") +async def analyze_image(file: UploadFile = File(...), api_key: str = Depends(get_api_key)): + api_key_info = await verify_api_key(api_key) + if not api_key_info: + raise HTTPException(status_code=403, detail="无效的API密钥") + + try: + return await process_file(file, "image", api_key) + except Exception as e: + return JSONResponse(content={"error": str(e)}, status_code=500) + + +def process_task(): + for message in consumer: + try: + if isinstance(message.value, dict): + task = message.value + else: + task = json.loads(message.value.decode('utf-8')) + + filename = task['filename'] + file_type = task['type'] + + file_path = os.path.join(UPLOAD_DIR, filename) + + redis_key = f"{file_type}_result:{filename}" + redis_client.set(redis_key, json.dumps({"status": "processing"})) + + with open(file_path, 'rb') as f: + file_data = f.read() + + if file_type == "video": + result = media_analysis_system.process_video(file_data, filename) + elif file_type == "image": + result = media_analysis_system.process_image(file_data, filename) + + # 保存结果到 JSON 文件 + result_file_path = os.path.join(RESULT_DIR, f"{filename}.json") + with open(result_file_path, 'w', encoding='utf-8') as f: + json.dump(result, f, ensure_ascii=False, indent=2) + + # 将结果存储在 Redis 中 + redis_client.set(redis_key, json.dumps({ + "status": "completed", + "result": result + })) + + except Exception as e: + print(f"Error processing task: {str(e)}") + if 'filename' in locals() and 'file_type' in locals(): + redis_key = f"{file_type}_result:{filename}" + redis_client.set(redis_key, json.dumps({"status": "failed", "error": str(e)})) + else: + print("Error occurred before task details were extracted") + +@cpm_app.get("/result/{filename}") +async def get_result(filename: str): + for file_type in ["video", "image"]: + redis_key = f"{file_type}_result:{filename}" + result = redis_client.get(redis_key) + if result: + result_json = json.loads(result) + + if result_json.get("status") == "queued": + return {"status": "queued", "message": "Your request is in the queue and will be processed soon."} + elif result_json.get("status") == "processing": + return {"status": "processing", "message": "Your request is being processed."} + else: + return result_json + + raise HTTPException(status_code=404, detail="Result not found") + +async def listen_redis_changes(): + pubsub = redis_client.pubsub() + pubsub.psubscribe('__keyspace@5__:*_result:*') + + for message in pubsub.listen(): + if message['type'] == 'pmessage': + key = message['channel'].decode('utf-8').split(':')[-1] + print(f"Key changed: {key}") + +if __name__ == "__main__": + # 在后台线程中启动Kafka消费者 + consumer_thread = threading.Thread(target=process_task, daemon=True) + consumer_thread.start() + + import uvicorn + uvicorn.run(app, host="0.0.0.0", port=7000) \ No newline at end of file diff --git a/api_old/before/cpm_key2.py b/api_old/before/cpm_key2.py new file mode 100644 index 0000000..8f95815 --- /dev/null +++ b/api_old/before/cpm_key2.py @@ -0,0 +1,526 @@ +import os +import json +import uuid +from datetime import datetime, timedelta, timezone +from fastapi import FastAPI, HTTPException, UploadFile, File, Depends, Security +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import JSONResponse +from fastapi.security import APIKeyHeader +from kafka import KafkaProducer, KafkaConsumer +from transformers import AutoTokenizer, AutoModel +from decord import VideoReader, cpu +from PIL import Image +from redis import Redis +import io +import re +import torch +from contextlib import asynccontextmanager +import threading +import string + + +app = FastAPI() +cpm_app = FastAPI() +app.mount("/cpm", cpm_app) + +# CORS设置 +ALLOWED_ORIGINS = ['https://beta.obscura.work'] + +app.add_middleware( + CORSMiddleware, + allow_origins=ALLOWED_ORIGINS, + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + +# 配置 +MODEL_PATH = "/home/zydi/worker_sys/OpenBMB/MiniCPM-V-2_6" +KAFKA_BROKER = "222.186.10.253:9092" +KAFKA_TOPIC = "cpm" +KAFKA_GROUP_ID = "cpm_group" + +REDIS_HOST = "222.186.10.253" +REDIS_PORT = 6379 +REDIS_PASSWORD = "Obscura@2024" +REDIS_DB = 5 +REDIS_API_DB = 12 +REDIS_API_USAGE_DB = 13 + + +UPLOAD_DIR = "/www/wwwroot/beta.obscura.work/upload_files/upload" +RESULT_DIR = "/www/wwwroot/beta.obscura.work/upload_files/result" +MAX_FILE_AGE = timedelta(hours=1) + +# 确保目录存在 +os.makedirs(UPLOAD_DIR, exist_ok=True) +os.makedirs(RESULT_DIR, exist_ok=True) + +# 初始化 Kafka +producer = KafkaProducer(bootstrap_servers=[KAFKA_BROKER]) +consumer = KafkaConsumer( + KAFKA_TOPIC, + bootstrap_servers=[KAFKA_BROKER], + group_id=KAFKA_GROUP_ID, + auto_offset_reset='earliest', + enable_auto_commit=True, + value_deserializer=lambda x: json.loads(x.decode('utf-8')) +) + +# 初始化 Redis +redis_client = Redis( + host=REDIS_HOST, + port=REDIS_PORT, + password=REDIS_PASSWORD, + db=REDIS_DB +) + +redis_api_client = Redis( + host=REDIS_HOST, + port=REDIS_PORT, + password=REDIS_PASSWORD, + db=REDIS_API_DB +) + +redis_api_usage_client = Redis( + host=REDIS_HOST, + port=REDIS_PORT, + password=REDIS_PASSWORD, + db=REDIS_API_USAGE_DB +) + +# 定义API密钥头部 +api_key_header = APIKeyHeader(name="Authorization", auto_error=False) + +# 定义 base62 字符集 +BASE62 = string.digits + string.ascii_letters + +# 验证API密钥的函数 +async def get_api_key(api_key: str = Security(api_key_header)): + if api_key and api_key.startswith("Bearer "): + key = api_key.split(" ")[1] + if key.startswith("obs-"): + return key + raise HTTPException( + status_code=401, + detail="无效的API密钥", + headers={"WWW-Authenticate": "Bearer"}, + ) + +async def verify_api_key(api_key: str = Depends(get_api_key)): + redis_key = f"api_key:{api_key}" + + api_key_info = redis_api_client.hgetall(redis_key) + + if not api_key_info: + raise HTTPException(status_code=401, detail="无效的API密钥") + + api_key_info = {k.decode('utf-8'): v.decode('utf-8') for k, v in api_key_info.items()} + + if api_key_info.get('is_active') != '1': + raise HTTPException(status_code=401, detail="API密钥已停用") + + expires_at = datetime.fromisoformat(api_key_info.get('expires_at')) + if datetime.now(timezone.utc) > expires_at: + raise HTTPException(status_code=401, detail="API密钥已过期") + + usage_info = redis_api_usage_client.hgetall(redis_key) + usage_info = {k.decode('utf-8'): v.decode('utf-8') for k, v in usage_info.items()} + + return { + "api_key": api_key, + **api_key_info, + **usage_info + } + +async def update_token_usage(api_key: str, new_tokens_used: int, model_name: str): + redis_key = f"api_key:{api_key}" + current_time = datetime.now(timezone.utc).isoformat() + + pipe = redis_api_usage_client.pipeline() + + # 更新总的token使用量 + pipe.hincrby(redis_key, "tokens_used", new_tokens_used) + pipe.hset(redis_key, "last_used_at", current_time) + + # 更新特定模型的token使用量 + model_tokens_field = f"{model_name}_tokens_used" + model_last_used_field = f"{model_name}_last_used_at" + + pipe.hincrby(redis_key, model_tokens_field, new_tokens_used) + pipe.hset(redis_key, model_last_used_field, current_time) + + pipe.execute() + +def calculate_tokens(file_path: str, file_type: str) -> int: + base_tokens = 0 + + try: + file_size = os.path.getsize(file_path) # 获取文件大小(字节) + + # 基础token:每MB文件大小消耗10个token + base_tokens = int((file_size / (1024 * 1024)) * 10) + + if file_type == "image": + img = Image.open(file_path) + width, height = img.size + pixel_count = width * height + + # 图片token:每100个像素额外消耗5个token + image_tokens = int((pixel_count / 10000) * 5) + + base_tokens += image_tokens + + elif file_type == "video": + vr = VideoReader(file_path) + fps = vr.get_avg_fps() + frame_count = len(vr) + width, height = vr[0].shape[1], vr[0].shape[0] + + pixel_count = width * height * frame_count + duration = frame_count / fps # 视频时长(秒) + + # 视频token:每100万像素每秒额外消耗1个token + video_tokens = int((pixel_count / 10000) * (duration / 60)) + + base_tokens += video_tokens + + return max(1, base_tokens) # 确保至少返回1个token + except Exception as e: + print(f"计算token时出错: {str(e)}") + return 1 # 出错时返回默认值1 + +# 设置 GPU 设备 +torch.cuda.set_device(0) + +# 初始化模型 +model = AutoModel.from_pretrained(MODEL_PATH, trust_remote_code=True) +model = model.half().cuda().eval() +tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True) + +class MediaAnalysisSystem: + def __init__(self, model, tokenizer): + self.model = model + self.tokenizer = tokenizer + self.device = torch.device("cuda:0") + self.model = self.model.to(self.device) + self.MAX_NUM_FRAMES = 16 + + def encode_video(self, video_data): + def uniform_sample(l, n): + gap = len(l) / n + return [l[int(i * gap + gap / 2)] for i in range(n)] + + video_file = io.BytesIO(video_data) + vr = VideoReader(video_file, ctx=cpu(0)) + sample_fps = round(vr.get_avg_fps() / 1) + frame_idx = list(range(0, len(vr), sample_fps)) + if len(frame_idx) > self.MAX_NUM_FRAMES: + frame_idx = uniform_sample(frame_idx, self.MAX_NUM_FRAMES) + frames = vr.get_batch(frame_idx).asnumpy() + frames = [Image.fromarray(v.astype('uint8')) for v in frames] + print('num frames:', len(frames)) + return frames + + + def process_video(self, video_data, object_name): + if not video_data: + raise ValueError(f"Empty video data for {object_name}") + print(f"Processing video: {object_name}, data size: {len(video_data)} bytes") + frames = self.encode_video(video_data) + question = "Describe the video in as much detail as possible in Chinese, including the setting, clear number of people, and changes in behavior." + msgs = [ + {'role': 'user', 'content': frames + [question]}, + ] + + params = { + "use_image_id": False, + "max_slice_nums": 1 + } + answer = self.model.chat( + image=frames, # 直接传递 frames + msgs=msgs, + tokenizer=self.tokenizer, + max_length=512, + temperature=0.7, + top_p=0.9, + **params + ) + extracted_info = self.extract_info(answer) + + return { + "original_answer": answer, + "extracted_info": extracted_info, + "num_frames": len(frames), + # "start_time": start_time.strftime("%Y-%m-%d %H:%M:%S"), + # "end_time": end_time.strftime("%Y-%m-%d %H:%M:%S") + } + + def process_image(self, image_data, object_name): + image = Image.open(io.BytesIO(image_data)) + question = "描述这张图片,包括场景、人物数量和行为等细节。" + msgs = [ + {'role': 'user', 'content': [image] + [question]}, + ] + + params = { + "use_image_id": False, + "max_slice_nums": 1 + } + + answer = self.model.chat( + image=None, + msgs=msgs, + tokenizer=self.tokenizer, + max_length=512, + temperature=0.7, + top_p=0.9, + **params + ) + + extracted_info = self.extract_info(answer) + + return { + "original_answer": answer, + "extracted_info": extracted_info, + "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S") + } + + @staticmethod + def extract_time_from_filename(object_name): + filename = os.path.basename(object_name) + time_str = filename.split('_')[0] + '_' + filename.split('_')[1].split('.')[0] + + try: + start_time = datetime.strptime(time_str, "%Y%m%d_%H%M%S") + end_time = start_time + timedelta(seconds=10) + return start_time, end_time + except ValueError: + print(f"无法从文件名 '{filename}' 解析时间。使用默认时间。") + return datetime.now(), datetime.now() + timedelta(seconds=10) + + @staticmethod + def extract_info(answer): + info = { + "environment": None, + "num_people": None, + "actions": [], + "interactions": [], + "objects": [], + "furniture": [] + } + + environments = ["办公室", "室内", "室外", "会议室"] + for env in environments: + if env in answer.lower(): + info["environment"] = env + break + + people_patterns = [ + r'(\d+)\s*(人|个人|位|名|员工|用户|小朋友|成年人|女性|男性)', + r'(一|二|三|四|五|六|七|八|九|十)\s*(人|个人|位|名|员工|用户|小朋友|成年人|女性|男性)', + r'(一个|几个)\s*(人|个人|员工|用户|小朋友|成年人|女性|男性)', + r'几\s*(名|位)\s*(人|员工|用户|小朋友|成年人|女性|男性)?', + r'(男|女)(性|生|士)', + r'(成年|未成年|青少年|老年)\s*(人|群体)', + r'(员工|职工|工人|学生|顾客|观众|游客|乘客)', + r'(群众|民众|大众|公众)', + r'(男女|老少|老幼|大人|小孩)' + ] + for pattern in people_patterns: + match = re.search(pattern, answer) + if match: + if match.group(1).isdigit(): + info["num_people"] = int(match.group(1)) + elif match.group(1) in ['一个', '一']: + info["num_people"] = 1 + else: + num_word_to_digit = { + '二': 2, '三': 3, '四': 4, '五': 5, + '六': 6, '七': 7, '八': 8, '九': 9, '十': 10 + } + info["num_people"] = num_word_to_digit.get(match.group(1), 0) + break + + actions = ["坐", "站", "摔倒", "跳舞", "转身", "摔", "倒", "倒下", "躺下", "转身", "跳跃", "跳", "躺", "睡", "说话"] + for action in actions: + if action in answer: + info["actions"].append(action) + + interactions = ["互动", "交流", "身体语言", "交谈", "讨论", "开会"] + for interaction in interactions: + if interaction in answer: + info["interactions"].append(interaction) + + objects = ["水瓶", "办公用品", "文件", "电脑"] + furniture = ["椅子", "桌子", "咖啡桌", "文件柜", "床", "沙发"] + + for obj in objects: + if obj in answer: + info["objects"].append(obj) + + for item in furniture: + if item in answer: + info["furniture"].append(item) + + return info + +# 初始化 MediaAnalysisSystem +media_analysis_system = MediaAnalysisSystem(model, tokenizer) + +async def process_file(file: UploadFile, file_type: str, api_key_info: dict): + content = await file.read() + original_extension = os.path.splitext(file.filename)[1] + + filename = f"cpm_{uuid.uuid4()}{original_extension}" + file_path = os.path.join(UPLOAD_DIR, filename) + with open(file_path, "wb") as f: + f.write(content) + + # 计算token + tokens_required = calculate_tokens(file_path, file_type) + + # 检查并更新token使用量 + api_key = api_key_info['api_key'] + usage_key = f"api_key:{api_key}" + total_tokens = int(redis_api_usage_client.hget(usage_key, "total_tokens") or 0) + tokens_used = int(redis_api_usage_client.hget(usage_key, "tokens_used") or 0) + + if tokens_used + tokens_required > total_tokens: + raise HTTPException(status_code=403, detail="Token 余额不足") + + # 更新token使用量 + model_name = "MiniCPM-V-2_6" + await update_token_usage(api_key, tokens_required, model_name) + + + producer.send(KAFKA_TOPIC, json.dumps({ + "filename": filename, + "type": file_type + }).encode('utf-8')) + + redis_key = f"{file_type}_result:{filename}" + redis_client.set(redis_key, json.dumps({"status": "queued"})) + + # 获取更新后的 token 使用情况 + updated_api_key_info = await verify_api_key(api_key) + new_tokens_used = int(updated_api_key_info.get("tokens_used", 0)) + model_tokens_used = int(updated_api_key_info.get(f"{model_name}_tokens_used", 0)) + + return JSONResponse(content={ + "message": "文件已上传并排队等待处理", + "filename": filename, + "tokens_used": tokens_required, + "total_tokens_used": new_tokens_used, + f"{model_name}_tokens_used": model_tokens_used, + "tokens_remaining": total_tokens - new_tokens_used + }) + +@cpm_app.post("/upload") +async def upload_file(file: UploadFile = File(...), api_key_info: dict = Depends(verify_api_key)): + try: + file_type = "image" if file.content_type.startswith("image") else "video" + return await process_file(file, file_type, api_key_info) + except Exception as e: + return JSONResponse(content={"error": str(e)}, status_code=500) + + +@cpm_app.post("/analyze_video") +async def upload_file(file: UploadFile = File(...), api_key_info: dict = Depends(verify_api_key)): + api_key_info = await verify_api_key(api_key_info) + if not api_key_info: + raise HTTPException(status_code=403, detail="无效的API密钥") + + try: + return await process_file(file, "video", api_key_info) + except Exception as e: + return JSONResponse(content={"error": str(e)}, status_code=500) + +@cpm_app.post("/analyze_image") +async def upload_file(file: UploadFile = File(...), api_key_info: dict = Depends(verify_api_key)): + api_key_info = await verify_api_key(api_key_info) + if not api_key_info: + raise HTTPException(status_code=403, detail="无效的API密钥") + + try: + return await process_file(file, "image", api_key_info) + except Exception as e: + return JSONResponse(content={"error": str(e)}, status_code=500) + + +def process_task(): + for message in consumer: + try: + if isinstance(message.value, dict): + task = message.value + else: + task = json.loads(message.value.decode('utf-8')) + + filename = task['filename'] + file_type = task['type'] + + file_path = os.path.join(UPLOAD_DIR, filename) + + redis_key = f"{file_type}_result:{filename}" + redis_client.set(redis_key, json.dumps({"status": "processing"})) + + with open(file_path, 'rb') as f: + file_data = f.read() + + if file_type == "video": + result = media_analysis_system.process_video(file_data, filename) + elif file_type == "image": + result = media_analysis_system.process_image(file_data, filename) + + # 保存结果到 JSON 文件 + result_file_path = os.path.join(RESULT_DIR, f"{filename}.json") + with open(result_file_path, 'w', encoding='utf-8') as f: + json.dump(result, f, ensure_ascii=False, indent=2) + + # 将结果存储在 Redis 中 + redis_client.set(redis_key, json.dumps({ + "status": "completed", + "result": result + })) + + except Exception as e: + print(f"Error processing task: {str(e)}") + if 'filename' in locals() and 'file_type' in locals(): + redis_key = f"{file_type}_result:{filename}" + redis_client.set(redis_key, json.dumps({"status": "failed", "error": str(e)})) + else: + print("Error occurred before task details were extracted") + +@cpm_app.get("/result/{filename}") +async def get_result(filename: str): + for file_type in ["video", "image"]: + redis_key = f"{file_type}_result:{filename}" + result = redis_client.get(redis_key) + if result: + result_json = json.loads(result) + + if result_json.get("status") == "queued": + return {"status": "queued", "message": "Your request is in the queue and will be processed soon."} + elif result_json.get("status") == "processing": + return {"status": "processing", "message": "Your request is being processed."} + else: + return result_json + + raise HTTPException(status_code=404, detail="Result not found") + +async def listen_redis_changes(): + pubsub = redis_client.pubsub() + pubsub.psubscribe('__keyspace@5__:*_result:*') + + for message in pubsub.listen(): + if message['type'] == 'pmessage': + key = message['channel'].decode('utf-8').split(':')[-1] + print(f"Key changed: {key}") + +if __name__ == "__main__": + # 在后台线程中启动Kafka消费者 + consumer_thread = threading.Thread(target=process_task, daemon=True) + consumer_thread.start() + + import uvicorn + uvicorn.run(app, host="0.0.0.0", port=7000) \ No newline at end of file diff --git a/api_old/before/face_api.py b/api_old/before/face_api.py new file mode 100644 index 0000000..118babe --- /dev/null +++ b/api_old/before/face_api.py @@ -0,0 +1,312 @@ +from fastapi import FastAPI, HTTPException, File, UploadFile +from fastapi.responses import StreamingResponse, JSONResponse +from fastapi.middleware.cors import CORSMiddleware +from ultralytics import YOLO +import cv2 +import numpy as np +import json +import uvicorn +from kafka import KafkaProducer, KafkaConsumer +from redis import Redis +import io +import uuid +import os +from datetime import datetime, timedelta +import threading +import torch +torch.cuda.set_device(1) + + +app = FastAPI() +face_app = FastAPI() +app.mount("/face", face_app) + +# CORS配置 +ALLOWED_ORIGINS = 'https://beta.obscura.work' + +# 只为主应用添加CORS中间件 +app.add_middleware( + CORSMiddleware, + allow_origins=ALLOWED_ORIGINS, + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + +# 配置 +MODEL_PATH = "/home/zydi/models/yolov8n-face.pt" # 请替换为您的模型路径 +KAFKA_BROKER = "222.186.10.253:9092" +KAFKA_TOPIC = "face" # 指定Kafka topic +KAFKA_GROUP_ID = "face_group" # 指定消费者组ID + +REDIS_HOST = "222.186.10.253" +REDIS_PORT = 6379 +REDIS_PASSWORD = "Obscura@2024" +REDIS_DB = 7 + +UPLOAD_DIR = "/www/wwwroot/beta.obscura.work/upload_files/upload" +RESULT_DIR = "/www/wwwroot/beta.obscura.work/upload_files/result" +MAX_FILE_AGE = timedelta(hours=1) + +# 确保目录存在 +os.makedirs(UPLOAD_DIR, exist_ok=True) +os.makedirs(RESULT_DIR, exist_ok=True) + +# 初始化 Kafka +producer = KafkaProducer(bootstrap_servers=[KAFKA_BROKER]) +consumer = KafkaConsumer( + KAFKA_TOPIC, + bootstrap_servers=[KAFKA_BROKER], + group_id=KAFKA_GROUP_ID, + auto_offset_reset='earliest', + enable_auto_commit=True, + value_deserializer=lambda x: json.loads(x.decode('utf-8')) +) + +# 初始化 Redis +redis_client = Redis( + host=REDIS_HOST, + port=REDIS_PORT, + password=REDIS_PASSWORD, + db=REDIS_DB +) + +class faceDetector: + def __init__(self, model_path): + self.model = YOLO(model_path).to('cuda:1') + + def detect(self, frame): + results = self.model(frame, device='cuda:1') + return results + + def format_results(self, results): + formatted_results = [] + for r in results: + boxes = r.boxes + keypoints = r.keypoints + for i in range(len(boxes)): + box = boxes[i] + kpts = keypoints[i] + formatted_results.append({ + "bbox": box.xyxy.tolist()[0], + "confidence": box.conf.item(), + "keypoints": kpts.xy.tolist()[0] + }) + return formatted_results + + def draw_results(self, frame, results, original_shape): + for r in results: + annotated_frame = r.plot(img=frame) + # 调整坐标以适应原始图像大小 + h, w = annotated_frame.shape[:2] + scale_x, scale_y = original_shape[1] / w, original_shape[0] / h + annotated_frame = cv2.resize(annotated_frame, (original_shape[1], original_shape[0])) + return annotated_frame + +detector = faceDetector(MODEL_PATH) + +def process_image(image_data, filename): + try: + img = cv2.imdecode(np.frombuffer(image_data, np.uint8), cv2.IMREAD_COLOR) + original_shape = img.shape + # Convert BGR to RGB + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + + # Resize image to fit model requirements (640x640) + img_resized = cv2.resize(img, (640, 640)) + + # Normalize and reshape to BCHW format + img_tensor = torch.from_numpy(img_resized).float().permute(2, 0, 1).unsqueeze(0) / 255.0 + img_tensor = img_tensor.to('cuda:1') + + results = detector.detect(img_tensor) + + # Format results for JSON + json_results = detector.format_results(results) + + # Draw results on original image + annotated_img = detector.draw_results(img_resized, results, original_shape) + + # Save annotated image + annotated_filename = f"face_{filename}" + annotated_path = os.path.join(RESULT_DIR, annotated_filename) + cv2.imwrite(annotated_path, cv2.cvtColor(annotated_img, cv2.COLOR_RGB2BGR)) + + return json_results, annotated_filename + except Exception as e: + print(f"Error processing image: {str(e)}") + return None, None + +def process_video(video_data, filename): + try: + temp_video_path = os.path.join(UPLOAD_DIR, f"face_{filename}") + with open(temp_video_path, 'wb') as temp_video: + temp_video.write(video_data) + + cap = cv2.VideoCapture(temp_video_path) + frame_count = 0 + json_results = [] + + # Get video properties + fps = int(cap.get(cv2.CAP_PROP_FPS)) + width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + original_shape = (height, width) + + # Create output video file + annotated_filename = f"face_{filename}" + output_path = os.path.join(RESULT_DIR, annotated_filename) + out = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height)) + + while cap.isOpened(): + ret, frame = cap.read() + if not ret: + break + + # Process one frame per second + if frame_count % fps == 0: + # Convert BGR to RGB + frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + + # Resize frame to fit model requirements (640x640) + frame_resized = cv2.resize(frame_rgb, (640, 640)) + + # Normalize and reshape to BCHW format + frame_tensor = torch.from_numpy(frame_resized).float().permute(2, 0, 1).unsqueeze(0) / 255.0 + frame_tensor = frame_tensor.to('cuda:1') + + results = detector.detect(frame_tensor) + frame_json_results = detector.format_results(results) + json_results.append({"frame": frame_count, "detections": frame_json_results}) + + # Draw results on original frame + annotated_frame = detector.draw_results(frame_resized, results, original_shape) + # Convert RGB back to BGR for OpenCV + annotated_frame = cv2.cvtColor(annotated_frame, cv2.COLOR_RGB2BGR) + else: + annotated_frame = frame + + out.write(annotated_frame) + frame_count += 1 + + cap.release() + out.release() + + # Clean up temporary input video file + os.remove(temp_video_path) + + return json_results, annotated_filename + except Exception as e: + print(f"Error processing video: {str(e)}") + return None, None + +@face_app.post("/upload") +async def upload_file(file: UploadFile = File(...)): + content = await file.read() + file_extension = os.path.splitext(file.filename)[1].lower() + new_filename = f"{uuid.uuid4()}{file_extension}" + + # Save the original file + original_file_path = os.path.join(UPLOAD_DIR, new_filename) + with open(original_file_path, "wb") as f: + f.write(content) + + # Send processing task to Kafka + producer.send(KAFKA_TOPIC, json.dumps({ + "filename": new_filename, + "file_type": "image" if file_extension in ['.jpg', '.jpeg', '.png'] else "video" + }).encode('utf-8')) + + # Set initial status in Redis + redis_key = f"face_result:{new_filename}" + redis_client.set(redis_key, json.dumps({"status": "queued"})) + + return JSONResponse(content={"message": "File uploaded and queued for processing", "filename": new_filename}) + +@face_app.get("/result/{filename}") +async def get_face_result(filename: str): + redis_key = f"face_result:{filename}" + result = redis_client.get(redis_key) + if result: + result_data = json.loads(result) + return JSONResponse(content=result_data) # 直接返回整个结果,包括 status + else: + raise HTTPException(status_code=404, detail="Result not found") + +@face_app.get("/annotated/{filename}") +async def get_annotated_file(filename: str): + redis_key = f"face_result:{filename}" + result = redis_client.get(redis_key) + if result: + result_data = json.loads(result) + if result_data["status"] == "completed": + annotated_filename = result_data["annotated_filename"] + file_path = os.path.join(RESULT_DIR, annotated_filename) + if os.path.exists(file_path): + def iterfile(): + with open(file_path, mode="rb") as file_like: + yield from file_like + file_extension = os.path.splitext(annotated_filename)[1].lower() + return StreamingResponse(iterfile(), media_type=f"image/{file_extension[1:]}" if file_extension in ['.jpg', '.jpeg', '.png'] else "video/mp4") + + raise HTTPException(status_code=404, detail="Annotated file not found") + +def process_task(): + for message in consumer: + task = message.value + filename = task['filename'] + file_type = task['file_type'] + + file_path = os.path.join(UPLOAD_DIR, filename) + + # Update status to "processing" + redis_key = f"face_result:{filename}" + redis_client.set(redis_key, json.dumps({"status": "processing"})) + + try: + if file_type == "image": + with open(file_path, 'rb') as f: + content = f.read() + json_results, annotated_filename = process_image(content, filename) + else: + with open(file_path, 'rb') as f: + content = f.read() + json_results, annotated_filename = process_video(content, filename) + + if json_results and annotated_filename: + redis_client.set(redis_key, json.dumps({ + "json_results": json_results, + "status": "completed", + "annotated_filename": annotated_filename + })) + else: + redis_client.set(redis_key, json.dumps({"status": "failed"})) + except Exception as e: + print(f"Error processing task: {str(e)}") + redis_client.set(redis_key, json.dumps({"status": "failed", "error": str(e)})) + +def listen_redis_changes(): + pubsub = redis_client.pubsub() + pubsub.psubscribe('__keyspace@3__:face_result:*') # 监听所有face_result键的变化 + + for message in pubsub.listen(): + if message['type'] == 'pmessage': + key = message['channel'].decode('utf-8').split(':')[-1] + operation = message['data'].decode('utf-8') + + if operation == 'set': + value = redis_client.get(f"face_result:{key}") + if value: + result = json.loads(value) + print(f"Status update for {key}: {result['status']}") + + # 这里可以添加其他处理逻辑,比如发送通知等 + +if __name__ == "__main__": + # 启动处理任务的线程 + threading.Thread(target=process_task, daemon=True).start() + + # 启动Redis监听线程 + threading.Thread(target=listen_redis_changes, daemon=True).start() + + uvicorn.run(app, host="0.0.0.0", port=7004) \ No newline at end of file diff --git a/api_old/before/face_key.py b/api_old/before/face_key.py new file mode 100644 index 0000000..5ad8e29 --- /dev/null +++ b/api_old/before/face_key.py @@ -0,0 +1,462 @@ +from fastapi import FastAPI, HTTPException, File, UploadFile, Depends, Header +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import JSONResponse, StreamingResponse +from fastapi.security import APIKeyHeader +from ultralytics import YOLO +import cv2 +import numpy as np +import json +import uvicorn +from kafka import KafkaProducer, KafkaConsumer +from redis import Redis +import uuid +import os +from datetime import datetime, timedelta, timezone +import threading +import torch +torch.cuda.set_device(1) + + +app = FastAPI() +face_app = FastAPI() +app.mount("/face", face_app) + +# CORS配置 +ALLOWED_ORIGINS = 'https://beta.obscura.work' + +# 只为主应用添加CORS中间件 +app.add_middleware( + CORSMiddleware, + allow_origins=ALLOWED_ORIGINS, + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + +# 配置 +MODEL_PATH = "/home/zydi/models/yolov8n-face.pt" # 请替换为您的模型路径 +KAFKA_BROKER = "222.186.10.253:9092" +KAFKA_TOPIC = "face" # 指定Kafka topic +KAFKA_GROUP_ID = "face_group" # 指定消费者组ID + +REDIS_HOST = "222.186.10.253" +REDIS_PORT = 6379 +REDIS_PASSWORD = "Obscura@2024" +REDIS_DB = 7 +REDIS_API_DB = 12 +REDIS_API_USAGE_DB = 13 + +UPLOAD_DIR = "/www/wwwroot/beta.obscura.work/upload_files/upload" +RESULT_DIR = "/www/wwwroot/beta.obscura.work/upload_files/result" +MAX_FILE_AGE = timedelta(hours=1) + +# 确保目录存在 +os.makedirs(UPLOAD_DIR, exist_ok=True) +os.makedirs(RESULT_DIR, exist_ok=True) + +# 初始化 Kafka +producer = KafkaProducer(bootstrap_servers=[KAFKA_BROKER]) +consumer = KafkaConsumer( + KAFKA_TOPIC, + bootstrap_servers=[KAFKA_BROKER], + group_id=KAFKA_GROUP_ID, + auto_offset_reset='earliest', + enable_auto_commit=True, + value_deserializer=lambda x: json.loads(x.decode('utf-8')) +) + +# 初始化 Redis +redis_client = Redis( + host=REDIS_HOST, + port=REDIS_PORT, + password=REDIS_PASSWORD, + db=REDIS_DB +) + +redis_api_client = Redis( + host=REDIS_HOST, + port=REDIS_PORT, + password=REDIS_PASSWORD, + db=REDIS_API_DB +) +redis_api_usage_client = Redis( + host=REDIS_HOST, + port=REDIS_PORT, + password=REDIS_PASSWORD, + db=REDIS_API_USAGE_DB +) + +# 添加API密钥验证 +API_KEY_NAME = "X-API-Key" +api_key_header = APIKeyHeader(name=API_KEY_NAME, auto_error=False) + +async def get_api_key(api_key: str = Header(None, alias=API_KEY_NAME)): + if api_key is None: + raise HTTPException(status_code=400, detail="API密钥缺失") + return api_key + +def calculate_tokens(file_path: str, file_type: str) -> int: + base_tokens = 0 + + try: + file_size = os.path.getsize(file_path) # 获取文件大小(字节) + + # 基础token:每MB文件大小消耗10个token + base_tokens = int((file_size / (1024 * 1024)) * 10) + + if file_type == "image": + img = cv2.imread(file_path) + if img is None: + raise ValueError("无法读取图片文件") + height, width = img.shape[:2] + pixel_count = height * width + + # 图片token:每100个像素额外消耗5个token + image_tokens = int((pixel_count / 10000) * 5) + + base_tokens += image_tokens + + elif file_type == "video": + cap = cv2.VideoCapture(file_path) + if not cap.isOpened(): + raise ValueError("无法打开视频文件") + fps = cap.get(cv2.CAP_PROP_FPS) + frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + cap.release() + + pixel_count = width * height * frame_count + duration = frame_count / fps # 视频时长(秒) + + # 视频token:每100万像素每秒额外消耗1个token + video_tokens = int((pixel_count / 10000) * (duration / 60)) + + base_tokens += video_tokens + + return max(1, base_tokens) # 确保至少返回1个token + except Exception as e: + print(f"计算token时出错: {str(e)}") + return 1 # 出错时返回默认值1 + + +async def verify_api_key(api_key: str = Depends(get_api_key)): + redis_key = f"api_key:{api_key}" + + api_key_info = redis_api_client.hgetall(redis_key) + + if not api_key_info: + return None + + api_key_info = {k.decode('utf-8'): v.decode('utf-8') for k, v in api_key_info.items()} + + if api_key_info.get('is_active') != '1': + return None + + expires_at = datetime.fromisoformat(api_key_info.get('expires_at')) + if datetime.now(timezone.utc) > expires_at: + return None + + usage_info = redis_api_usage_client.hgetall(redis_key) + usage_info = {k.decode('utf-8'): v.decode('utf-8') for k, v in usage_info.items()} + + return { + **api_key_info, + **usage_info + } + +async def update_token_usage(api_key: str, new_tokens_used: int, model_name: str): + redis_key = f"api_key:{api_key}" + current_time = datetime.now(timezone.utc).isoformat() + + pipe = redis_api_usage_client.pipeline() + + # 更新总的token使用量 + pipe.hincrby(redis_key, "tokens_used", new_tokens_used) + pipe.hset(redis_key, "last_used_at", current_time) + + # 更新特定模型的token使用量 + model_tokens_field = f"{model_name}_tokens_used" + model_last_used_field = f"{model_name}_last_used_at" + + pipe.hincrby(redis_key, model_tokens_field, new_tokens_used) + pipe.hset(redis_key, model_last_used_field, current_time) + + pipe.execute() + +class faceDetector: + def __init__(self, model_path): + self.model = YOLO(model_path).to('cuda:1') + + def detect(self, frame): + results = self.model(frame, device='cuda:1') + return results + + def format_results(self, results): + formatted_results = [] + for r in results: + boxes = r.boxes + keypoints = r.keypoints + for i in range(len(boxes)): + box = boxes[i] + kpts = keypoints[i] + formatted_results.append({ + "bbox": box.xyxy.tolist()[0], + "confidence": box.conf.item(), + "keypoints": kpts.xy.tolist()[0] + }) + return formatted_results + + def draw_results(self, frame, results, original_shape): + for r in results: + annotated_frame = r.plot(img=frame) + # 调整坐标以适应原始图像大小 + h, w = annotated_frame.shape[:2] + scale_x, scale_y = original_shape[1] / w, original_shape[0] / h + annotated_frame = cv2.resize(annotated_frame, (original_shape[1], original_shape[0])) + return annotated_frame + +detector = faceDetector(MODEL_PATH) + +def process_image(image_data, filename): + try: + img = cv2.imdecode(np.frombuffer(image_data, np.uint8), cv2.IMREAD_COLOR) + original_shape = img.shape + # Convert BGR to RGB + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + + # Resize image to fit model requirements (640x640) + img_resized = cv2.resize(img, (640, 640)) + + # Normalize and reshape to BCHW format + img_tensor = torch.from_numpy(img_resized).float().permute(2, 0, 1).unsqueeze(0) / 255.0 + img_tensor = img_tensor.to('cuda:1') + + results = detector.detect(img_tensor) + + # Format results for JSON + json_results = detector.format_results(results) + + # Draw results on original image + annotated_img = detector.draw_results(img_resized, results, original_shape) + + # Save annotated image + annotated_filename = f"face_{filename}" + annotated_path = os.path.join(RESULT_DIR, annotated_filename) + cv2.imwrite(annotated_path, cv2.cvtColor(annotated_img, cv2.COLOR_RGB2BGR)) + + return json_results, annotated_filename + except Exception as e: + print(f"Error processing image: {str(e)}") + return None, None + +def process_video(video_data, filename): + try: + temp_video_path = os.path.join(UPLOAD_DIR, f"face_{filename}") + with open(temp_video_path, 'wb') as temp_video: + temp_video.write(video_data) + + cap = cv2.VideoCapture(temp_video_path) + frame_count = 0 + json_results = [] + + # Get video properties + fps = int(cap.get(cv2.CAP_PROP_FPS)) + width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + original_shape = (height, width) + + # Create output video file + annotated_filename = f"face_{filename}" + output_path = os.path.join(RESULT_DIR, annotated_filename) + out = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height)) + + while cap.isOpened(): + ret, frame = cap.read() + if not ret: + break + + # Process one frame per second + if frame_count % fps == 0: + # Convert BGR to RGB + frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + + # Resize frame to fit model requirements (640x640) + frame_resized = cv2.resize(frame_rgb, (640, 640)) + + # Normalize and reshape to BCHW format + frame_tensor = torch.from_numpy(frame_resized).float().permute(2, 0, 1).unsqueeze(0) / 255.0 + frame_tensor = frame_tensor.to('cuda:1') + + results = detector.detect(frame_tensor) + frame_json_results = detector.format_results(results) + json_results.append({"frame": frame_count, "detections": frame_json_results}) + + # Draw results on original frame + annotated_frame = detector.draw_results(frame_resized, results, original_shape) + # Convert RGB back to BGR for OpenCV + annotated_frame = cv2.cvtColor(annotated_frame, cv2.COLOR_RGB2BGR) + else: + annotated_frame = frame + + out.write(annotated_frame) + frame_count += 1 + + cap.release() + out.release() + + # Clean up temporary input video file + os.remove(temp_video_path) + + return json_results, annotated_filename + except Exception as e: + print(f"Error processing video: {str(e)}") + return None, None + +@face_app.post("/upload") +async def upload_file(file: UploadFile = File(...),api_key: str = Depends(get_api_key)): + # 验证 API key + api_key_info = await verify_api_key(api_key) + if not api_key_info: + raise HTTPException(status_code=403, detail="无效的API密钥") + + content = await file.read() + file_extension = os.path.splitext(file.filename)[1].lower() + new_filename = f"{uuid.uuid4()}{file_extension}" + + # Save the original file + original_file_path = os.path.join(UPLOAD_DIR, new_filename) + with open(original_file_path, "wb") as f: + f.write(content) + + # 计算 token + file_type = "image" if file_extension in ['.jpg', '.jpeg', '.png'] else "video" + tokens_required = calculate_tokens(original_file_path, file_type) + if tokens_required is None or tokens_required <= 0: + raise HTTPException(status_code=500, detail="无法计算所需的token数量") + + + # 检查并更新 token 使用量 + usage_key = f"api_key:{api_key}" + total_tokens = int(redis_api_usage_client.hget(usage_key, "total_tokens") or 0) + tokens_used = int(redis_api_usage_client.hget(usage_key, "tokens_used") or 0) + + if tokens_used + tokens_required > total_tokens: + raise HTTPException(status_code=403, detail="Token 余额不足") + + # 更新 token 使用量 + model_name = "yolov8n-face" + await update_token_usage(api_key, tokens_required, model_name) + + + # Send processing task to Kafka + producer.send(KAFKA_TOPIC, json.dumps({ + "filename": new_filename, + "file_type": "image" if file_extension in ['.jpg', '.jpeg', '.png'] else "video" + }).encode('utf-8')) + + # Set initial status in Redis + redis_key = f"face_result:{new_filename}" + redis_client.set(redis_key, json.dumps({"status": "queued"})) + + # 获取更新后的 token 使用情况 + updated_api_key_info = await verify_api_key(api_key) + new_tokens_used = int(updated_api_key_info.get("tokens_used", 0)) + model_tokens_used = int(updated_api_key_info.get(f"{model_name}_tokens_used", 0)) + + return JSONResponse(content={ + "message": "文件已上传并排队等待处理", + "filename": new_filename, + "tokens_used": tokens_required, + "total_tokens_used": new_tokens_used, + f"{model_name}_tokens_used": model_tokens_used, + "tokens_remaining": total_tokens - new_tokens_used + }) + +@face_app.get("/result/{filename}", dependencies=[Depends(verify_api_key)]) +async def get_face_result(filename: str): + redis_key = f"face_result:{filename}" + result = redis_client.get(redis_key) + if result: + result_data = json.loads(result) + return JSONResponse(content=result_data) # 直接返回整个结果,包括 status + else: + raise HTTPException(status_code=404, detail="Result not found") + +@face_app.get("/annotated/{filename}", dependencies=[Depends(verify_api_key)]) +async def get_annotated_file(filename: str): + redis_key = f"face_result:{filename}" + result = redis_client.get(redis_key) + if result: + result_data = json.loads(result) + if result_data["status"] == "completed": + annotated_filename = result_data["annotated_filename"] + file_path = os.path.join(RESULT_DIR, annotated_filename) + if os.path.exists(file_path): + def iterfile(): + with open(file_path, mode="rb") as file_like: + yield from file_like + file_extension = os.path.splitext(annotated_filename)[1].lower() + return StreamingResponse(iterfile(), media_type=f"image/{file_extension[1:]}" if file_extension in ['.jpg', '.jpeg', '.png'] else "video/mp4") + + raise HTTPException(status_code=404, detail="Annotated file not found") + +def process_task(): + for message in consumer: + task = message.value + filename = task['filename'] + file_type = task['file_type'] + + file_path = os.path.join(UPLOAD_DIR, filename) + + # Update status to "processing" + redis_key = f"face_result:{filename}" + redis_client.set(redis_key, json.dumps({"status": "processing"})) + + try: + if file_type == "image": + with open(file_path, 'rb') as f: + content = f.read() + json_results, annotated_filename = process_image(content, filename) + else: + with open(file_path, 'rb') as f: + content = f.read() + json_results, annotated_filename = process_video(content, filename) + + if json_results and annotated_filename: + redis_client.set(redis_key, json.dumps({ + "json_results": json_results, + "status": "completed", + "annotated_filename": annotated_filename + })) + else: + redis_client.set(redis_key, json.dumps({"status": "failed"})) + except Exception as e: + print(f"Error processing task: {str(e)}") + redis_client.set(redis_key, json.dumps({"status": "failed", "error": str(e)})) + +def listen_redis_changes(): + pubsub = redis_client.pubsub() + pubsub.psubscribe('__keyspace@3__:face_result:*') # 监听所有face_result键的变化 + + for message in pubsub.listen(): + if message['type'] == 'pmessage': + key = message['channel'].decode('utf-8').split(':')[-1] + operation = message['data'].decode('utf-8') + + if operation == 'set': + value = redis_client.get(f"face_result:{key}") + if value: + result = json.loads(value) + print(f"Status update for {key}: {result['status']}") + + # 这里可以添加其他处理逻辑,比如发送通知等 + +if __name__ == "__main__": + # 启动处理任务的线程 + threading.Thread(target=process_task, daemon=True).start() + + # 启动Redis监听线程 + threading.Thread(target=listen_redis_changes, daemon=True).start() + + uvicorn.run(app, host="0.0.0.0", port=7004) \ No newline at end of file diff --git a/api_old/before/face_key2.py b/api_old/before/face_key2.py new file mode 100644 index 0000000..02e40e4 --- /dev/null +++ b/api_old/before/face_key2.py @@ -0,0 +1,471 @@ +from fastapi import FastAPI, HTTPException, File, UploadFile, Depends, Security +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import JSONResponse, StreamingResponse +from fastapi.security import APIKeyHeader +from ultralytics import YOLO +import cv2 +import numpy as np +import json +import uvicorn +from kafka import KafkaProducer, KafkaConsumer +from redis import Redis +import uuid +import os +from datetime import datetime, timedelta, timezone +import threading +import torch + +import string +torch.cuda.set_device(1) + + +app = FastAPI() +face_app = FastAPI() +app.mount("/face", face_app) + +# CORS配置 +ALLOWED_ORIGINS = 'https://beta.obscura.work' + +# 只为主应用添加CORS中间件 +app.add_middleware( + CORSMiddleware, + allow_origins=ALLOWED_ORIGINS, + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + +# 配置 +MODEL_PATH = "/home/zydi/models/yolov8n-face.pt" # 请替换为您的模型路径 +KAFKA_BROKER = "222.186.10.253:9092" +KAFKA_TOPIC = "face" # 指定Kafka topic +KAFKA_GROUP_ID = "face_group" # 指定消费者组ID + +REDIS_HOST = "222.186.10.253" +REDIS_PORT = 6379 +REDIS_PASSWORD = "Obscura@2024" +REDIS_DB = 7 +REDIS_API_DB = 12 +REDIS_API_USAGE_DB = 13 + +UPLOAD_DIR = "/www/wwwroot/beta.obscura.work/upload_files/upload" +RESULT_DIR = "/www/wwwroot/beta.obscura.work/upload_files/result" +MAX_FILE_AGE = timedelta(hours=1) + +# 确保目录存在 +os.makedirs(UPLOAD_DIR, exist_ok=True) +os.makedirs(RESULT_DIR, exist_ok=True) + +# 初始化 Kafka +producer = KafkaProducer(bootstrap_servers=[KAFKA_BROKER]) +consumer = KafkaConsumer( + KAFKA_TOPIC, + bootstrap_servers=[KAFKA_BROKER], + group_id=KAFKA_GROUP_ID, + auto_offset_reset='earliest', + enable_auto_commit=True, + value_deserializer=lambda x: json.loads(x.decode('utf-8')) +) + +# 初始化 Redis +redis_client = Redis( + host=REDIS_HOST, + port=REDIS_PORT, + password=REDIS_PASSWORD, + db=REDIS_DB +) + +redis_api_client = Redis( + host=REDIS_HOST, + port=REDIS_PORT, + password=REDIS_PASSWORD, + db=REDIS_API_DB +) +redis_api_usage_client = Redis( + host=REDIS_HOST, + port=REDIS_PORT, + password=REDIS_PASSWORD, + db=REDIS_API_USAGE_DB +) + +# 定义API密钥头部 +api_key_header = APIKeyHeader(name="Authorization", auto_error=False) + +# 定义 base62 字符集 +BASE62 = string.digits + string.ascii_letters + +# 验证API密钥的函数 +async def get_api_key(api_key: str = Security(api_key_header)): + if api_key and api_key.startswith("Bearer "): + key = api_key.split(" ")[1] + if key.startswith("obs-"): + return key + raise HTTPException( + status_code=401, + detail="无效的API密钥", + headers={"WWW-Authenticate": "Bearer"}, + ) + + +def calculate_tokens(file_path: str, file_type: str) -> int: + base_tokens = 0 + + try: + file_size = os.path.getsize(file_path) # 获取文件大小(字节) + + # 基础token:每MB文件大小消耗10个token + base_tokens = int((file_size / (1024 * 1024)) * 10) + + if file_type == "image": + img = cv2.imread(file_path) + if img is None: + raise ValueError("无法读取图片文件") + height, width = img.shape[:2] + pixel_count = height * width + + # 图片token:每100个像素额外消耗5个token + image_tokens = int((pixel_count / 10000) * 5) + + base_tokens += image_tokens + + elif file_type == "video": + cap = cv2.VideoCapture(file_path) + if not cap.isOpened(): + raise ValueError("无法打开视频文件") + fps = cap.get(cv2.CAP_PROP_FPS) + frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + cap.release() + + pixel_count = width * height * frame_count + duration = frame_count / fps # 视频时长(秒) + + # 视频token:每100万像素每秒额外消耗1个token + video_tokens = int((pixel_count / 10000) * (duration / 60)) + + base_tokens += video_tokens + + return max(1, base_tokens) # 确保至少返回1个token + except Exception as e: + print(f"计算token时出错: {str(e)}") + return 1 # 出错时返回默认值1 + + +async def verify_api_key(api_key: str = Depends(get_api_key)): + redis_key = f"api_key:{api_key}" + + api_key_info = redis_api_client.hgetall(redis_key) + + if not api_key_info: + raise HTTPException(status_code=401, detail="无效的API密钥") + + api_key_info = {k.decode('utf-8'): v.decode('utf-8') for k, v in api_key_info.items()} + + if api_key_info.get('is_active') != '1': + raise HTTPException(status_code=401, detail="API密钥已停用") + + expires_at = datetime.fromisoformat(api_key_info.get('expires_at')) + if datetime.now(timezone.utc) > expires_at: + raise HTTPException(status_code=401, detail="API密钥已过期") + + usage_info = redis_api_usage_client.hgetall(redis_key) + usage_info = {k.decode('utf-8'): v.decode('utf-8') for k, v in usage_info.items()} + + return { + "api_key": api_key, + **api_key_info, + **usage_info + } +async def update_token_usage(api_key: str, new_tokens_used: int, model_name: str): + redis_key = f"api_key:{api_key}" + current_time = datetime.now(timezone.utc).isoformat() + + pipe = redis_api_usage_client.pipeline() + + # 更新总的token使用量 + pipe.hincrby(redis_key, "tokens_used", new_tokens_used) + pipe.hset(redis_key, "last_used_at", current_time) + + # 更新特定模型的token使用量 + model_tokens_field = f"{model_name}_tokens_used" + model_last_used_field = f"{model_name}_last_used_at" + + pipe.hincrby(redis_key, model_tokens_field, new_tokens_used) + pipe.hset(redis_key, model_last_used_field, current_time) + + pipe.execute() + +class faceDetector: + def __init__(self, model_path): + self.model = YOLO(model_path).to('cuda:1') + + def detect(self, frame): + results = self.model(frame, device='cuda:1') + return results + + def format_results(self, results): + formatted_results = [] + for r in results: + boxes = r.boxes + keypoints = r.keypoints + for i in range(len(boxes)): + box = boxes[i] + kpts = keypoints[i] + formatted_results.append({ + "bbox": box.xyxy.tolist()[0], + "confidence": box.conf.item(), + "keypoints": kpts.xy.tolist()[0] + }) + return formatted_results + + def draw_results(self, frame, results, original_shape): + for r in results: + annotated_frame = r.plot(img=frame) + # 调整坐标以适应原始图像大小 + h, w = annotated_frame.shape[:2] + scale_x, scale_y = original_shape[1] / w, original_shape[0] / h + annotated_frame = cv2.resize(annotated_frame, (original_shape[1], original_shape[0])) + return annotated_frame + +detector = faceDetector(MODEL_PATH) + +def process_image(image_data, filename): + try: + img = cv2.imdecode(np.frombuffer(image_data, np.uint8), cv2.IMREAD_COLOR) + original_shape = img.shape + # Convert BGR to RGB + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + + # Resize image to fit model requirements (640x640) + img_resized = cv2.resize(img, (640, 640)) + + # Normalize and reshape to BCHW format + img_tensor = torch.from_numpy(img_resized).float().permute(2, 0, 1).unsqueeze(0) / 255.0 + img_tensor = img_tensor.to('cuda:1') + + results = detector.detect(img_tensor) + + # Format results for JSON + json_results = detector.format_results(results) + + # Draw results on original image + annotated_img = detector.draw_results(img_resized, results, original_shape) + + # Save annotated image + annotated_filename = f"face_{filename}" + annotated_path = os.path.join(RESULT_DIR, annotated_filename) + cv2.imwrite(annotated_path, cv2.cvtColor(annotated_img, cv2.COLOR_RGB2BGR)) + + return json_results, annotated_filename + except Exception as e: + print(f"Error processing image: {str(e)}") + return None, None + +def process_video(video_data, filename): + try: + temp_video_path = os.path.join(UPLOAD_DIR, f"face_{filename}") + with open(temp_video_path, 'wb') as temp_video: + temp_video.write(video_data) + + cap = cv2.VideoCapture(temp_video_path) + frame_count = 0 + json_results = [] + + # Get video properties + fps = int(cap.get(cv2.CAP_PROP_FPS)) + width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + original_shape = (height, width) + + # Create output video file + annotated_filename = f"face_{filename}" + output_path = os.path.join(RESULT_DIR, annotated_filename) + out = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height)) + + while cap.isOpened(): + ret, frame = cap.read() + if not ret: + break + + # Process one frame per second + if frame_count % fps == 0: + # Convert BGR to RGB + frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + + # Resize frame to fit model requirements (640x640) + frame_resized = cv2.resize(frame_rgb, (640, 640)) + + # Normalize and reshape to BCHW format + frame_tensor = torch.from_numpy(frame_resized).float().permute(2, 0, 1).unsqueeze(0) / 255.0 + frame_tensor = frame_tensor.to('cuda:1') + + results = detector.detect(frame_tensor) + frame_json_results = detector.format_results(results) + json_results.append({"frame": frame_count, "detections": frame_json_results}) + + # Draw results on original frame + annotated_frame = detector.draw_results(frame_resized, results, original_shape) + # Convert RGB back to BGR for OpenCV + annotated_frame = cv2.cvtColor(annotated_frame, cv2.COLOR_RGB2BGR) + else: + annotated_frame = frame + + out.write(annotated_frame) + frame_count += 1 + + cap.release() + out.release() + + # Clean up temporary input video file + os.remove(temp_video_path) + + return json_results, annotated_filename + except Exception as e: + print(f"Error processing video: {str(e)}") + return None, None + +@face_app.post("/upload") +async def upload_file(file: UploadFile = File(...), api_key_info: dict = Depends(verify_api_key)): + + content = await file.read() + file_extension = os.path.splitext(file.filename)[1].lower() + new_filename = f"{uuid.uuid4()}{file_extension}" + + # Save the original file + original_file_path = os.path.join(UPLOAD_DIR, new_filename) + with open(original_file_path, "wb") as f: + f.write(content) + + # 计算 token + file_type = "image" if file_extension in ['.jpg', '.jpeg', '.png'] else "video" + tokens_required = calculate_tokens(original_file_path, file_type) + if tokens_required is None or tokens_required <= 0: + raise HTTPException(status_code=500, detail="无法计算所需的token数量") + + + # 检查并更新 token 使用量 + api_key = api_key_info['api_key'] + usage_key = f"api_key:{api_key}" + total_tokens = int(redis_api_usage_client.hget(usage_key, "total_tokens") or 0) + tokens_used = int(redis_api_usage_client.hget(usage_key, "tokens_used") or 0) + + if tokens_used + tokens_required > total_tokens: + raise HTTPException(status_code=403, detail="Token 余额不足") + + # 更新 token 使用量 + model_name = "yolov8n-face" + await update_token_usage(api_key, tokens_required, model_name) + + + # Send processing task to Kafka + producer.send(KAFKA_TOPIC, json.dumps({ + "filename": new_filename, + "file_type": "image" if file_extension in ['.jpg', '.jpeg', '.png'] else "video" + }).encode('utf-8')) + + # Set initial status in Redis + redis_key = f"face_result:{new_filename}" + redis_client.set(redis_key, json.dumps({"status": "queued"})) + + # 获取更新后的 token 使用情况 + updated_api_key_info = await verify_api_key(api_key) + new_tokens_used = int(updated_api_key_info.get("tokens_used", 0)) + model_tokens_used = int(updated_api_key_info.get(f"{model_name}_tokens_used", 0)) + + return JSONResponse(content={ + "message": "文件已上传并排队等待处理", + "filename": new_filename, + "tokens_used": tokens_required, + "total_tokens_used": new_tokens_used, + f"{model_name}_tokens_used": model_tokens_used, + "tokens_remaining": total_tokens - new_tokens_used + }) + +@face_app.get("/result/{filename}", dependencies=[Depends(verify_api_key)]) +async def get_face_result(filename: str): + redis_key = f"face_result:{filename}" + result = redis_client.get(redis_key) + if result: + result_data = json.loads(result) + return JSONResponse(content=result_data) # 直接返回整个结果,包括 status + else: + raise HTTPException(status_code=404, detail="Result not found") + +@face_app.get("/annotated/{filename}", dependencies=[Depends(verify_api_key)]) +async def get_annotated_file(filename: str): + redis_key = f"face_result:{filename}" + result = redis_client.get(redis_key) + if result: + result_data = json.loads(result) + if result_data["status"] == "completed": + annotated_filename = result_data["annotated_filename"] + file_path = os.path.join(RESULT_DIR, annotated_filename) + if os.path.exists(file_path): + def iterfile(): + with open(file_path, mode="rb") as file_like: + yield from file_like + file_extension = os.path.splitext(annotated_filename)[1].lower() + return StreamingResponse(iterfile(), media_type=f"image/{file_extension[1:]}" if file_extension in ['.jpg', '.jpeg', '.png'] else "video/mp4") + + raise HTTPException(status_code=404, detail="Annotated file not found") + +def process_task(): + for message in consumer: + task = message.value + filename = task['filename'] + file_type = task['file_type'] + + file_path = os.path.join(UPLOAD_DIR, filename) + + # Update status to "processing" + redis_key = f"face_result:{filename}" + redis_client.set(redis_key, json.dumps({"status": "processing"})) + + try: + if file_type == "image": + with open(file_path, 'rb') as f: + content = f.read() + json_results, annotated_filename = process_image(content, filename) + else: + with open(file_path, 'rb') as f: + content = f.read() + json_results, annotated_filename = process_video(content, filename) + + if json_results and annotated_filename: + redis_client.set(redis_key, json.dumps({ + "json_results": json_results, + "status": "completed", + "annotated_filename": annotated_filename + })) + else: + redis_client.set(redis_key, json.dumps({"status": "failed"})) + except Exception as e: + print(f"Error processing task: {str(e)}") + redis_client.set(redis_key, json.dumps({"status": "failed", "error": str(e)})) + +def listen_redis_changes(): + pubsub = redis_client.pubsub() + pubsub.psubscribe('__keyspace@3__:face_result:*') # 监听所有face_result键的变化 + + for message in pubsub.listen(): + if message['type'] == 'pmessage': + key = message['channel'].decode('utf-8').split(':')[-1] + operation = message['data'].decode('utf-8') + + if operation == 'set': + value = redis_client.get(f"face_result:{key}") + if value: + result = json.loads(value) + print(f"Status update for {key}: {result['status']}") + + # 这里可以添加其他处理逻辑,比如发送通知等 + +if __name__ == "__main__": + # 启动处理任务的线程 + threading.Thread(target=process_task, daemon=True).start() + + # 启动Redis监听线程 + threading.Thread(target=listen_redis_changes, daemon=True).start() + + uvicorn.run(app, host="0.0.0.0", port=7004) \ No newline at end of file diff --git a/api_old/before/fall_api.py b/api_old/before/fall_api.py new file mode 100644 index 0000000..d85705a --- /dev/null +++ b/api_old/before/fall_api.py @@ -0,0 +1,293 @@ +from fastapi import FastAPI, HTTPException, File, UploadFile +from fastapi.responses import StreamingResponse, JSONResponse +from fastapi.middleware.cors import CORSMiddleware +from ultralytics import YOLO +import cv2 +import numpy as np +import json +import uvicorn +from kafka import KafkaProducer, KafkaConsumer +from redis import Redis +import io +import uuid +import os +from datetime import datetime, timedelta +import threading + +app = FastAPI() +fall_app = FastAPI() +app.mount("/fall", fall_app) + +# CORS配置 +ALLOWED_ORIGINS = 'https://beta.obscura.work' + +# 只为主应用添加CORS中间件 +app.add_middleware( + CORSMiddleware, + allow_origins=ALLOWED_ORIGINS, + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + +# 配置 +MODEL_PATH = "/home/zydi/models/yolov8n-fall.pt" # 请替换为您的模型路径 +KAFKA_BROKER = "222.186.10.253:9092" +KAFKA_TOPIC = "fall" # 指定Kafka topic +KAFKA_GROUP_ID = "fall_group" # 指定消费者组ID + +REDIS_HOST = "222.186.10.253" +REDIS_PORT = 6379 +REDIS_PASSWORD = "Obscura@2024" +REDIS_DB = 4 + +UPLOAD_DIR = "/www/wwwroot/beta.obscura.work/upload_files/upload" +RESULT_DIR = "/www/wwwroot/beta.obscura.work/upload_files/result" +MAX_FILE_AGE = timedelta(hours=1) + +# 确保目录存在 +os.makedirs(UPLOAD_DIR, exist_ok=True) +os.makedirs(RESULT_DIR, exist_ok=True) + +# 初始化 Kafka +producer = KafkaProducer(bootstrap_servers=[KAFKA_BROKER]) +consumer = KafkaConsumer( + KAFKA_TOPIC, + bootstrap_servers=[KAFKA_BROKER], + group_id=KAFKA_GROUP_ID, + auto_offset_reset='earliest', + enable_auto_commit=True, + value_deserializer=lambda x: json.loads(x.decode('utf-8')) +) + +# 初始化 Redis +redis_client = Redis( + host=REDIS_HOST, + port=REDIS_PORT, + password=REDIS_PASSWORD, + db=REDIS_DB +) + +class fallDetector: + def __init__(self, model_path): + self.model = YOLO(model_path) + def detect(self, frame): + results = self.model(frame) + return results + def format_results(self, results): + formatted_results = [] + for r in results: + if not hasattr(r, 'boxes') or len(r.boxes) == 0: + print("没有检测到任何对象") + return [{"message": "No objects detected"}] + + boxes = r.boxes + names = getattr(r, 'names', {}) + + for i in range(len(boxes)): + box = boxes[i] + if not hasattr(box, 'cls') or not hasattr(box, 'conf') or not hasattr(box, 'xyxy'): + print(f"警告: 第 {i} 个框缺少必要的属性") + continue + + try: + class_id = int(box.cls.item()) + formatted_result = { + "bbox": box.xyxy.tolist()[0], + "confidence": box.conf.item(), + "class_id": class_id, + "class": names.get(class_id, f"Unknown-{class_id}") + } + formatted_results.append(formatted_result) + except Exception as e: + print(f"处理第 {i} 个框时出错: {str(e)}") + + # print("格式化后的结果:", formatted_results) + return formatted_results + + def draw_results(self, frame, results): + for r in results: + annotated_frame = r.plot() + return annotated_frame + +detector = fallDetector(MODEL_PATH) + +def process_image(image_data, filename): + try: + img = cv2.imdecode(np.frombuffer(image_data, np.uint8), cv2.IMREAD_COLOR) + + results = detector.detect(img) + + # Format results for JSON + json_results = detector.format_results(results) + + # Draw results on image + annotated_img = detector.draw_results(img, results) + + # Save annotated image + annotated_filename = f"fall_{filename}" + annotated_path = os.path.join(RESULT_DIR, annotated_filename) + cv2.imwrite(annotated_path, annotated_img) + + return json_results, annotated_filename + except Exception as e: + print(f"Error processing image: {str(e)}") + return None, None + + + +def process_video(video_data, filename): + try: + temp_video_path = os.path.join(UPLOAD_DIR, f"fall_{filename}") + with open(temp_video_path, 'wb') as temp_video: + temp_video.write(video_data) + + cap = cv2.VideoCapture(temp_video_path) + frame_count = 0 + json_results = [] + + # Get video properties + fps = int(cap.get(cv2.CAP_PROP_FPS)) + width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + + # Create output video file + annotated_filename = f"fall_{filename}" + output_path = os.path.join(RESULT_DIR, annotated_filename) + out = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height)) + + while cap.isOpened(): + ret, frame = cap.read() + if not ret: + break + + results = detector.detect(frame) + frame_json_results = detector.format_results(results) + json_results.append({"frame": frame_count, "detections": frame_json_results}) + + annotated_frame = detector.draw_results(frame, results) + out.write(annotated_frame) + + frame_count += 1 + + cap.release() + out.release() + + # Clean up temporary input video file + os.remove(temp_video_path) + + return json_results, annotated_filename + except Exception as e: + print(f"Error processing video: {str(e)}") + return None, None + +@fall_app.post("/upload") +async def upload_file(file: UploadFile = File(...)): + content = await file.read() + file_extension = os.path.splitext(file.filename)[1].lower() + new_filename = f"{uuid.uuid4()}{file_extension}" + + # Save the original file + original_file_path = os.path.join(UPLOAD_DIR, new_filename) + with open(original_file_path, "wb") as f: + f.write(content) + + # Send processing task to Kafka + producer.send(KAFKA_TOPIC, json.dumps({ + "filename": new_filename, + "file_type": "image" if file_extension in ['.jpg', '.jpeg', '.png'] else "video" + }).encode('utf-8')) + + # Set initial status in Redis + redis_key = f"fall_result:{new_filename}" + redis_client.set(redis_key, json.dumps({"status": "queued"})) + + return JSONResponse(content={"message": "File uploaded and queued for processing", "filename": new_filename}) + +@fall_app.get("/result/{filename}") +async def get_fall_result(filename: str): + redis_key = f"fall_result:{filename}" + result = redis_client.get(redis_key) + if result: + result_data = json.loads(result) + return JSONResponse(content=result_data) # 直接返回整个结果,包括 status + else: + raise HTTPException(status_code=404, detail="Result not found") + +@fall_app.get("/annotated/{filename}") +async def get_annotated_file(filename: str): + redis_key = f"fall_result:{filename}" + result = redis_client.get(redis_key) + if result: + result_data = json.loads(result) + if result_data["status"] == "completed": + annotated_filename = result_data["annotated_filename"] + file_path = os.path.join(RESULT_DIR, annotated_filename) + if os.path.exists(file_path): + def iterfile(): + with open(file_path, mode="rb") as file_like: + yield from file_like + file_extension = os.path.splitext(annotated_filename)[1].lower() + return StreamingResponse(iterfile(), media_type=f"image/{file_extension[1:]}" if file_extension in ['.jpg', '.jpeg', '.png'] else "video/mp4") + + raise HTTPException(status_code=404, detail="Annotated file not found") + +def process_task(): + for message in consumer: + task = message.value + filename = task['filename'] + file_type = task['file_type'] + + file_path = os.path.join(UPLOAD_DIR, filename) + + # Update status to "processing" + redis_key = f"fall_result:{filename}" + redis_client.set(redis_key, json.dumps({"status": "processing"})) + + try: + if file_type == "image": + with open(file_path, 'rb') as f: + content = f.read() + json_results, annotated_filename = process_image(content, filename) + else: + with open(file_path, 'rb') as f: + content = f.read() + json_results, annotated_filename = process_video(content, filename) + + if json_results and annotated_filename: + redis_client.set(redis_key, json.dumps({ + "json_results": json_results, + "status": "completed", + "annotated_filename": annotated_filename + })) + else: + redis_client.set(redis_key, json.dumps({"status": "failed"})) + except Exception as e: + print(f"Error processing task: {str(e)}") + redis_client.set(redis_key, json.dumps({"status": "failed", "error": str(e)})) + +def listen_redis_changes(): + pubsub = redis_client.pubsub() + pubsub.psubscribe('__keyspace@3__:fall_result:*') # 监听所有fall_result键的变化 + + for message in pubsub.listen(): + if message['type'] == 'pmessage': + key = message['channel'].decode('utf-8').split(':')[-1] + operation = message['data'].decode('utf-8') + + if operation == 'set': + value = redis_client.get(f"fall_result:{key}") + if value: + result = json.loads(value) + print(f"Status update for {key}: {result['status']}") + + # 这里可以添加其他处理逻辑,比如发送通知等 + +if __name__ == "__main__": + # 启动处理任务的线程 + threading.Thread(target=process_task, daemon=True).start() + + # 启动Redis监听线程 + threading.Thread(target=listen_redis_changes, daemon=True).start() + + uvicorn.run(app, host="0.0.0.0", port=7002) \ No newline at end of file diff --git a/api_old/before/fall_key.py b/api_old/before/fall_key.py new file mode 100644 index 0000000..3894b2b --- /dev/null +++ b/api_old/before/fall_key.py @@ -0,0 +1,442 @@ +from fastapi import FastAPI, HTTPException, File, UploadFile, Depends, Header +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import JSONResponse, StreamingResponse +from fastapi.security import APIKeyHeader +from ultralytics import YOLO +import cv2 +import numpy as np +import json +import uvicorn +from kafka import KafkaProducer, KafkaConsumer +from redis import Redis +import io +import uuid +import os +from datetime import datetime, timedelta, timezone +import threading + +app = FastAPI() +fall_app = FastAPI() +app.mount("/fall", fall_app) + +# CORS配置 +ALLOWED_ORIGINS = 'https://beta.obscura.work' + +# 只为主应用添加CORS中间件 +app.add_middleware( + CORSMiddleware, + allow_origins=ALLOWED_ORIGINS, + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + +# 配置 +MODEL_PATH = "/home/zydi/models/yolov8n-fall.pt" # 请替换为您的模型路径 +KAFKA_BROKER = "222.186.10.253:9092" +KAFKA_TOPIC = "fall" # 指定Kafka topic +KAFKA_GROUP_ID = "fall_group" # 指定消费者组ID + +REDIS_HOST = "222.186.10.253" +REDIS_PORT = 6379 +REDIS_PASSWORD = "Obscura@2024" +REDIS_DB = 4 +REDIS_API_DB = 12 +REDIS_API_USAGE_DB = 13 +UPLOAD_DIR = "/www/wwwroot/beta.obscura.work/upload_files/upload" +RESULT_DIR = "/www/wwwroot/beta.obscura.work/upload_files/result" +MAX_FILE_AGE = timedelta(hours=1) + +# 确保目录存在 +os.makedirs(UPLOAD_DIR, exist_ok=True) +os.makedirs(RESULT_DIR, exist_ok=True) + +# 初始化 Kafka +producer = KafkaProducer(bootstrap_servers=[KAFKA_BROKER]) +consumer = KafkaConsumer( + KAFKA_TOPIC, + bootstrap_servers=[KAFKA_BROKER], + group_id=KAFKA_GROUP_ID, + auto_offset_reset='earliest', + enable_auto_commit=True, + value_deserializer=lambda x: json.loads(x.decode('utf-8')) +) + +# 初始化 Redis +redis_client = Redis( + host=REDIS_HOST, + port=REDIS_PORT, + password=REDIS_PASSWORD, + db=REDIS_DB +) +redis_api_client = Redis( + host=REDIS_HOST, + port=REDIS_PORT, + password=REDIS_PASSWORD, + db=REDIS_API_DB +) +redis_api_usage_client = Redis( + host=REDIS_HOST, + port=REDIS_PORT, + password=REDIS_PASSWORD, + db=REDIS_API_USAGE_DB +) + +# 添加API密钥验证 +API_KEY_NAME = "X-API-Key" +api_key_header = APIKeyHeader(name=API_KEY_NAME, auto_error=False) + +async def get_api_key(api_key: str = Header(None, alias=API_KEY_NAME)): + if api_key is None: + raise HTTPException(status_code=400, detail="API密钥缺失") + return api_key + +def calculate_tokens(file_path: str, file_type: str) -> int: + base_tokens = 0 + + try: + file_size = os.path.getsize(file_path) # 获取文件大小(字节) + + # 基础token:每MB文件大小消耗10个token + base_tokens = int((file_size / (1024 * 1024)) * 10) + + if file_type == "image": + img = cv2.imread(file_path) + if img is None: + raise ValueError("无法读取图片文件") + height, width = img.shape[:2] + pixel_count = height * width + + # 图片token:每100个像素额外消耗5个token + image_tokens = int((pixel_count / 10000) * 5) + + base_tokens += image_tokens + + elif file_type == "video": + cap = cv2.VideoCapture(file_path) + if not cap.isOpened(): + raise ValueError("无法打开视频文件") + fps = cap.get(cv2.CAP_PROP_FPS) + frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + cap.release() + + pixel_count = width * height * frame_count + duration = frame_count / fps # 视频时长(秒) + + # 视频token:每100万像素每秒额外消耗1个token + video_tokens = int((pixel_count / 10000) * (duration / 60)) + + base_tokens += video_tokens + + return max(1, base_tokens) # 确保至少返回1个token + except Exception as e: + print(f"计算token时出错: {str(e)}") + return 1 # 出错时返回默认值1 + + +async def verify_api_key(api_key: str = Depends(get_api_key)): + redis_key = f"api_key:{api_key}" + + api_key_info = redis_api_client.hgetall(redis_key) + + if not api_key_info: + return None + + api_key_info = {k.decode('utf-8'): v.decode('utf-8') for k, v in api_key_info.items()} + + if api_key_info.get('is_active') != '1': + return None + + expires_at = datetime.fromisoformat(api_key_info.get('expires_at')) + if datetime.now(timezone.utc) > expires_at: + return None + + usage_info = redis_api_usage_client.hgetall(redis_key) + usage_info = {k.decode('utf-8'): v.decode('utf-8') for k, v in usage_info.items()} + + return { + **api_key_info, + **usage_info + } + +async def update_token_usage(api_key: str, new_tokens_used: int, model_name: str): + redis_key = f"api_key:{api_key}" + current_time = datetime.now(timezone.utc).isoformat() + + pipe = redis_api_usage_client.pipeline() + + # 更新总的token使用量 + pipe.hincrby(redis_key, "tokens_used", new_tokens_used) + pipe.hset(redis_key, "last_used_at", current_time) + + # 更新特定模型的token使用量 + model_tokens_field = f"{model_name}_tokens_used" + model_last_used_field = f"{model_name}_last_used_at" + + pipe.hincrby(redis_key, model_tokens_field, new_tokens_used) + pipe.hset(redis_key, model_last_used_field, current_time) + + pipe.execute() + +class fallDetector: + def __init__(self, model_path): + self.model = YOLO(model_path) + def detect(self, frame): + results = self.model(frame) + return results + def format_results(self, results): + formatted_results = [] + for r in results: + if not hasattr(r, 'boxes') or len(r.boxes) == 0: + print("没有检测到任何对象") + return [{"message": "No objects detected"}] + + boxes = r.boxes + names = getattr(r, 'names', {}) + + for i in range(len(boxes)): + box = boxes[i] + if not hasattr(box, 'cls') or not hasattr(box, 'conf') or not hasattr(box, 'xyxy'): + print(f"警告: 第 {i} 个框缺少必要的属性") + continue + + try: + class_id = int(box.cls.item()) + formatted_result = { + "bbox": box.xyxy.tolist()[0], + "confidence": box.conf.item(), + "class_id": class_id, + "class": names.get(class_id, f"Unknown-{class_id}") + } + formatted_results.append(formatted_result) + except Exception as e: + print(f"处理第 {i} 个框时出错: {str(e)}") + + # print("格式化后的结果:", formatted_results) + return formatted_results + + def draw_results(self, frame, results): + for r in results: + annotated_frame = r.plot() + return annotated_frame + +detector = fallDetector(MODEL_PATH) + +def process_image(image_data, filename): + try: + img = cv2.imdecode(np.frombuffer(image_data, np.uint8), cv2.IMREAD_COLOR) + + results = detector.detect(img) + + # Format results for JSON + json_results = detector.format_results(results) + + # Draw results on image + annotated_img = detector.draw_results(img, results) + + # Save annotated image + annotated_filename = f"fall_{filename}" + annotated_path = os.path.join(RESULT_DIR, annotated_filename) + cv2.imwrite(annotated_path, annotated_img) + + return json_results, annotated_filename + except Exception as e: + print(f"Error processing image: {str(e)}") + return None, None + + + +def process_video(video_data, filename): + try: + temp_video_path = os.path.join(UPLOAD_DIR, f"fall_{filename}") + with open(temp_video_path, 'wb') as temp_video: + temp_video.write(video_data) + + cap = cv2.VideoCapture(temp_video_path) + frame_count = 0 + json_results = [] + + # Get video properties + fps = int(cap.get(cv2.CAP_PROP_FPS)) + width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + + # Create output video file + annotated_filename = f"fall_{filename}" + output_path = os.path.join(RESULT_DIR, annotated_filename) + out = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height)) + + while cap.isOpened(): + ret, frame = cap.read() + if not ret: + break + + results = detector.detect(frame) + frame_json_results = detector.format_results(results) + json_results.append({"frame": frame_count, "detections": frame_json_results}) + + annotated_frame = detector.draw_results(frame, results) + out.write(annotated_frame) + + frame_count += 1 + + cap.release() + out.release() + + # Clean up temporary input video file + os.remove(temp_video_path) + + return json_results, annotated_filename + except Exception as e: + print(f"Error processing video: {str(e)}") + return None, None + +@fall_app.post("/upload") +async def upload_file(file: UploadFile = File(...),api_key: str = Depends(get_api_key)): + # 验证 API key + api_key_info = await verify_api_key(api_key) + if not api_key_info: + raise HTTPException(status_code=403, detail="无效的API密钥") + + content = await file.read() + file_extension = os.path.splitext(file.filename)[1].lower() + new_filename = f"{uuid.uuid4()}{file_extension}" + + # Save the original file + original_file_path = os.path.join(UPLOAD_DIR, new_filename) + with open(original_file_path, "wb") as f: + f.write(content) + + + # 计算 token + file_type = "image" if file_extension in ['.jpg', '.jpeg', '.png'] else "video" + tokens_required = calculate_tokens(original_file_path, file_type) + if tokens_required is None or tokens_required <= 0: + raise HTTPException(status_code=500, detail="无法计算所需的token数量") + # 检查并更新 token 使用量 + usage_key = f"api_key:{api_key}" + total_tokens = int(redis_api_usage_client.hget(usage_key, "total_tokens") or 0) + tokens_used = int(redis_api_usage_client.hget(usage_key, "tokens_used") or 0) + + if tokens_used + tokens_required > total_tokens: + raise HTTPException(status_code=403, detail="Token 余额不足") + + # 更新 token 使用量 + model_name = "yolov8n-fall" + await update_token_usage(api_key, tokens_required, model_name) + + + # Send processing task to Kafka + producer.send(KAFKA_TOPIC, json.dumps({ + "filename": new_filename, + "file_type": "image" if file_extension in ['.jpg', '.jpeg', '.png'] else "video" + }).encode('utf-8')) + + # Set initial status in Redis + redis_key = f"fall_result:{new_filename}" + redis_client.set(redis_key, json.dumps({"status": "queued"})) + + # 获取更新后的 token 使用情况 + updated_api_key_info = await verify_api_key(api_key) + new_tokens_used = int(updated_api_key_info.get("tokens_used", 0)) + model_tokens_used = int(updated_api_key_info.get(f"{model_name}_tokens_used", 0)) + + return JSONResponse(content={ + "message": "文件已上传并排队等待处理", + "filename": new_filename, + "tokens_used": tokens_required, + "total_tokens_used": new_tokens_used, + f"{model_name}_tokens_used": model_tokens_used, + "tokens_remaining": total_tokens - new_tokens_used + }) + + +@fall_app.get("/result/{filename}", dependencies=[Depends(verify_api_key)]) +async def get_fall_result(filename: str): + redis_key = f"fall_result:{filename}" + result = redis_client.get(redis_key) + if result: + result_data = json.loads(result) + return JSONResponse(content=result_data) # 直接返回整个结果,包括 status + else: + raise HTTPException(status_code=404, detail="Result not found") + +@fall_app.get("/annotated/{filename}", dependencies=[Depends(verify_api_key)]) +async def get_annotated_file(filename: str): + redis_key = f"fall_result:{filename}" + result = redis_client.get(redis_key) + if result: + result_data = json.loads(result) + if result_data["status"] == "completed": + annotated_filename = result_data["annotated_filename"] + file_path = os.path.join(RESULT_DIR, annotated_filename) + if os.path.exists(file_path): + def iterfile(): + with open(file_path, mode="rb") as file_like: + yield from file_like + file_extension = os.path.splitext(annotated_filename)[1].lower() + return StreamingResponse(iterfile(), media_type=f"image/{file_extension[1:]}" if file_extension in ['.jpg', '.jpeg', '.png'] else "video/mp4") + + raise HTTPException(status_code=404, detail="Annotated file not found") + +def process_task(): + for message in consumer: + task = message.value + filename = task['filename'] + file_type = task['file_type'] + + file_path = os.path.join(UPLOAD_DIR, filename) + + # Update status to "processing" + redis_key = f"fall_result:{filename}" + redis_client.set(redis_key, json.dumps({"status": "processing"})) + + try: + if file_type == "image": + with open(file_path, 'rb') as f: + content = f.read() + json_results, annotated_filename = process_image(content, filename) + else: + with open(file_path, 'rb') as f: + content = f.read() + json_results, annotated_filename = process_video(content, filename) + + if json_results and annotated_filename: + redis_client.set(redis_key, json.dumps({ + "json_results": json_results, + "status": "completed", + "annotated_filename": annotated_filename + })) + else: + redis_client.set(redis_key, json.dumps({"status": "failed"})) + except Exception as e: + print(f"Error processing task: {str(e)}") + redis_client.set(redis_key, json.dumps({"status": "failed", "error": str(e)})) + +def listen_redis_changes(): + pubsub = redis_client.pubsub() + pubsub.psubscribe('__keyspace@3__:fall_result:*') # 监听所有fall_result键的变化 + + for message in pubsub.listen(): + if message['type'] == 'pmessage': + key = message['channel'].decode('utf-8').split(':')[-1] + operation = message['data'].decode('utf-8') + + if operation == 'set': + value = redis_client.get(f"fall_result:{key}") + if value: + result = json.loads(value) + print(f"Status update for {key}: {result['status']}") + + # 这里可以添加其他处理逻辑,比如发送通知等 + +if __name__ == "__main__": + # 启动处理任务的线程 + threading.Thread(target=process_task, daemon=True).start() + + # 启动Redis监听线程 + threading.Thread(target=listen_redis_changes, daemon=True).start() + + uvicorn.run(app, host="0.0.0.0", port=7002) \ No newline at end of file diff --git a/api_old/before/fall_key2.py b/api_old/before/fall_key2.py new file mode 100644 index 0000000..82abe42 --- /dev/null +++ b/api_old/before/fall_key2.py @@ -0,0 +1,449 @@ +from fastapi import FastAPI, HTTPException, File, UploadFile, Depends, Security +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import JSONResponse, StreamingResponse +from fastapi.security import APIKeyHeader +from ultralytics import YOLO +import cv2 +import numpy as np +import json +import uvicorn +from kafka import KafkaProducer, KafkaConsumer +from redis import Redis +import io +import uuid +import os +from datetime import datetime, timedelta, timezone +import threading +import string +app = FastAPI() +fall_app = FastAPI() +app.mount("/fall", fall_app) + +# CORS配置 +ALLOWED_ORIGINS = 'https://beta.obscura.work' + +# 只为主应用添加CORS中间件 +app.add_middleware( + CORSMiddleware, + allow_origins=ALLOWED_ORIGINS, + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + +# 配置 +MODEL_PATH = "/home/zydi/models/yolov8n-fall.pt" # 请替换为您的模型路径 +KAFKA_BROKER = "222.186.10.253:9092" +KAFKA_TOPIC = "fall" # 指定Kafka topic +KAFKA_GROUP_ID = "fall_group" # 指定消费者组ID + +REDIS_HOST = "222.186.10.253" +REDIS_PORT = 6379 +REDIS_PASSWORD = "Obscura@2024" +REDIS_DB = 4 +REDIS_API_DB = 12 +REDIS_API_USAGE_DB = 13 +UPLOAD_DIR = "/www/wwwroot/beta.obscura.work/upload_files/upload" +RESULT_DIR = "/www/wwwroot/beta.obscura.work/upload_files/result" +MAX_FILE_AGE = timedelta(hours=1) + +# 确保目录存在 +os.makedirs(UPLOAD_DIR, exist_ok=True) +os.makedirs(RESULT_DIR, exist_ok=True) + +# 初始化 Kafka +producer = KafkaProducer(bootstrap_servers=[KAFKA_BROKER]) +consumer = KafkaConsumer( + KAFKA_TOPIC, + bootstrap_servers=[KAFKA_BROKER], + group_id=KAFKA_GROUP_ID, + auto_offset_reset='earliest', + enable_auto_commit=True, + value_deserializer=lambda x: json.loads(x.decode('utf-8')) +) + +# 初始化 Redis +redis_client = Redis( + host=REDIS_HOST, + port=REDIS_PORT, + password=REDIS_PASSWORD, + db=REDIS_DB +) +redis_api_client = Redis( + host=REDIS_HOST, + port=REDIS_PORT, + password=REDIS_PASSWORD, + db=REDIS_API_DB +) +redis_api_usage_client = Redis( + host=REDIS_HOST, + port=REDIS_PORT, + password=REDIS_PASSWORD, + db=REDIS_API_USAGE_DB +) + +# 定义API密钥头部 +api_key_header = APIKeyHeader(name="Authorization", auto_error=False) + +# 定义 base62 字符集 +BASE62 = string.digits + string.ascii_letters + +# 验证API密钥的函数 +async def get_api_key(api_key: str = Security(api_key_header)): + if api_key and api_key.startswith("Bearer "): + key = api_key.split(" ")[1] + if key.startswith("obs-"): + return key + raise HTTPException( + status_code=401, + detail="无效的API密钥", + headers={"WWW-Authenticate": "Bearer"}, + ) + +def calculate_tokens(file_path: str, file_type: str) -> int: + base_tokens = 0 + + try: + file_size = os.path.getsize(file_path) # 获取文件大小(字节) + + # 基础token:每MB文件大小消耗10个token + base_tokens = int((file_size / (1024 * 1024)) * 10) + + if file_type == "image": + img = cv2.imread(file_path) + if img is None: + raise ValueError("无法读取图片文件") + height, width = img.shape[:2] + pixel_count = height * width + + # 图片token:每100个像素额外消耗5个token + image_tokens = int((pixel_count / 10000) * 5) + + base_tokens += image_tokens + + elif file_type == "video": + cap = cv2.VideoCapture(file_path) + if not cap.isOpened(): + raise ValueError("无法打开视频文件") + fps = cap.get(cv2.CAP_PROP_FPS) + frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + cap.release() + + pixel_count = width * height * frame_count + duration = frame_count / fps # 视频时长(秒) + + # 视频token:每100万像素每秒额外消耗1个token + video_tokens = int((pixel_count / 10000) * (duration / 60)) + + base_tokens += video_tokens + + return max(1, base_tokens) # 确保至少返回1个token + except Exception as e: + print(f"计算token时出错: {str(e)}") + return 1 # 出错时返回默认值1 + + +async def verify_api_key(api_key: str = Depends(get_api_key)): + redis_key = f"api_key:{api_key}" + + api_key_info = redis_api_client.hgetall(redis_key) + + if not api_key_info: + raise HTTPException(status_code=401, detail="无效的API密钥") + + api_key_info = {k.decode('utf-8'): v.decode('utf-8') for k, v in api_key_info.items()} + + if api_key_info.get('is_active') != '1': + raise HTTPException(status_code=401, detail="API密钥已停用") + + expires_at = datetime.fromisoformat(api_key_info.get('expires_at')) + if datetime.now(timezone.utc) > expires_at: + raise HTTPException(status_code=401, detail="API密钥已过期") + + usage_info = redis_api_usage_client.hgetall(redis_key) + usage_info = {k.decode('utf-8'): v.decode('utf-8') for k, v in usage_info.items()} + + return { + "api_key": api_key, + **api_key_info, + **usage_info + } + + +async def update_token_usage(api_key: str, new_tokens_used: int, model_name: str): + redis_key = f"api_key:{api_key}" + current_time = datetime.now(timezone.utc).isoformat() + + pipe = redis_api_usage_client.pipeline() + + # 更新总的token使用量 + pipe.hincrby(redis_key, "tokens_used", new_tokens_used) + pipe.hset(redis_key, "last_used_at", current_time) + + # 更新特定模型的token使用量 + model_tokens_field = f"{model_name}_tokens_used" + model_last_used_field = f"{model_name}_last_used_at" + + pipe.hincrby(redis_key, model_tokens_field, new_tokens_used) + pipe.hset(redis_key, model_last_used_field, current_time) + + pipe.execute() + +class fallDetector: + def __init__(self, model_path): + self.model = YOLO(model_path) + def detect(self, frame): + results = self.model(frame) + return results + def format_results(self, results): + formatted_results = [] + for r in results: + if not hasattr(r, 'boxes') or len(r.boxes) == 0: + print("没有检测到任何对象") + return [{"message": "No objects detected"}] + + boxes = r.boxes + names = getattr(r, 'names', {}) + + for i in range(len(boxes)): + box = boxes[i] + if not hasattr(box, 'cls') or not hasattr(box, 'conf') or not hasattr(box, 'xyxy'): + print(f"警告: 第 {i} 个框缺少必要的属性") + continue + + try: + class_id = int(box.cls.item()) + formatted_result = { + "bbox": box.xyxy.tolist()[0], + "confidence": box.conf.item(), + "class_id": class_id, + "class": names.get(class_id, f"Unknown-{class_id}") + } + formatted_results.append(formatted_result) + except Exception as e: + print(f"处理第 {i} 个框时出错: {str(e)}") + + # print("格式化后的结果:", formatted_results) + return formatted_results + + def draw_results(self, frame, results): + for r in results: + annotated_frame = r.plot() + return annotated_frame + +detector = fallDetector(MODEL_PATH) + +def process_image(image_data, filename): + try: + img = cv2.imdecode(np.frombuffer(image_data, np.uint8), cv2.IMREAD_COLOR) + + results = detector.detect(img) + + # Format results for JSON + json_results = detector.format_results(results) + + # Draw results on image + annotated_img = detector.draw_results(img, results) + + # Save annotated image + annotated_filename = f"fall_{filename}" + annotated_path = os.path.join(RESULT_DIR, annotated_filename) + cv2.imwrite(annotated_path, annotated_img) + + return json_results, annotated_filename + except Exception as e: + print(f"Error processing image: {str(e)}") + return None, None + + + +def process_video(video_data, filename): + try: + temp_video_path = os.path.join(UPLOAD_DIR, f"fall_{filename}") + with open(temp_video_path, 'wb') as temp_video: + temp_video.write(video_data) + + cap = cv2.VideoCapture(temp_video_path) + frame_count = 0 + json_results = [] + + # Get video properties + fps = int(cap.get(cv2.CAP_PROP_FPS)) + width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + + # Create output video file + annotated_filename = f"fall_{filename}" + output_path = os.path.join(RESULT_DIR, annotated_filename) + out = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height)) + + while cap.isOpened(): + ret, frame = cap.read() + if not ret: + break + + results = detector.detect(frame) + frame_json_results = detector.format_results(results) + json_results.append({"frame": frame_count, "detections": frame_json_results}) + + annotated_frame = detector.draw_results(frame, results) + out.write(annotated_frame) + + frame_count += 1 + + cap.release() + out.release() + + # Clean up temporary input video file + os.remove(temp_video_path) + + return json_results, annotated_filename + except Exception as e: + print(f"Error processing video: {str(e)}") + return None, None + +@fall_app.post("/upload") +async def upload_file(file: UploadFile = File(...), api_key_info: dict = Depends(verify_api_key)): + + content = await file.read() + file_extension = os.path.splitext(file.filename)[1].lower() + new_filename = f"{uuid.uuid4()}{file_extension}" + + # Save the original file + original_file_path = os.path.join(UPLOAD_DIR, new_filename) + with open(original_file_path, "wb") as f: + f.write(content) + + # 计算 token + file_type = "image" if file_extension in ['.jpg', '.jpeg', '.png'] else "video" + tokens_required = calculate_tokens(original_file_path, file_type) + if tokens_required is None or tokens_required <= 0: + raise HTTPException(status_code=500, detail="无法计算所需的token数量") + # 检查并更新 token 使用量 + api_key = api_key_info['api_key'] + usage_key = f"api_key:{api_key}" + total_tokens = int(redis_api_usage_client.hget(usage_key, "total_tokens") or 0) + tokens_used = int(redis_api_usage_client.hget(usage_key, "tokens_used") or 0) + + if tokens_used + tokens_required > total_tokens: + raise HTTPException(status_code=403, detail="Token 余额不足") + + # 更新 token 使用量 + model_name = "yolov8n-fall" + await update_token_usage(api_key, tokens_required, model_name) + + + # Send processing task to Kafka + producer.send(KAFKA_TOPIC, json.dumps({ + "filename": new_filename, + "file_type": "image" if file_extension in ['.jpg', '.jpeg', '.png'] else "video" + }).encode('utf-8')) + + # Set initial status in Redis + redis_key = f"fall_result:{new_filename}" + redis_client.set(redis_key, json.dumps({"status": "queued"})) + + # 获取更新后的 token 使用情况 + updated_api_key_info = await verify_api_key(api_key) + new_tokens_used = int(updated_api_key_info.get("tokens_used", 0)) + model_tokens_used = int(updated_api_key_info.get(f"{model_name}_tokens_used", 0)) + + return JSONResponse(content={ + "message": "文件已上传并排队等待处理", + "filename": new_filename, + "tokens_used": tokens_required, + "total_tokens_used": new_tokens_used, + f"{model_name}_tokens_used": model_tokens_used, + "tokens_remaining": total_tokens - new_tokens_used + }) + + +@fall_app.get("/result/{filename}", dependencies=[Depends(verify_api_key)]) +async def get_fall_result(filename: str): + redis_key = f"fall_result:{filename}" + result = redis_client.get(redis_key) + if result: + result_data = json.loads(result) + return JSONResponse(content=result_data) # 直接返回整个结果,包括 status + else: + raise HTTPException(status_code=404, detail="Result not found") + +@fall_app.get("/annotated/{filename}", dependencies=[Depends(verify_api_key)]) +async def get_annotated_file(filename: str): + redis_key = f"fall_result:{filename}" + result = redis_client.get(redis_key) + if result: + result_data = json.loads(result) + if result_data["status"] == "completed": + annotated_filename = result_data["annotated_filename"] + file_path = os.path.join(RESULT_DIR, annotated_filename) + if os.path.exists(file_path): + def iterfile(): + with open(file_path, mode="rb") as file_like: + yield from file_like + file_extension = os.path.splitext(annotated_filename)[1].lower() + return StreamingResponse(iterfile(), media_type=f"image/{file_extension[1:]}" if file_extension in ['.jpg', '.jpeg', '.png'] else "video/mp4") + + raise HTTPException(status_code=404, detail="Annotated file not found") + +def process_task(): + for message in consumer: + task = message.value + filename = task['filename'] + file_type = task['file_type'] + + file_path = os.path.join(UPLOAD_DIR, filename) + + # Update status to "processing" + redis_key = f"fall_result:{filename}" + redis_client.set(redis_key, json.dumps({"status": "processing"})) + + try: + if file_type == "image": + with open(file_path, 'rb') as f: + content = f.read() + json_results, annotated_filename = process_image(content, filename) + else: + with open(file_path, 'rb') as f: + content = f.read() + json_results, annotated_filename = process_video(content, filename) + + if json_results and annotated_filename: + redis_client.set(redis_key, json.dumps({ + "json_results": json_results, + "status": "completed", + "annotated_filename": annotated_filename + })) + else: + redis_client.set(redis_key, json.dumps({"status": "failed"})) + except Exception as e: + print(f"Error processing task: {str(e)}") + redis_client.set(redis_key, json.dumps({"status": "failed", "error": str(e)})) + +def listen_redis_changes(): + pubsub = redis_client.pubsub() + pubsub.psubscribe('__keyspace@3__:fall_result:*') # 监听所有fall_result键的变化 + + for message in pubsub.listen(): + if message['type'] == 'pmessage': + key = message['channel'].decode('utf-8').split(':')[-1] + operation = message['data'].decode('utf-8') + + if operation == 'set': + value = redis_client.get(f"fall_result:{key}") + if value: + result = json.loads(value) + print(f"Status update for {key}: {result['status']}") + + # 这里可以添加其他处理逻辑,比如发送通知等 + +if __name__ == "__main__": + # 启动处理任务的线程 + threading.Thread(target=process_task, daemon=True).start() + + # 启动Redis监听线程 + threading.Thread(target=listen_redis_changes, daemon=True).start() + + uvicorn.run(app, host="0.0.0.0", port=7002) \ No newline at end of file diff --git a/api_old/before/mediapipe.py b/api_old/before/mediapipe.py new file mode 100644 index 0000000..9d84690 --- /dev/null +++ b/api_old/before/mediapipe.py @@ -0,0 +1,297 @@ +import os +import cv2 +import torch +import numpy as np +from redis import Redis +import json +from kafka import KafkaConsumer +import threading +import redis +import torch +import media as mp +from mediapipe.tasks import python +from mediapipe.tasks.python import vision + +# Configuration +MODEL_PATH = "/home/zydi/models/face_landmarker.task" # Replace with your model path +KAFKA_BROKER = "222.186.10.253:9092" +KAFKA_TOPIC = "mediapipe" +KAFKA_GROUP_ID = "mediapipe_group" + +REDIS_HOST = "222.186.10.253" +REDIS_PORT = 6379 +REDIS_PASSWORD = "Obscura@2024" +REDIS_DB = 9 # POSE Worker使用的Redis DB +MAIN_REDIS_DB = 15 # 主Redis DB +UPLOAD_DIR = "/www/wwwroot/beta.obscura.work/upload_files/upload" +RESULT_DIR = "/www/wwwroot/beta.obscura.work/upload_files/result" + +# Ensure directories exist +os.makedirs(UPLOAD_DIR, exist_ok=True) +os.makedirs(RESULT_DIR, exist_ok=True) + +# Initialize Kafka +consumer = KafkaConsumer( + KAFKA_TOPIC, + bootstrap_servers=[KAFKA_BROKER], + group_id=KAFKA_GROUP_ID, + auto_offset_reset='earliest', + enable_auto_commit=True, + value_deserializer=lambda x: json.loads(x.decode('utf-8')) +) + +# 初始化 Redis +redis_client = Redis( + host=REDIS_HOST, + port=REDIS_PORT, + password=REDIS_PASSWORD, + db=REDIS_DB +) + +main_redis_client = Redis( + host=REDIS_HOST, + port=REDIS_PORT, + password=REDIS_PASSWORD, + db=MAIN_REDIS_DB +) + + +class mediapipeEmbedder: + def __init__(self, model_path): + base_options = python.BaseOptions(model_asset_path=model_path) + options = vision.FaceLandmarkerOptions( + base_options=base_options, + output_face_blendshapes=True, + output_facial_transformation_matrixes=True, + num_faces=1 + ) + self.detector = vision.FaceLandmarker.create_from_options(options) + + def get_mediapipe_landmarks(self, image): + mp_image = mp.Image(image_format=mp.ImageFormat.SRGB, data=image) + detection_result = self.detector.detect(mp_image) + if detection_result.face_landmarks: + return np.array([(lm.x, lm.y, lm.z) for lm in detection_result.face_landmarks[0]]) + return None + + def process_image(self, image_data): + img = cv2.imdecode(np.frombuffer(image_data, np.uint8), cv2.IMREAD_COLOR) + landmarks = self.get_mediapipe_landmarks(img) + + if landmarks is not None: + # Calculate a more detailed mediapipe embedding + embedding = self.calculate_detailed_embedding(landmarks) + + # Draw landmarks on the image + for lm in landmarks: + cv2.circle(img, (int(lm[0]*img.shape[1]), int(lm[1]*img.shape[0])), 2, (0,255,0), -1) + + return { + "embedding": embedding, + "landmarks": landmarks.tolist() + }, img + else: + return None, img + + def calculate_detailed_embedding(self, landmarks): + # Calculate various statistical features + mean = np.mean(landmarks, axis=0) + std = np.std(landmarks, axis=0) + median = np.median(landmarks, axis=0) + min_vals = np.min(landmarks, axis=0) + max_vals = np.max(landmarks, axis=0) + + # Calculate pairwise distances between key facial landmarks + nose_tip = landmarks[4] + left_eye = landmarks[159] + right_eye = landmarks[386] + left_mouth = landmarks[61] + right_mouth = landmarks[291] + + eye_distance = np.linalg.norm(left_eye - right_eye) + mouth_width = np.linalg.norm(left_mouth - right_mouth) + nose_to_mouth = np.linalg.norm(nose_tip - (left_mouth + right_mouth) / 2) + + # Calculate face shape features + face_width = np.max(landmarks[:, 0]) - np.min(landmarks[:, 0]) + face_height = np.max(landmarks[:, 1]) - np.min(landmarks[:, 1]) + face_depth = np.max(landmarks[:, 2]) - np.min(landmarks[:, 2]) + + # Combine all features into a single embedding + embedding = np.concatenate([ + mean, std, median, min_vals, max_vals, + [eye_distance, mouth_width, nose_to_mouth, face_width, face_height, face_depth] + ]) + + return embedding.tolist() + +embedder = mediapipeEmbedder(MODEL_PATH) + +def process_image(image_data, filename): + try: + results, annotated_img = embedder.process_image(image_data) + + if results: + # Save annotated image + annotated_filename = f"mediapipe_{filename}" + annotated_path = os.path.join(RESULT_DIR, annotated_filename) + cv2.imwrite(annotated_path, annotated_img) + + return results, annotated_filename + else: + print(f"No face landmarks detected in image: {filename}") + return None, None + except Exception as e: + print(f"Error processing image {filename}: {str(e)}") + import traceback + traceback.print_exc() + return None, None + +def process_video(video_data, filename): + try: + temp_video_path = os.path.join(UPLOAD_DIR, f"mediapipe_{filename}") + with open(temp_video_path, 'wb') as temp_video: + temp_video.write(video_data) + + cap = cv2.VideoCapture(temp_video_path) + frame_count = 0 + results = [] + + fps = int(cap.get(cv2.CAP_PROP_FPS)) + width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + + annotated_filename = f"mediapipe_{filename}" + output_path = os.path.join(RESULT_DIR, annotated_filename) + out = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height)) + + while cap.isOpened(): + ret, frame = cap.read() + if not ret: + break + + if frame_count % fps == 0: + frame_results, annotated_frame = embedder.process_image(cv2.imencode('.jpg', frame)[1].tobytes()) + if frame_results: + results.append({"frame": frame_count, "results": frame_results}) + else: + annotated_frame = frame + + out.write(annotated_frame) + frame_count += 1 + + cap.release() + out.release() + + os.remove(temp_video_path) + + return results, annotated_filename + except Exception as e: + print(f"Error processing video: {str(e)}") + return None, None + +def process_task(): + print("开始处理任务,等待Kafka消息...") + for message in consumer: + print(f"收到Kafka消息: topic={message.topic}, partition={message.partition}, offset={message.offset}") + task = message.value + task_id = task['task_id'] + filename = task['filename'] + file_type = task['file_type'] + + print(f"解析任务信息: ID={task_id}, 文件名={filename}, 类型={file_type}") + + file_path = os.path.join(UPLOAD_DIR, filename) + # 检查键类型并更新状态 + task_key = f"task:{task_id}" + try: + key_type = main_redis_client.type(task_key) + if key_type != b'hash': + main_redis_client.delete(task_key) + main_redis_client.hset(f"task:{task_id}", "status", "processing") + print(f"任务 {task_id} 状态更新为 'processing'") + except redis.exceptions.ResponseError as e: + print(f"更新任务 {task_id} 状态时出错: {str(e)}") + continue # 跳过这个任务,继续处理下一个 + + try: + if file_type == "image": + print(f"开始处理图像: {filename}") + with open(file_path, 'rb') as f: + image_data = f.read() + json_results, annotated_filename = process_image(image_data, filename) + if json_results and annotated_filename is not None: + result_path = os.path.join(RESULT_DIR, annotated_filename) + + redis_client.hset(f"fall_result:{task_id}", mapping={ + "result": json.dumps(json_results), + "result_file": annotated_filename + }) + main_redis_client.hset(f"task:{task_id}", "status", "completed") + main_redis_client.hset(f"task:{task_id}", "result_type", "fall") + main_redis_client.hset(f"task:{task_id}", "result_key", f"fall_result:{task_id}") + print(f"图像 {filename} 处理完成,结果已保存") + else: + print(f"图像 {filename} 处理失败") + main_redis_client.hset(f"task:{task_id}", "status", "failed") + else: # video + print(f"开始处理视频: {filename}") + with open(file_path, 'rb') as f: + video_data = f.read() + json_results, annotated_filename = process_video(video_data, filename) + if json_results and annotated_filename: + redis_client.hset(f"fall_result:{task_id}", mapping={ + "result": json.dumps(json_results), + "result_file": annotated_filename + }) + main_redis_client.hset(f"task:{task_id}", "status", "completed") + main_redis_client.hset(f"task:{task_id}", "result_type", "fall") + main_redis_client.hset(f"task:{task_id}", "result_key", f"fall_result:{task_id}") + + + + print(f"视频 {filename} 处理完成,结果已保存") + else: + print(f"视频 {filename} 处理失败") + main_redis_client.hset(f"task:{task_id}", "status", "failed") + except Exception as e: + print(f"处理任务 {task_id} 时出错: {str(e)}") + main_redis_client.hset(f"task:{task_id}", { + "status": "failed", + "error": str(e) + }) + + print(f"任务 {task_id} 处理完毕,等待下一个Kafka消息...") + +def listen_redis_changes(): + pubsub = redis_client.pubsub() + pubsub.psubscribe('__keyspace@3__:fall_result:*') # 监听所有fall_result键的变化 + + for message in pubsub.listen(): + if message['type'] == 'pmessage': + key = message['channel'].decode('utf-8').split(':')[-1] + operation = message['data'].decode('utf-8') + + if operation == 'hset': + value = redis_client.hgetall(f"fall_result:{key}") + if value: + result = {k.decode(): v.decode() for k, v in value.items()} + print(f"Result update for task {key}: {result}") + + +if __name__ == "__main__": + print("fall处理程序启动...") + # 启动处理任务的线程 + task_thread = threading.Thread(target=process_task, daemon=True) + task_thread.start() + print("任务处理线程已启动") + + # 启动Redis监听线程 + redis_thread = threading.Thread(target=listen_redis_changes, daemon=True) + redis_thread.start() + print("Redis监听线程已启动") + + print("主程序进入等待状态...") + # 保持主线程运行 + task_thread.join() + redis_thread.join() \ No newline at end of file diff --git a/api_old/before/mediapipe_api.py b/api_old/before/mediapipe_api.py new file mode 100644 index 0000000..1b75b6c --- /dev/null +++ b/api_old/before/mediapipe_api.py @@ -0,0 +1,304 @@ +from fastapi import FastAPI, HTTPException, File, UploadFile +from fastapi.responses import StreamingResponse, JSONResponse +from fastapi.middleware.cors import CORSMiddleware +import cv2 +import numpy as np +import json +import uvicorn +from kafka import KafkaProducer, KafkaConsumer +from redis import Redis +import uuid +import os +from datetime import timedelta +import threading +import mediapipe as mp +from mediapipe.tasks import python +from mediapipe.tasks.python import vision + +app = FastAPI() +mediapipe_app = FastAPI() +app.mount("/mediapipe", mediapipe_app) + +# CORS configuration +ALLOWED_ORIGINS = 'https://beta.obscura.work' + +app.add_middleware( + CORSMiddleware, + allow_origins=ALLOWED_ORIGINS, + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + +# Configuration +MODEL_PATH = "/home/zydi/models/face_landmarker.task" # Replace with your model path +KAFKA_BROKER = "222.186.10.253:9092" +KAFKA_TOPIC = "mediapipe" +KAFKA_GROUP_ID = "mediapipe_group" + +REDIS_HOST = "222.186.10.253" +REDIS_PORT = 6379 +REDIS_PASSWORD = "Obscura@2024" +REDIS_DB = 10 + +UPLOAD_DIR = "/www/wwwroot/beta.obscura.work/upload_files/upload" +RESULT_DIR = "/www/wwwroot/beta.obscura.work/upload_files/result" +MAX_FILE_AGE = timedelta(hours=1) + +# Ensure directories exist +os.makedirs(UPLOAD_DIR, exist_ok=True) +os.makedirs(RESULT_DIR, exist_ok=True) + +# Initialize Kafka +producer = KafkaProducer(bootstrap_servers=[KAFKA_BROKER]) +consumer = KafkaConsumer( + KAFKA_TOPIC, + bootstrap_servers=[KAFKA_BROKER], + group_id=KAFKA_GROUP_ID, + auto_offset_reset='earliest', + enable_auto_commit=True, + value_deserializer=lambda x: json.loads(x.decode('utf-8')) +) + +# Initialize Redis +redis_client = Redis( + host=REDIS_HOST, + port=REDIS_PORT, + password=REDIS_PASSWORD, + db=REDIS_DB +) + +class mediapipeEmbedder: + def __init__(self, model_path): + base_options = python.BaseOptions(model_asset_path=model_path) + options = vision.FaceLandmarkerOptions( + base_options=base_options, + output_face_blendshapes=True, + output_facial_transformation_matrixes=True, + num_faces=1 + ) + self.detector = vision.FaceLandmarker.create_from_options(options) + + def get_mediapipe_landmarks(self, image): + mp_image = mp.Image(image_format=mp.ImageFormat.SRGB, data=image) + detection_result = self.detector.detect(mp_image) + if detection_result.face_landmarks: + return np.array([(lm.x, lm.y, lm.z) for lm in detection_result.face_landmarks[0]]) + return None + + def process_image(self, image_data): + img = cv2.imdecode(np.frombuffer(image_data, np.uint8), cv2.IMREAD_COLOR) + landmarks = self.get_mediapipe_landmarks(img) + + if landmarks is not None: + # Calculate a more detailed mediapipe embedding + embedding = self.calculate_detailed_embedding(landmarks) + + # Draw landmarks on the image + for lm in landmarks: + cv2.circle(img, (int(lm[0]*img.shape[1]), int(lm[1]*img.shape[0])), 2, (0,255,0), -1) + + return { + "embedding": embedding, + "landmarks": landmarks.tolist() + }, img + else: + return None, img + + def calculate_detailed_embedding(self, landmarks): + # Calculate various statistical features + mean = np.mean(landmarks, axis=0) + std = np.std(landmarks, axis=0) + median = np.median(landmarks, axis=0) + min_vals = np.min(landmarks, axis=0) + max_vals = np.max(landmarks, axis=0) + + # Calculate pairwise distances between key facial landmarks + nose_tip = landmarks[4] + left_eye = landmarks[159] + right_eye = landmarks[386] + left_mouth = landmarks[61] + right_mouth = landmarks[291] + + eye_distance = np.linalg.norm(left_eye - right_eye) + mouth_width = np.linalg.norm(left_mouth - right_mouth) + nose_to_mouth = np.linalg.norm(nose_tip - (left_mouth + right_mouth) / 2) + + # Calculate face shape features + face_width = np.max(landmarks[:, 0]) - np.min(landmarks[:, 0]) + face_height = np.max(landmarks[:, 1]) - np.min(landmarks[:, 1]) + face_depth = np.max(landmarks[:, 2]) - np.min(landmarks[:, 2]) + + # Combine all features into a single embedding + embedding = np.concatenate([ + mean, std, median, min_vals, max_vals, + [eye_distance, mouth_width, nose_to_mouth, face_width, face_height, face_depth] + ]) + + return embedding.tolist() + +embedder = mediapipeEmbedder(MODEL_PATH) + +def process_image(image_data, filename): + try: + results, annotated_img = embedder.process_image(image_data) + + if results: + # Save annotated image + annotated_filename = f"mediapipe_{filename}" + annotated_path = os.path.join(RESULT_DIR, annotated_filename) + cv2.imwrite(annotated_path, annotated_img) + + return results, annotated_filename + else: + print(f"No face landmarks detected in image: {filename}") + return None, None + except Exception as e: + print(f"Error processing image {filename}: {str(e)}") + import traceback + traceback.print_exc() + return None, None + +def process_video(video_data, filename): + try: + temp_video_path = os.path.join(UPLOAD_DIR, f"mediapipe_{filename}") + with open(temp_video_path, 'wb') as temp_video: + temp_video.write(video_data) + + cap = cv2.VideoCapture(temp_video_path) + frame_count = 0 + results = [] + + fps = int(cap.get(cv2.CAP_PROP_FPS)) + width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + + annotated_filename = f"mediapipe_{filename}" + output_path = os.path.join(RESULT_DIR, annotated_filename) + out = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height)) + + while cap.isOpened(): + ret, frame = cap.read() + if not ret: + break + + if frame_count % fps == 0: + frame_results, annotated_frame = embedder.process_image(cv2.imencode('.jpg', frame)[1].tobytes()) + if frame_results: + results.append({"frame": frame_count, "results": frame_results}) + else: + annotated_frame = frame + + out.write(annotated_frame) + frame_count += 1 + + cap.release() + out.release() + + os.remove(temp_video_path) + + return results, annotated_filename + except Exception as e: + print(f"Error processing video: {str(e)}") + return None, None + +@mediapipe_app.post("/upload") +async def upload_file(file: UploadFile = File(...)): + content = await file.read() + file_extension = os.path.splitext(file.filename)[1].lower() + new_filename = f"{uuid.uuid4()}{file_extension}" + + original_file_path = os.path.join(UPLOAD_DIR, new_filename) + with open(original_file_path, "wb") as f: + f.write(content) + + producer.send(KAFKA_TOPIC, json.dumps({ + "filename": new_filename, + "file_type": "image" if file_extension in ['.jpg', '.jpeg', '.png'] else "video" + }).encode('utf-8')) + + redis_key = f"mediapipe_result:{new_filename}" + redis_client.set(redis_key, json.dumps({"status": "queued"})) + + return JSONResponse(content={"message": "File uploaded and queued for processing", "filename": new_filename}) + +@mediapipe_app.get("/result/{filename}") +async def get_mediapipe_result(filename: str): + redis_key = f"mediapipe_result:{filename}" + result = redis_client.get(redis_key) + if result: + result_data = json.loads(result) + return JSONResponse(content=result_data) + else: + raise HTTPException(status_code=404, detail="Result not found") + +@mediapipe_app.get("/annotated/{filename}") +async def get_annotated_file(filename: str): + redis_key = f"mediapipe_result:{filename}" + result = redis_client.get(redis_key) + if result: + result_data = json.loads(result) + if result_data["status"] == "completed": + annotated_filename = result_data["annotated_filename"] + file_path = os.path.join(RESULT_DIR, annotated_filename) + if os.path.exists(file_path): + def iterfile(): + with open(file_path, mode="rb") as file_like: + yield from file_like + file_extension = os.path.splitext(annotated_filename)[1].lower() + return StreamingResponse(iterfile(), media_type=f"image/{file_extension[1:]}" if file_extension in ['.jpg', '.jpeg', '.png'] else "video/mp4") + + raise HTTPException(status_code=404, detail="Annotated file not found") + +def process_task(): + for message in consumer: + task = message.value + filename = task['filename'] + file_type = task['file_type'] + + file_path = os.path.join(UPLOAD_DIR, filename) + + redis_key = f"mediapipe_result:{filename}" + redis_client.set(redis_key, json.dumps({"status": "processing"})) + + try: + if file_type == "image": + with open(file_path, 'rb') as f: + content = f.read() + results, annotated_filename = process_image(content, filename) + else: + with open(file_path, 'rb') as f: + content = f.read() + results, annotated_filename = process_video(content, filename) + + if results and annotated_filename: + redis_client.set(redis_key, json.dumps({ + "results": results, + "status": "completed", + "annotated_filename": annotated_filename + })) + else: + redis_client.set(redis_key, json.dumps({"status": "failed"})) + except Exception as e: + print(f"Error processing task: {str(e)}") + redis_client.set(redis_key, json.dumps({"status": "failed", "error": str(e)})) + +def listen_redis_changes(): + pubsub = redis_client.pubsub() + pubsub.psubscribe('__keyspace@3__:mediapipe_result:*') + + for message in pubsub.listen(): + if message['type'] == 'pmessage': + key = message['channel'].decode('utf-8').split(':')[-1] + operation = message['data'].decode('utf-8') + + if operation == 'set': + value = redis_client.get(f"mediapipe_result:{key}") + if value: + result = json.loads(value) + print(f"Status update for {key}: {result['status']}") + +if __name__ == "__main__": + threading.Thread(target=process_task, daemon=True).start() + threading.Thread(target=listen_redis_changes, daemon=True).start() + uvicorn.run(app, host="0.0.0.0", port=7006) \ No newline at end of file diff --git a/api_old/before/mediapipe_key.py b/api_old/before/mediapipe_key.py new file mode 100644 index 0000000..8b2c104 --- /dev/null +++ b/api_old/before/mediapipe_key.py @@ -0,0 +1,453 @@ +from fastapi import FastAPI, HTTPException, File, UploadFile, Depends, Header +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import JSONResponse, StreamingResponse +from fastapi.security import APIKeyHeader +import cv2 +import numpy as np +import json +import uvicorn +from kafka import KafkaProducer, KafkaConsumer +from redis import Redis +import uuid +import os +from datetime import datetime, timedelta, timezone +import threading +import mediapipe as mp +from mediapipe.tasks import python +from mediapipe.tasks.python import vision + +app = FastAPI() +mediapipe_app = FastAPI() +app.mount("/mediapipe", mediapipe_app) + +# CORS configuration +ALLOWED_ORIGINS = 'https://beta.obscura.work' + +app.add_middleware( + CORSMiddleware, + allow_origins=ALLOWED_ORIGINS, + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + +# Configuration +MODEL_PATH = "/home/zydi/models/face_landmarker.task" # Replace with your model path +KAFKA_BROKER = "222.186.10.253:9092" +KAFKA_TOPIC = "mediapipe" +KAFKA_GROUP_ID = "mediapipe_group" + +REDIS_HOST = "222.186.10.253" +REDIS_PORT = 6379 +REDIS_PASSWORD = "Obscura@2024" +REDIS_DB = 10 +REDIS_API_DB = 12 +REDIS_API_USAGE_DB = 13 +UPLOAD_DIR = "/www/wwwroot/beta.obscura.work/upload_files/upload" +RESULT_DIR = "/www/wwwroot/beta.obscura.work/upload_files/result" +MAX_FILE_AGE = timedelta(hours=1) + +# Ensure directories exist +os.makedirs(UPLOAD_DIR, exist_ok=True) +os.makedirs(RESULT_DIR, exist_ok=True) + +# Initialize Kafka +producer = KafkaProducer(bootstrap_servers=[KAFKA_BROKER]) +consumer = KafkaConsumer( + KAFKA_TOPIC, + bootstrap_servers=[KAFKA_BROKER], + group_id=KAFKA_GROUP_ID, + auto_offset_reset='earliest', + enable_auto_commit=True, + value_deserializer=lambda x: json.loads(x.decode('utf-8')) +) + +# Initialize Redis +redis_client = Redis( + host=REDIS_HOST, + port=REDIS_PORT, + password=REDIS_PASSWORD, + db=REDIS_DB +) + +redis_api_client = Redis( + host=REDIS_HOST, + port=REDIS_PORT, + password=REDIS_PASSWORD, + db=REDIS_API_DB +) +redis_api_usage_client = Redis( + host=REDIS_HOST, + port=REDIS_PORT, + password=REDIS_PASSWORD, + db=REDIS_API_USAGE_DB +) + +# 添加API密钥验证 +API_KEY_NAME = "X-API-Key" +api_key_header = APIKeyHeader(name=API_KEY_NAME, auto_error=False) + +async def get_api_key(api_key: str = Header(None, alias=API_KEY_NAME)): + if api_key is None: + raise HTTPException(status_code=400, detail="API密钥缺失") + return api_key + +def calculate_tokens(file_path: str, file_type: str) -> int: + base_tokens = 0 + + try: + file_size = os.path.getsize(file_path) # 获取文件大小(字节) + + # 基础token:每MB文件大小消耗10个token + base_tokens = int((file_size / (1024 * 1024)) * 10) + + if file_type == "image": + img = cv2.imread(file_path) + if img is None: + raise ValueError("无法读取图片文件") + height, width = img.shape[:2] + pixel_count = height * width + + # 图片token:每100个像素额外消耗5个token + image_tokens = int((pixel_count / 10000) * 5) + + base_tokens += image_tokens + + elif file_type == "video": + cap = cv2.VideoCapture(file_path) + if not cap.isOpened(): + raise ValueError("无法打开视频文件") + fps = cap.get(cv2.CAP_PROP_FPS) + frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + cap.release() + + pixel_count = width * height * frame_count + duration = frame_count / fps # 视频时长(秒) + + # 视频token:每100万像素每秒额外消耗1个token + video_tokens = int((pixel_count / 10000) * (duration / 60)) + + base_tokens += video_tokens + + return max(1, base_tokens) # 确保至少返回1个token + except Exception as e: + print(f"计算token时出错: {str(e)}") + return 1 # 出错时返回默认值1 + + +async def verify_api_key(api_key: str = Depends(get_api_key)): + redis_key = f"api_key:{api_key}" + + api_key_info = redis_api_client.hgetall(redis_key) + + if not api_key_info: + return None + + api_key_info = {k.decode('utf-8'): v.decode('utf-8') for k, v in api_key_info.items()} + + if api_key_info.get('is_active') != '1': + return None + + expires_at = datetime.fromisoformat(api_key_info.get('expires_at')) + if datetime.now(timezone.utc) > expires_at: + return None + + usage_info = redis_api_usage_client.hgetall(redis_key) + usage_info = {k.decode('utf-8'): v.decode('utf-8') for k, v in usage_info.items()} + + return { + **api_key_info, + **usage_info + } + +async def update_token_usage(api_key: str, new_tokens_used: int, model_name: str): + redis_key = f"api_key:{api_key}" + current_time = datetime.now(timezone.utc).isoformat() + + pipe = redis_api_usage_client.pipeline() + + # 更新总的token使用量 + pipe.hincrby(redis_key, "tokens_used", new_tokens_used) + pipe.hset(redis_key, "last_used_at", current_time) + + # 更新特定模型的token使用量 + model_tokens_field = f"{model_name}_tokens_used" + model_last_used_field = f"{model_name}_last_used_at" + + pipe.hincrby(redis_key, model_tokens_field, new_tokens_used) + pipe.hset(redis_key, model_last_used_field, current_time) + + pipe.execute() + +class mediapipeEmbedder: + def __init__(self, model_path): + base_options = python.BaseOptions(model_asset_path=model_path) + options = vision.FaceLandmarkerOptions( + base_options=base_options, + output_face_blendshapes=True, + output_facial_transformation_matrixes=True, + num_faces=1 + ) + self.detector = vision.FaceLandmarker.create_from_options(options) + + def get_mediapipe_landmarks(self, image): + mp_image = mp.Image(image_format=mp.ImageFormat.SRGB, data=image) + detection_result = self.detector.detect(mp_image) + if detection_result.face_landmarks: + return np.array([(lm.x, lm.y, lm.z) for lm in detection_result.face_landmarks[0]]) + return None + + def process_image(self, image_data): + img = cv2.imdecode(np.frombuffer(image_data, np.uint8), cv2.IMREAD_COLOR) + landmarks = self.get_mediapipe_landmarks(img) + + if landmarks is not None: + # Calculate a more detailed mediapipe embedding + embedding = self.calculate_detailed_embedding(landmarks) + + # Draw landmarks on the image + for lm in landmarks: + cv2.circle(img, (int(lm[0]*img.shape[1]), int(lm[1]*img.shape[0])), 2, (0,255,0), -1) + + return { + "embedding": embedding, + "landmarks": landmarks.tolist() + }, img + else: + return None, img + + def calculate_detailed_embedding(self, landmarks): + # Calculate various statistical features + mean = np.mean(landmarks, axis=0) + std = np.std(landmarks, axis=0) + median = np.median(landmarks, axis=0) + min_vals = np.min(landmarks, axis=0) + max_vals = np.max(landmarks, axis=0) + + # Calculate pairwise distances between key facial landmarks + nose_tip = landmarks[4] + left_eye = landmarks[159] + right_eye = landmarks[386] + left_mouth = landmarks[61] + right_mouth = landmarks[291] + + eye_distance = np.linalg.norm(left_eye - right_eye) + mouth_width = np.linalg.norm(left_mouth - right_mouth) + nose_to_mouth = np.linalg.norm(nose_tip - (left_mouth + right_mouth) / 2) + + # Calculate face shape features + face_width = np.max(landmarks[:, 0]) - np.min(landmarks[:, 0]) + face_height = np.max(landmarks[:, 1]) - np.min(landmarks[:, 1]) + face_depth = np.max(landmarks[:, 2]) - np.min(landmarks[:, 2]) + + # Combine all features into a single embedding + embedding = np.concatenate([ + mean, std, median, min_vals, max_vals, + [eye_distance, mouth_width, nose_to_mouth, face_width, face_height, face_depth] + ]) + + return embedding.tolist() + +embedder = mediapipeEmbedder(MODEL_PATH) + +def process_image(image_data, filename): + try: + results, annotated_img = embedder.process_image(image_data) + + if results: + # Save annotated image + annotated_filename = f"mediapipe_{filename}" + annotated_path = os.path.join(RESULT_DIR, annotated_filename) + cv2.imwrite(annotated_path, annotated_img) + + return results, annotated_filename + else: + print(f"No face landmarks detected in image: {filename}") + return None, None + except Exception as e: + print(f"Error processing image {filename}: {str(e)}") + import traceback + traceback.print_exc() + return None, None + +def process_video(video_data, filename): + try: + temp_video_path = os.path.join(UPLOAD_DIR, f"mediapipe_{filename}") + with open(temp_video_path, 'wb') as temp_video: + temp_video.write(video_data) + + cap = cv2.VideoCapture(temp_video_path) + frame_count = 0 + results = [] + + fps = int(cap.get(cv2.CAP_PROP_FPS)) + width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + + annotated_filename = f"mediapipe_{filename}" + output_path = os.path.join(RESULT_DIR, annotated_filename) + out = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height)) + + while cap.isOpened(): + ret, frame = cap.read() + if not ret: + break + + if frame_count % fps == 0: + frame_results, annotated_frame = embedder.process_image(cv2.imencode('.jpg', frame)[1].tobytes()) + if frame_results: + results.append({"frame": frame_count, "results": frame_results}) + else: + annotated_frame = frame + + out.write(annotated_frame) + frame_count += 1 + + cap.release() + out.release() + + os.remove(temp_video_path) + + return results, annotated_filename + except Exception as e: + print(f"Error processing video: {str(e)}") + return None, None + +@mediapipe_app.post("/upload") +async def upload_file(file: UploadFile = File(...),api_key: str = Depends(get_api_key)): + # 验证 API key + api_key_info = await verify_api_key(api_key) + if not api_key_info: + raise HTTPException(status_code=403, detail="无效的API密钥") + + content = await file.read() + file_extension = os.path.splitext(file.filename)[1].lower() + new_filename = f"{uuid.uuid4()}{file_extension}" + + original_file_path = os.path.join(UPLOAD_DIR, new_filename) + with open(original_file_path, "wb") as f: + f.write(content) + + # 计算 token + file_type = "image" if file_extension in ['.jpg', '.jpeg', '.png'] else "video" + tokens_required = calculate_tokens(original_file_path, file_type) + if tokens_required is None or tokens_required <= 0: + raise HTTPException(status_code=500, detail="无法计算所需的token数量") + # 检查并更新 token 使用量 + usage_key = f"api_key:{api_key}" + total_tokens = int(redis_api_usage_client.hget(usage_key, "total_tokens") or 0) + tokens_used = int(redis_api_usage_client.hget(usage_key, "tokens_used") or 0) + + if tokens_used + tokens_required > total_tokens: + raise HTTPException(status_code=403, detail="Token 余额不足") + + # 更新 token 使用量 + model_name = "mediapipe" + await update_token_usage(api_key, tokens_required, model_name) + + + producer.send(KAFKA_TOPIC, json.dumps({ + "filename": new_filename, + "file_type": "image" if file_extension in ['.jpg', '.jpeg', '.png'] else "video" + }).encode('utf-8')) + + redis_key = f"mediapipe_result:{new_filename}" + redis_client.set(redis_key, json.dumps({"status": "queued"})) + + # 获取更新后的 token 使用情况 + updated_api_key_info = await verify_api_key(api_key) + new_tokens_used = int(updated_api_key_info.get("tokens_used", 0)) + model_tokens_used = int(updated_api_key_info.get(f"{model_name}_tokens_used", 0)) + + return JSONResponse(content={ + "message": "文件已上传并排队等待处理", + "filename": new_filename, + "tokens_used": tokens_required, + "total_tokens_used": new_tokens_used, + f"{model_name}_tokens_used": model_tokens_used, + "tokens_remaining": total_tokens - new_tokens_used + }) + + +@mediapipe_app.get("/result/{filename}", dependencies=[Depends(verify_api_key)]) +async def get_mediapipe_result(filename: str): + redis_key = f"mediapipe_result:{filename}" + result = redis_client.get(redis_key) + if result: + result_data = json.loads(result) + return JSONResponse(content=result_data) + else: + raise HTTPException(status_code=404, detail="Result not found") + +@mediapipe_app.get("/annotated/{filename}", dependencies=[Depends(verify_api_key)]) +async def get_annotated_file(filename: str): + redis_key = f"mediapipe_result:{filename}" + result = redis_client.get(redis_key) + if result: + result_data = json.loads(result) + if result_data["status"] == "completed": + annotated_filename = result_data["annotated_filename"] + file_path = os.path.join(RESULT_DIR, annotated_filename) + if os.path.exists(file_path): + def iterfile(): + with open(file_path, mode="rb") as file_like: + yield from file_like + file_extension = os.path.splitext(annotated_filename)[1].lower() + return StreamingResponse(iterfile(), media_type=f"image/{file_extension[1:]}" if file_extension in ['.jpg', '.jpeg', '.png'] else "video/mp4") + + raise HTTPException(status_code=404, detail="Annotated file not found") + +def process_task(): + for message in consumer: + task = message.value + filename = task['filename'] + file_type = task['file_type'] + + file_path = os.path.join(UPLOAD_DIR, filename) + + redis_key = f"mediapipe_result:{filename}" + redis_client.set(redis_key, json.dumps({"status": "processing"})) + + try: + if file_type == "image": + with open(file_path, 'rb') as f: + content = f.read() + results, annotated_filename = process_image(content, filename) + else: + with open(file_path, 'rb') as f: + content = f.read() + results, annotated_filename = process_video(content, filename) + + if results and annotated_filename: + redis_client.set(redis_key, json.dumps({ + "results": results, + "status": "completed", + "annotated_filename": annotated_filename + })) + else: + redis_client.set(redis_key, json.dumps({"status": "failed"})) + except Exception as e: + print(f"Error processing task: {str(e)}") + redis_client.set(redis_key, json.dumps({"status": "failed", "error": str(e)})) + +def listen_redis_changes(): + pubsub = redis_client.pubsub() + pubsub.psubscribe('__keyspace@3__:mediapipe_result:*') + + for message in pubsub.listen(): + if message['type'] == 'pmessage': + key = message['channel'].decode('utf-8').split(':')[-1] + operation = message['data'].decode('utf-8') + + if operation == 'set': + value = redis_client.get(f"mediapipe_result:{key}") + if value: + result = json.loads(value) + print(f"Status update for {key}: {result['status']}") + +if __name__ == "__main__": + threading.Thread(target=process_task, daemon=True).start() + threading.Thread(target=listen_redis_changes, daemon=True).start() + uvicorn.run(app, host="0.0.0.0", port=7006) \ No newline at end of file diff --git a/api_old/before/mediapipe_key2.py b/api_old/before/mediapipe_key2.py new file mode 100644 index 0000000..ecc8ce4 --- /dev/null +++ b/api_old/before/mediapipe_key2.py @@ -0,0 +1,462 @@ +from fastapi import FastAPI, HTTPException, File, UploadFile, Depends, Security +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import JSONResponse, StreamingResponse +from fastapi.security import APIKeyHeader +import cv2 +import numpy as np +import json +import uvicorn +from kafka import KafkaProducer, KafkaConsumer +from redis import Redis +import uuid +import os +from datetime import datetime, timedelta, timezone +import threading +import mediapipe as mp +from mediapipe.tasks import python +from mediapipe.tasks.python import vision +import string + +app = FastAPI() +mediapipe_app = FastAPI() +app.mount("/mediapipe", mediapipe_app) + +# CORS configuration +ALLOWED_ORIGINS = 'https://beta.obscura.work' + +app.add_middleware( + CORSMiddleware, + allow_origins=ALLOWED_ORIGINS, + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + +# Configuration +MODEL_PATH = "/home/zydi/models/face_landmarker.task" # Replace with your model path +KAFKA_BROKER = "222.186.10.253:9092" +KAFKA_TOPIC = "mediapipe" +KAFKA_GROUP_ID = "mediapipe_group" + +REDIS_HOST = "222.186.10.253" +REDIS_PORT = 6379 +REDIS_PASSWORD = "Obscura@2024" +REDIS_DB = 10 +REDIS_API_DB = 12 +REDIS_API_USAGE_DB = 13 +UPLOAD_DIR = "/www/wwwroot/beta.obscura.work/upload_files/upload" +RESULT_DIR = "/www/wwwroot/beta.obscura.work/upload_files/result" +MAX_FILE_AGE = timedelta(hours=1) + +# Ensure directories exist +os.makedirs(UPLOAD_DIR, exist_ok=True) +os.makedirs(RESULT_DIR, exist_ok=True) + +# Initialize Kafka +producer = KafkaProducer(bootstrap_servers=[KAFKA_BROKER]) +consumer = KafkaConsumer( + KAFKA_TOPIC, + bootstrap_servers=[KAFKA_BROKER], + group_id=KAFKA_GROUP_ID, + auto_offset_reset='earliest', + enable_auto_commit=True, + value_deserializer=lambda x: json.loads(x.decode('utf-8')) +) + +# Initialize Redis +redis_client = Redis( + host=REDIS_HOST, + port=REDIS_PORT, + password=REDIS_PASSWORD, + db=REDIS_DB +) + +redis_api_client = Redis( + host=REDIS_HOST, + port=REDIS_PORT, + password=REDIS_PASSWORD, + db=REDIS_API_DB +) +redis_api_usage_client = Redis( + host=REDIS_HOST, + port=REDIS_PORT, + password=REDIS_PASSWORD, + db=REDIS_API_USAGE_DB +) + +# 定义API密钥头部 +api_key_header = APIKeyHeader(name="Authorization", auto_error=False) + +# 定义 base62 字符集 +BASE62 = string.digits + string.ascii_letters + +# 验证API密钥的函数 +async def get_api_key(api_key: str = Security(api_key_header)): + if api_key and api_key.startswith("Bearer "): + key = api_key.split(" ")[1] + if key.startswith("obs-"): + return key + raise HTTPException( + status_code=401, + detail="无效的API密钥", + headers={"WWW-Authenticate": "Bearer"}, + ) + +def calculate_tokens(file_path: str, file_type: str) -> int: + base_tokens = 0 + + try: + file_size = os.path.getsize(file_path) # 获取文件大小(字节) + + # 基础token:每MB文件大小消耗10个token + base_tokens = int((file_size / (1024 * 1024)) * 10) + + if file_type == "image": + img = cv2.imread(file_path) + if img is None: + raise ValueError("无法读取图片文件") + height, width = img.shape[:2] + pixel_count = height * width + + # 图片token:每100个像素额外消耗5个token + image_tokens = int((pixel_count / 10000) * 5) + + base_tokens += image_tokens + + elif file_type == "video": + cap = cv2.VideoCapture(file_path) + if not cap.isOpened(): + raise ValueError("无法打开视频文件") + fps = cap.get(cv2.CAP_PROP_FPS) + frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + cap.release() + + pixel_count = width * height * frame_count + duration = frame_count / fps # 视频时长(秒) + + # 视频token:每100万像素每秒额外消耗1个token + video_tokens = int((pixel_count / 10000) * (duration / 60)) + + base_tokens += video_tokens + + return max(1, base_tokens) # 确保至少返回1个token + except Exception as e: + print(f"计算token时出错: {str(e)}") + return 1 # 出错时返回默认值1 + + +async def verify_api_key(api_key: str = Depends(get_api_key)): + redis_key = f"api_key:{api_key}" + + api_key_info = redis_api_client.hgetall(redis_key) + + if not api_key_info: + raise HTTPException(status_code=401, detail="无效的API密钥") + + api_key_info = {k.decode('utf-8'): v.decode('utf-8') for k, v in api_key_info.items()} + + if api_key_info.get('is_active') != '1': + raise HTTPException(status_code=401, detail="API密钥已停用") + + expires_at = datetime.fromisoformat(api_key_info.get('expires_at')) + if datetime.now(timezone.utc) > expires_at: + raise HTTPException(status_code=401, detail="API密钥已过期") + + usage_info = redis_api_usage_client.hgetall(redis_key) + usage_info = {k.decode('utf-8'): v.decode('utf-8') for k, v in usage_info.items()} + + return { + "api_key": api_key, + **api_key_info, + **usage_info + } + +async def update_token_usage(api_key: str, new_tokens_used: int, model_name: str): + redis_key = f"api_key:{api_key}" + current_time = datetime.now(timezone.utc).isoformat() + + pipe = redis_api_usage_client.pipeline() + + # 更新总的token使用量 + pipe.hincrby(redis_key, "tokens_used", new_tokens_used) + pipe.hset(redis_key, "last_used_at", current_time) + + # 更新特定模型的token使用量 + model_tokens_field = f"{model_name}_tokens_used" + model_last_used_field = f"{model_name}_last_used_at" + + pipe.hincrby(redis_key, model_tokens_field, new_tokens_used) + pipe.hset(redis_key, model_last_used_field, current_time) + + pipe.execute() + +class mediapipeEmbedder: + def __init__(self, model_path): + base_options = python.BaseOptions(model_asset_path=model_path) + options = vision.FaceLandmarkerOptions( + base_options=base_options, + output_face_blendshapes=True, + output_facial_transformation_matrixes=True, + num_faces=1 + ) + self.detector = vision.FaceLandmarker.create_from_options(options) + + def get_mediapipe_landmarks(self, image): + mp_image = mp.Image(image_format=mp.ImageFormat.SRGB, data=image) + detection_result = self.detector.detect(mp_image) + if detection_result.face_landmarks: + return np.array([(lm.x, lm.y, lm.z) for lm in detection_result.face_landmarks[0]]) + return None + + def process_image(self, image_data): + img = cv2.imdecode(np.frombuffer(image_data, np.uint8), cv2.IMREAD_COLOR) + landmarks = self.get_mediapipe_landmarks(img) + + if landmarks is not None: + # Calculate a more detailed mediapipe embedding + embedding = self.calculate_detailed_embedding(landmarks) + + # Draw landmarks on the image + for lm in landmarks: + cv2.circle(img, (int(lm[0]*img.shape[1]), int(lm[1]*img.shape[0])), 2, (0,255,0), -1) + + return { + "embedding": embedding, + "landmarks": landmarks.tolist() + }, img + else: + return None, img + + def calculate_detailed_embedding(self, landmarks): + # Calculate various statistical features + mean = np.mean(landmarks, axis=0) + std = np.std(landmarks, axis=0) + median = np.median(landmarks, axis=0) + min_vals = np.min(landmarks, axis=0) + max_vals = np.max(landmarks, axis=0) + + # Calculate pairwise distances between key facial landmarks + nose_tip = landmarks[4] + left_eye = landmarks[159] + right_eye = landmarks[386] + left_mouth = landmarks[61] + right_mouth = landmarks[291] + + eye_distance = np.linalg.norm(left_eye - right_eye) + mouth_width = np.linalg.norm(left_mouth - right_mouth) + nose_to_mouth = np.linalg.norm(nose_tip - (left_mouth + right_mouth) / 2) + + # Calculate face shape features + face_width = np.max(landmarks[:, 0]) - np.min(landmarks[:, 0]) + face_height = np.max(landmarks[:, 1]) - np.min(landmarks[:, 1]) + face_depth = np.max(landmarks[:, 2]) - np.min(landmarks[:, 2]) + + # Combine all features into a single embedding + embedding = np.concatenate([ + mean, std, median, min_vals, max_vals, + [eye_distance, mouth_width, nose_to_mouth, face_width, face_height, face_depth] + ]) + + return embedding.tolist() + +embedder = mediapipeEmbedder(MODEL_PATH) + +def process_image(image_data, filename): + try: + results, annotated_img = embedder.process_image(image_data) + + if results: + # Save annotated image + annotated_filename = f"mediapipe_{filename}" + annotated_path = os.path.join(RESULT_DIR, annotated_filename) + cv2.imwrite(annotated_path, annotated_img) + + return results, annotated_filename + else: + print(f"No face landmarks detected in image: {filename}") + return None, None + except Exception as e: + print(f"Error processing image {filename}: {str(e)}") + import traceback + traceback.print_exc() + return None, None + +def process_video(video_data, filename): + try: + temp_video_path = os.path.join(UPLOAD_DIR, f"mediapipe_{filename}") + with open(temp_video_path, 'wb') as temp_video: + temp_video.write(video_data) + + cap = cv2.VideoCapture(temp_video_path) + frame_count = 0 + results = [] + + fps = int(cap.get(cv2.CAP_PROP_FPS)) + width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + + annotated_filename = f"mediapipe_{filename}" + output_path = os.path.join(RESULT_DIR, annotated_filename) + out = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height)) + + while cap.isOpened(): + ret, frame = cap.read() + if not ret: + break + + if frame_count % fps == 0: + frame_results, annotated_frame = embedder.process_image(cv2.imencode('.jpg', frame)[1].tobytes()) + if frame_results: + results.append({"frame": frame_count, "results": frame_results}) + else: + annotated_frame = frame + + out.write(annotated_frame) + frame_count += 1 + + cap.release() + out.release() + + os.remove(temp_video_path) + + return results, annotated_filename + except Exception as e: + print(f"Error processing video: {str(e)}") + return None, None + +@mediapipe_app.post("/upload") +async def upload_file(file: UploadFile = File(...), api_key_info: dict = Depends(verify_api_key)): + + content = await file.read() + file_extension = os.path.splitext(file.filename)[1].lower() + new_filename = f"{uuid.uuid4()}{file_extension}" + + original_file_path = os.path.join(UPLOAD_DIR, new_filename) + with open(original_file_path, "wb") as f: + f.write(content) + + # 计算 token + file_type = "image" if file_extension in ['.jpg', '.jpeg', '.png'] else "video" + tokens_required = calculate_tokens(original_file_path, file_type) + if tokens_required is None or tokens_required <= 0: + raise HTTPException(status_code=500, detail="无法计算所需的token数量") + + # 检查并更新 token 使用量 + api_key = api_key_info['api_key'] + usage_key = f"api_key:{api_key}" + total_tokens = int(redis_api_usage_client.hget(usage_key, "total_tokens") or 0) + tokens_used = int(redis_api_usage_client.hget(usage_key, "tokens_used") or 0) + + if tokens_used + tokens_required > total_tokens: + raise HTTPException(status_code=403, detail="Token 余额不足") + + # 更新 token 使用量 + model_name = "mediapipe" + await update_token_usage(api_key, tokens_required, model_name) + + + producer.send(KAFKA_TOPIC, json.dumps({ + "filename": new_filename, + "file_type": "image" if file_extension in ['.jpg', '.jpeg', '.png'] else "video" + }).encode('utf-8')) + + redis_key = f"mediapipe_result:{new_filename}" + redis_client.set(redis_key, json.dumps({"status": "queued"})) + + # 获取更新后的 token 使用情况 + updated_api_key_info = await verify_api_key(api_key) + new_tokens_used = int(updated_api_key_info.get("tokens_used", 0)) + model_tokens_used = int(updated_api_key_info.get(f"{model_name}_tokens_used", 0)) + + return JSONResponse(content={ + "message": "文件已上传并排队等待处理", + "filename": new_filename, + "tokens_used": tokens_required, + "total_tokens_used": new_tokens_used, + f"{model_name}_tokens_used": model_tokens_used, + "tokens_remaining": total_tokens - new_tokens_used + }) + + +@mediapipe_app.get("/result/{filename}", dependencies=[Depends(verify_api_key)]) +async def get_mediapipe_result(filename: str): + redis_key = f"mediapipe_result:{filename}" + result = redis_client.get(redis_key) + if result: + result_data = json.loads(result) + return JSONResponse(content=result_data) + else: + raise HTTPException(status_code=404, detail="Result not found") + +@mediapipe_app.get("/annotated/{filename}", dependencies=[Depends(verify_api_key)]) +async def get_annotated_file(filename: str): + redis_key = f"mediapipe_result:{filename}" + result = redis_client.get(redis_key) + if result: + result_data = json.loads(result) + if result_data["status"] == "completed": + annotated_filename = result_data["annotated_filename"] + file_path = os.path.join(RESULT_DIR, annotated_filename) + if os.path.exists(file_path): + def iterfile(): + with open(file_path, mode="rb") as file_like: + yield from file_like + file_extension = os.path.splitext(annotated_filename)[1].lower() + return StreamingResponse(iterfile(), media_type=f"image/{file_extension[1:]}" if file_extension in ['.jpg', '.jpeg', '.png'] else "video/mp4") + + raise HTTPException(status_code=404, detail="Annotated file not found") + +def process_task(): + for message in consumer: + task = message.value + filename = task['filename'] + file_type = task['file_type'] + + file_path = os.path.join(UPLOAD_DIR, filename) + + redis_key = f"mediapipe_result:{filename}" + redis_client.set(redis_key, json.dumps({"status": "processing"})) + + try: + if file_type == "image": + with open(file_path, 'rb') as f: + content = f.read() + results, annotated_filename = process_image(content, filename) + else: + with open(file_path, 'rb') as f: + content = f.read() + results, annotated_filename = process_video(content, filename) + + if results and annotated_filename: + redis_client.set(redis_key, json.dumps({ + "results": results, + "status": "completed", + "annotated_filename": annotated_filename + })) + else: + redis_client.set(redis_key, json.dumps({"status": "failed"})) + except Exception as e: + print(f"Error processing task: {str(e)}") + redis_client.set(redis_key, json.dumps({"status": "failed", "error": str(e)})) + +def listen_redis_changes(): + pubsub = redis_client.pubsub() + pubsub.psubscribe('__keyspace@3__:mediapipe_result:*') + + for message in pubsub.listen(): + if message['type'] == 'pmessage': + key = message['channel'].decode('utf-8').split(':')[-1] + operation = message['data'].decode('utf-8') + + if operation == 'set': + value = redis_client.get(f"mediapipe_result:{key}") + if value: + result = json.loads(value) + print(f"Status update for {key}: {result['status']}") + +if __name__ == "__main__": + threading.Thread(target=process_task, daemon=True).start() + threading.Thread(target=listen_redis_changes, daemon=True).start() + uvicorn.run(app, host="0.0.0.0", port=7006) \ No newline at end of file diff --git a/api_old/before/pose_api.py b/api_old/before/pose_api.py new file mode 100644 index 0000000..fa90e34 --- /dev/null +++ b/api_old/before/pose_api.py @@ -0,0 +1,312 @@ +from fastapi import FastAPI, HTTPException, File, UploadFile +from fastapi.responses import StreamingResponse, JSONResponse +from fastapi.middleware.cors import CORSMiddleware +from ultralytics import YOLO +import cv2 +import numpy as np +import json +import uvicorn +from kafka import KafkaProducer, KafkaConsumer +from redis import Redis +import io +import uuid +import os +from datetime import datetime, timedelta +import threading +import torch +torch.cuda.set_device(1) + + +app = FastAPI() +pose_app = FastAPI() +app.mount("/pose", pose_app) + +# CORS配置 +ALLOWED_ORIGINS = 'https://beta.obscura.work' + +# 只为主应用添加CORS中间件 +app.add_middleware( + CORSMiddleware, + allow_origins=ALLOWED_ORIGINS, + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + +# 配置 +MODEL_PATH = "/home/zydi/models/yolov8x-pose.pt" # 请替换为您的模型路径 +KAFKA_BROKER = "222.186.10.253:9092" +KAFKA_TOPIC = "pose" # 指定Kafka topic +KAFKA_GROUP_ID = "pose_group" # 指定消费者组ID + +REDIS_HOST = "222.186.10.253" +REDIS_PORT = 6379 +REDIS_PASSWORD = "Obscura@2024" +REDIS_DB = 3 + +UPLOAD_DIR = "/www/wwwroot/beta.obscura.work/upload_files/upload" +RESULT_DIR = "/www/wwwroot/beta.obscura.work/upload_files/result" +MAX_FILE_AGE = timedelta(hours=1) + +# 确保目录存在 +os.makedirs(UPLOAD_DIR, exist_ok=True) +os.makedirs(RESULT_DIR, exist_ok=True) + +# 初始化 Kafka +producer = KafkaProducer(bootstrap_servers=[KAFKA_BROKER]) +consumer = KafkaConsumer( + KAFKA_TOPIC, + bootstrap_servers=[KAFKA_BROKER], + group_id=KAFKA_GROUP_ID, + auto_offset_reset='earliest', + enable_auto_commit=True, + value_deserializer=lambda x: json.loads(x.decode('utf-8')) +) + +# 初始化 Redis +redis_client = Redis( + host=REDIS_HOST, + port=REDIS_PORT, + password=REDIS_PASSWORD, + db=REDIS_DB +) + +class PoseDetector: + def __init__(self, model_path): + self.model = YOLO(model_path).to('cuda:1') + + def detect(self, frame): + results = self.model(frame, device='cuda:1') + return results + + def format_results(self, results): + formatted_results = [] + for r in results: + boxes = r.boxes + keypoints = r.keypoints + for i in range(len(boxes)): + box = boxes[i] + kpts = keypoints[i] + formatted_results.append({ + "bbox": box.xyxy.tolist()[0], + "confidence": box.conf.item(), + "keypoints": kpts.xy.tolist()[0] + }) + return formatted_results + + def draw_results(self, frame, results, original_shape): + for r in results: + annotated_frame = r.plot(img=frame) + # 调整坐标以适应原始图像大小 + h, w = annotated_frame.shape[:2] + scale_x, scale_y = original_shape[1] / w, original_shape[0] / h + annotated_frame = cv2.resize(annotated_frame, (original_shape[1], original_shape[0])) + return annotated_frame + +detector = PoseDetector(MODEL_PATH) + +def process_image(image_data, filename): + try: + img = cv2.imdecode(np.frombuffer(image_data, np.uint8), cv2.IMREAD_COLOR) + original_shape = img.shape + # Convert BGR to RGB + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + + # Resize image to fit model requirements (640x640) + img_resized = cv2.resize(img, (640, 640)) + + # Normalize and reshape to BCHW format + img_tensor = torch.from_numpy(img_resized).float().permute(2, 0, 1).unsqueeze(0) / 255.0 + img_tensor = img_tensor.to('cuda:1') + + results = detector.detect(img_tensor) + + # Format results for JSON + json_results = detector.format_results(results) + + # Draw results on original image + annotated_img = detector.draw_results(img_resized, results, original_shape) + + # Save annotated image + annotated_filename = f"pose_{filename}" + annotated_path = os.path.join(RESULT_DIR, annotated_filename) + cv2.imwrite(annotated_path, cv2.cvtColor(annotated_img, cv2.COLOR_RGB2BGR)) + + return json_results, annotated_filename + except Exception as e: + print(f"Error processing image: {str(e)}") + return None, None + +def process_video(video_data, filename): + try: + temp_video_path = os.path.join(UPLOAD_DIR, f"pose_{filename}") + with open(temp_video_path, 'wb') as temp_video: + temp_video.write(video_data) + + cap = cv2.VideoCapture(temp_video_path) + frame_count = 0 + json_results = [] + + # Get video properties + fps = int(cap.get(cv2.CAP_PROP_FPS)) + width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + original_shape = (height, width) + + # Create output video file + annotated_filename = f"pose_{filename}" + output_path = os.path.join(RESULT_DIR, annotated_filename) + out = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height)) + + while cap.isOpened(): + ret, frame = cap.read() + if not ret: + break + + # Process one frame per second + if frame_count % fps == 0: + # Convert BGR to RGB + frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + + # Resize frame to fit model requirements (640x640) + frame_resized = cv2.resize(frame_rgb, (640, 640)) + + # Normalize and reshape to BCHW format + frame_tensor = torch.from_numpy(frame_resized).float().permute(2, 0, 1).unsqueeze(0) / 255.0 + frame_tensor = frame_tensor.to('cuda:1') + + results = detector.detect(frame_tensor) + frame_json_results = detector.format_results(results) + json_results.append({"frame": frame_count, "detections": frame_json_results}) + + # Draw results on original frame + annotated_frame = detector.draw_results(frame_resized, results, original_shape) + # Convert RGB back to BGR for OpenCV + annotated_frame = cv2.cvtColor(annotated_frame, cv2.COLOR_RGB2BGR) + else: + annotated_frame = frame + + out.write(annotated_frame) + frame_count += 1 + + cap.release() + out.release() + + # Clean up temporary input video file + os.remove(temp_video_path) + + return json_results, annotated_filename + except Exception as e: + print(f"Error processing video: {str(e)}") + return None, None + +@pose_app.post("/upload") +async def upload_file(file: UploadFile = File(...)): + content = await file.read() + file_extension = os.path.splitext(file.filename)[1].lower() + new_filename = f"{uuid.uuid4()}{file_extension}" + + # Save the original file + original_file_path = os.path.join(UPLOAD_DIR, new_filename) + with open(original_file_path, "wb") as f: + f.write(content) + + # Send processing task to Kafka + producer.send(KAFKA_TOPIC, json.dumps({ + "filename": new_filename, + "file_type": "image" if file_extension in ['.jpg', '.jpeg', '.png'] else "video" + }).encode('utf-8')) + + # Set initial status in Redis + redis_key = f"pose_result:{new_filename}" + redis_client.set(redis_key, json.dumps({"status": "queued"})) + + return JSONResponse(content={"message": "File uploaded and queued for processing", "filename": new_filename}) + +@pose_app.get("/result/{filename}") +async def get_pose_result(filename: str): + redis_key = f"pose_result:{filename}" + result = redis_client.get(redis_key) + if result: + result_data = json.loads(result) + return JSONResponse(content=result_data) # 直接返回整个结果,包括 status + else: + raise HTTPException(status_code=404, detail="Result not found") + +@pose_app.get("/annotated/{filename}") +async def get_annotated_file(filename: str): + redis_key = f"pose_result:{filename}" + result = redis_client.get(redis_key) + if result: + result_data = json.loads(result) + if result_data["status"] == "completed": + annotated_filename = result_data["annotated_filename"] + file_path = os.path.join(RESULT_DIR, annotated_filename) + if os.path.exists(file_path): + def iterfile(): + with open(file_path, mode="rb") as file_like: + yield from file_like + file_extension = os.path.splitext(annotated_filename)[1].lower() + return StreamingResponse(iterfile(), media_type=f"image/{file_extension[1:]}" if file_extension in ['.jpg', '.jpeg', '.png'] else "video/mp4") + + raise HTTPException(status_code=404, detail="Annotated file not found") + +def process_task(): + for message in consumer: + task = message.value + filename = task['filename'] + file_type = task['file_type'] + + file_path = os.path.join(UPLOAD_DIR, filename) + + # Update status to "processing" + redis_key = f"pose_result:{filename}" + redis_client.set(redis_key, json.dumps({"status": "processing"})) + + try: + if file_type == "image": + with open(file_path, 'rb') as f: + content = f.read() + json_results, annotated_filename = process_image(content, filename) + else: + with open(file_path, 'rb') as f: + content = f.read() + json_results, annotated_filename = process_video(content, filename) + + if json_results and annotated_filename: + redis_client.set(redis_key, json.dumps({ + "json_results": json_results, + "status": "completed", + "annotated_filename": annotated_filename + })) + else: + redis_client.set(redis_key, json.dumps({"status": "failed"})) + except Exception as e: + print(f"Error processing task: {str(e)}") + redis_client.set(redis_key, json.dumps({"status": "failed", "error": str(e)})) + +def listen_redis_changes(): + pubsub = redis_client.pubsub() + pubsub.psubscribe('__keyspace@3__:pose_result:*') # 监听所有pose_result键的变化 + + for message in pubsub.listen(): + if message['type'] == 'pmessage': + key = message['channel'].decode('utf-8').split(':')[-1] + operation = message['data'].decode('utf-8') + + if operation == 'set': + value = redis_client.get(f"pose_result:{key}") + if value: + result = json.loads(value) + print(f"Status update for {key}: {result['status']}") + + # 这里可以添加其他处理逻辑,比如发送通知等 + +if __name__ == "__main__": + # 启动处理任务的线程 + threading.Thread(target=process_task, daemon=True).start() + + # 启动Redis监听线程 + threading.Thread(target=listen_redis_changes, daemon=True).start() + + uvicorn.run(app, host="0.0.0.0", port=7001) \ No newline at end of file diff --git a/api_old/before/pose_key.py b/api_old/before/pose_key.py new file mode 100644 index 0000000..4d898c2 --- /dev/null +++ b/api_old/before/pose_key.py @@ -0,0 +1,461 @@ +from fastapi import FastAPI, HTTPException, File, UploadFile, Depends, Header +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import JSONResponse, StreamingResponse +from fastapi.security import APIKeyHeader +from ultralytics import YOLO +import cv2 +import numpy as np +import json +import uvicorn +from kafka import KafkaProducer, KafkaConsumer +from redis import Redis +import uuid +import os +from datetime import datetime, timedelta, timezone +import threading +import torch +torch.cuda.set_device(1) + + +app = FastAPI() +pose_app = FastAPI() +app.mount("/pose", pose_app) + +# CORS配置 +ALLOWED_ORIGINS = 'https://beta.obscura.work' + +# 只为主应用添加CORS中间件 +app.add_middleware( + CORSMiddleware, + allow_origins=ALLOWED_ORIGINS, + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + +# 配置 +MODEL_PATH = "/home/zydi/models/yolov8x-pose.pt" # 请替换为您的模型路径 +KAFKA_BROKER = "222.186.10.253:9092" +KAFKA_TOPIC = "pose" # 指定Kafka topic +KAFKA_GROUP_ID = "pose_group" # 指定消费者组ID + +REDIS_HOST = "222.186.10.253" +REDIS_PORT = 6379 +REDIS_PASSWORD = "Obscura@2024" +REDIS_DB = 3 +REDIS_API_DB = 12 +REDIS_API_USAGE_DB = 13 +UPLOAD_DIR = "/www/wwwroot/beta.obscura.work/upload_files/upload" +RESULT_DIR = "/www/wwwroot/beta.obscura.work/upload_files/result" +MAX_FILE_AGE = timedelta(hours=1) + +# 确保目录存在 +os.makedirs(UPLOAD_DIR, exist_ok=True) +os.makedirs(RESULT_DIR, exist_ok=True) + +# 初始化 Kafka +producer = KafkaProducer(bootstrap_servers=[KAFKA_BROKER]) +consumer = KafkaConsumer( + KAFKA_TOPIC, + bootstrap_servers=[KAFKA_BROKER], + group_id=KAFKA_GROUP_ID, + auto_offset_reset='earliest', + enable_auto_commit=True, + value_deserializer=lambda x: json.loads(x.decode('utf-8')) +) + +# 初始化 Redis +redis_client = Redis( + host=REDIS_HOST, + port=REDIS_PORT, + password=REDIS_PASSWORD, + db=REDIS_DB +) + +redis_api_client = Redis( + host=REDIS_HOST, + port=REDIS_PORT, + password=REDIS_PASSWORD, + db=REDIS_API_DB +) +redis_api_usage_client = Redis( + host=REDIS_HOST, + port=REDIS_PORT, + password=REDIS_PASSWORD, + db=REDIS_API_USAGE_DB +) + +# 添加API密钥验证 +API_KEY_NAME = "X-API-Key" +api_key_header = APIKeyHeader(name=API_KEY_NAME, auto_error=False) + +async def get_api_key(api_key: str = Header(None, alias=API_KEY_NAME)): + if api_key is None: + raise HTTPException(status_code=400, detail="API密钥缺失") + return api_key + +def calculate_tokens(file_path: str, file_type: str) -> int: + base_tokens = 0 + + try: + file_size = os.path.getsize(file_path) # 获取文件大小(字节) + + # 基础token:每MB文件大小消耗10个token + base_tokens = int((file_size / (1024 * 1024)) * 10) + + if file_type == "image": + img = cv2.imread(file_path) + if img is None: + raise ValueError("无法读取图片文件") + height, width = img.shape[:2] + pixel_count = height * width + + # 图片token:每100个像素额外消耗5个token + image_tokens = int((pixel_count / 10000) * 5) + + base_tokens += image_tokens + + elif file_type == "video": + cap = cv2.VideoCapture(file_path) + if not cap.isOpened(): + raise ValueError("无法打开视频文件") + fps = cap.get(cv2.CAP_PROP_FPS) + frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + cap.release() + + pixel_count = width * height * frame_count + duration = frame_count / fps # 视频时长(秒) + + # 视频token:每100万像素每秒额外消耗1个token + video_tokens = int((pixel_count / 10000) * (duration / 60)) + + base_tokens += video_tokens + + return max(1, base_tokens) # 确保至少返回1个token + except Exception as e: + print(f"计算token时出错: {str(e)}") + return 1 # 出错时返回默认值1 + + +async def verify_api_key(api_key: str = Depends(get_api_key)): + redis_key = f"api_key:{api_key}" + + api_key_info = redis_api_client.hgetall(redis_key) + + if not api_key_info: + return None + + api_key_info = {k.decode('utf-8'): v.decode('utf-8') for k, v in api_key_info.items()} + + if api_key_info.get('is_active') != '1': + return None + + expires_at = datetime.fromisoformat(api_key_info.get('expires_at')) + if datetime.now(timezone.utc) > expires_at: + return None + + usage_info = redis_api_usage_client.hgetall(redis_key) + usage_info = {k.decode('utf-8'): v.decode('utf-8') for k, v in usage_info.items()} + + return { + **api_key_info, + **usage_info + } + +async def update_token_usage(api_key: str, new_tokens_used: int, model_name: str): + redis_key = f"api_key:{api_key}" + current_time = datetime.now(timezone.utc).isoformat() + + pipe = redis_api_usage_client.pipeline() + + # 更新总的token使用量 + pipe.hincrby(redis_key, "tokens_used", new_tokens_used) + pipe.hset(redis_key, "last_used_at", current_time) + + # 更新特定模型的token使用量 + model_tokens_field = f"{model_name}_tokens_used" + model_last_used_field = f"{model_name}_last_used_at" + + pipe.hincrby(redis_key, model_tokens_field, new_tokens_used) + pipe.hset(redis_key, model_last_used_field, current_time) + + pipe.execute() + +class PoseDetector: + def __init__(self, model_path): + self.model = YOLO(model_path).to('cuda:1') + + def detect(self, frame): + results = self.model(frame, device='cuda:1') + return results + + def format_results(self, results): + formatted_results = [] + for r in results: + boxes = r.boxes + keypoints = r.keypoints + for i in range(len(boxes)): + box = boxes[i] + kpts = keypoints[i] + formatted_results.append({ + "bbox": box.xyxy.tolist()[0], + "confidence": box.conf.item(), + "keypoints": kpts.xy.tolist()[0] + }) + return formatted_results + + def draw_results(self, frame, results, original_shape): + for r in results: + annotated_frame = r.plot(img=frame) + # 调整坐标以适应原始图像大小 + h, w = annotated_frame.shape[:2] + scale_x, scale_y = original_shape[1] / w, original_shape[0] / h + annotated_frame = cv2.resize(annotated_frame, (original_shape[1], original_shape[0])) + return annotated_frame + +detector = PoseDetector(MODEL_PATH) + +def process_image(image_data, filename): + try: + img = cv2.imdecode(np.frombuffer(image_data, np.uint8), cv2.IMREAD_COLOR) + original_shape = img.shape + # Convert BGR to RGB + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + + # Resize image to fit model requirements (640x640) + img_resized = cv2.resize(img, (640, 640)) + + # Normalize and reshape to BCHW format + img_tensor = torch.from_numpy(img_resized).float().permute(2, 0, 1).unsqueeze(0) / 255.0 + img_tensor = img_tensor.to('cuda:1') + + results = detector.detect(img_tensor) + + # Format results for JSON + json_results = detector.format_results(results) + + # Draw results on original image + annotated_img = detector.draw_results(img_resized, results, original_shape) + + # Save annotated image + annotated_filename = f"pose_{filename}" + annotated_path = os.path.join(RESULT_DIR, annotated_filename) + cv2.imwrite(annotated_path, cv2.cvtColor(annotated_img, cv2.COLOR_RGB2BGR)) + + return json_results, annotated_filename + except Exception as e: + print(f"Error processing image: {str(e)}") + return None, None + +def process_video(video_data, filename): + try: + temp_video_path = os.path.join(UPLOAD_DIR, f"pose_{filename}") + with open(temp_video_path, 'wb') as temp_video: + temp_video.write(video_data) + + cap = cv2.VideoCapture(temp_video_path) + frame_count = 0 + json_results = [] + + # Get video properties + fps = int(cap.get(cv2.CAP_PROP_FPS)) + width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + original_shape = (height, width) + + # Create output video file + annotated_filename = f"pose_{filename}" + output_path = os.path.join(RESULT_DIR, annotated_filename) + out = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height)) + + while cap.isOpened(): + ret, frame = cap.read() + if not ret: + break + + # Process one frame per second + if frame_count % fps == 0: + # Convert BGR to RGB + frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + + # Resize frame to fit model requirements (640x640) + frame_resized = cv2.resize(frame_rgb, (640, 640)) + + # Normalize and reshape to BCHW format + frame_tensor = torch.from_numpy(frame_resized).float().permute(2, 0, 1).unsqueeze(0) / 255.0 + frame_tensor = frame_tensor.to('cuda:1') + + results = detector.detect(frame_tensor) + frame_json_results = detector.format_results(results) + json_results.append({"frame": frame_count, "detections": frame_json_results}) + + # Draw results on original frame + annotated_frame = detector.draw_results(frame_resized, results, original_shape) + # Convert RGB back to BGR for OpenCV + annotated_frame = cv2.cvtColor(annotated_frame, cv2.COLOR_RGB2BGR) + else: + annotated_frame = frame + + out.write(annotated_frame) + frame_count += 1 + + cap.release() + out.release() + + # Clean up temporary input video file + os.remove(temp_video_path) + + return json_results, annotated_filename + except Exception as e: + print(f"Error processing video: {str(e)}") + return None, None + +@pose_app.post("/upload") +async def upload_file(file: UploadFile = File(...),api_key: str = Depends(get_api_key)): + # 验证 API key + api_key_info = await verify_api_key(api_key) + if not api_key_info: + raise HTTPException(status_code=403, detail="无效的API密钥") + + content = await file.read() + file_extension = os.path.splitext(file.filename)[1].lower() + new_filename = f"{uuid.uuid4()}{file_extension}" + + # Save the original file + original_file_path = os.path.join(UPLOAD_DIR, new_filename) + with open(original_file_path, "wb") as f: + f.write(content) + + # 计算 token + file_type = "image" if file_extension in ['.jpg', '.jpeg', '.png'] else "video" + tokens_required = calculate_tokens(original_file_path, file_type) + if tokens_required is None or tokens_required <= 0: + raise HTTPException(status_code=500, detail="无法计算所需的token数量") + + # 检查并更新 token 使用量 + usage_key = f"api_key:{api_key}" + total_tokens = int(redis_api_usage_client.hget(usage_key, "total_tokens") or 0) + tokens_used = int(redis_api_usage_client.hget(usage_key, "tokens_used") or 0) + + if tokens_used + tokens_required > total_tokens: + raise HTTPException(status_code=403, detail="Token 余额不足") + + # 更新 token 使用量 + model_name = "yolov8x-pose" + await update_token_usage(api_key, tokens_required, model_name) + + + # Send processing task to Kafka + producer.send(KAFKA_TOPIC, json.dumps({ + "filename": new_filename, + "file_type": "image" if file_extension in ['.jpg', '.jpeg', '.png'] else "video" + }).encode('utf-8')) + + # Set initial status in Redis + redis_key = f"pose_result:{new_filename}" + redis_client.set(redis_key, json.dumps({"status": "queued"})) + + # 获取更新后的 token 使用情况 + updated_api_key_info = await verify_api_key(api_key) + new_tokens_used = int(updated_api_key_info.get("tokens_used", 0)) + model_tokens_used = int(updated_api_key_info.get(f"{model_name}_tokens_used", 0)) + + return JSONResponse(content={ + "message": "文件已上传并排队等待处理", + "filename": new_filename, + "tokens_used": tokens_required, + "total_tokens_used": new_tokens_used, + f"{model_name}_tokens_used": model_tokens_used, + "tokens_remaining": total_tokens - new_tokens_used + }) + + +@pose_app.get("/result/{filename}", dependencies=[Depends(verify_api_key)]) +async def get_pose_result(filename: str): + redis_key = f"pose_result:{filename}" + result = redis_client.get(redis_key) + if result: + result_data = json.loads(result) + return JSONResponse(content=result_data) # 直接返回整个结果,包括 status + else: + raise HTTPException(status_code=404, detail="Result not found") + +@pose_app.get("/annotated/{filename}", dependencies=[Depends(verify_api_key)]) +async def get_annotated_file(filename: str): + redis_key = f"pose_result:{filename}" + result = redis_client.get(redis_key) + if result: + result_data = json.loads(result) + if result_data["status"] == "completed": + annotated_filename = result_data["annotated_filename"] + file_path = os.path.join(RESULT_DIR, annotated_filename) + if os.path.exists(file_path): + def iterfile(): + with open(file_path, mode="rb") as file_like: + yield from file_like + file_extension = os.path.splitext(annotated_filename)[1].lower() + return StreamingResponse(iterfile(), media_type=f"image/{file_extension[1:]}" if file_extension in ['.jpg', '.jpeg', '.png'] else "video/mp4") + + raise HTTPException(status_code=404, detail="Annotated file not found") + +def process_task(): + for message in consumer: + task = message.value + filename = task['filename'] + file_type = task['file_type'] + + file_path = os.path.join(UPLOAD_DIR, filename) + + # Update status to "processing" + redis_key = f"pose_result:{filename}" + redis_client.set(redis_key, json.dumps({"status": "processing"})) + + try: + if file_type == "image": + with open(file_path, 'rb') as f: + content = f.read() + json_results, annotated_filename = process_image(content, filename) + else: + with open(file_path, 'rb') as f: + content = f.read() + json_results, annotated_filename = process_video(content, filename) + + if json_results and annotated_filename: + redis_client.set(redis_key, json.dumps({ + "json_results": json_results, + "status": "completed", + "annotated_filename": annotated_filename + })) + else: + redis_client.set(redis_key, json.dumps({"status": "failed"})) + except Exception as e: + print(f"Error processing task: {str(e)}") + redis_client.set(redis_key, json.dumps({"status": "failed", "error": str(e)})) + +def listen_redis_changes(): + pubsub = redis_client.pubsub() + pubsub.psubscribe('__keyspace@3__:pose_result:*') # 监听所有pose_result键的变化 + + for message in pubsub.listen(): + if message['type'] == 'pmessage': + key = message['channel'].decode('utf-8').split(':')[-1] + operation = message['data'].decode('utf-8') + + if operation == 'set': + value = redis_client.get(f"pose_result:{key}") + if value: + result = json.loads(value) + print(f"Status update for {key}: {result['status']}") + + # 这里可以添加其他处理逻辑,比如发送通知等 + +if __name__ == "__main__": + # 启动处理任务的线程 + threading.Thread(target=process_task, daemon=True).start() + + # 启动Redis监听线程 + threading.Thread(target=listen_redis_changes, daemon=True).start() + + uvicorn.run(app, host="0.0.0.0", port=7001) \ No newline at end of file diff --git a/api_old/before/pose_key2.py b/api_old/before/pose_key2.py new file mode 100644 index 0000000..6f9e599 --- /dev/null +++ b/api_old/before/pose_key2.py @@ -0,0 +1,470 @@ +from fastapi import FastAPI, HTTPException, File, UploadFile, Depends, Security +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import JSONResponse, StreamingResponse +from fastapi.security import APIKeyHeader +from ultralytics import YOLO +import cv2 +import numpy as np +import json +import uvicorn +from kafka import KafkaProducer, KafkaConsumer +from redis import Redis +import uuid +import os +from datetime import datetime, timedelta, timezone +import threading +import torch +import string + +torch.cuda.set_device(1) + +app = FastAPI() +pose_app = FastAPI() +app.mount("/pose", pose_app) + +# CORS配置 +ALLOWED_ORIGINS = 'https://beta.obscura.work' + +# 只为主应用添加CORS中间件 +app.add_middleware( + CORSMiddleware, + allow_origins=ALLOWED_ORIGINS, + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + +# 配置 +MODEL_PATH = "/home/zydi/models/yolov8x-pose.pt" # 请替换为您的模型路径 +KAFKA_BROKER = "222.186.10.253:9092" +KAFKA_TOPIC = "pose" # 指定Kafka topic +KAFKA_GROUP_ID = "pose_group" # 指定消费者组ID + +REDIS_HOST = "222.186.10.253" +REDIS_PORT = 6379 +REDIS_PASSWORD = "Obscura@2024" +REDIS_DB = 3 +REDIS_API_DB = 12 +REDIS_API_USAGE_DB = 13 +UPLOAD_DIR = "/www/wwwroot/beta.obscura.work/upload_files/upload" +RESULT_DIR = "/www/wwwroot/beta.obscura.work/upload_files/result" +MAX_FILE_AGE = timedelta(hours=1) + +# 确保目录存在 +os.makedirs(UPLOAD_DIR, exist_ok=True) +os.makedirs(RESULT_DIR, exist_ok=True) + +# 初始化 Kafka +producer = KafkaProducer(bootstrap_servers=[KAFKA_BROKER]) +consumer = KafkaConsumer( + KAFKA_TOPIC, + bootstrap_servers=[KAFKA_BROKER], + group_id=KAFKA_GROUP_ID, + auto_offset_reset='earliest', + enable_auto_commit=True, + value_deserializer=lambda x: json.loads(x.decode('utf-8')) +) + +# 初始化 Redis +redis_client = Redis( + host=REDIS_HOST, + port=REDIS_PORT, + password=REDIS_PASSWORD, + db=REDIS_DB +) + +redis_api_client = Redis( + host=REDIS_HOST, + port=REDIS_PORT, + password=REDIS_PASSWORD, + db=REDIS_API_DB +) +redis_api_usage_client = Redis( + host=REDIS_HOST, + port=REDIS_PORT, + password=REDIS_PASSWORD, + db=REDIS_API_USAGE_DB +) + +# 定义API密钥头部 +api_key_header = APIKeyHeader(name="Authorization", auto_error=False) + +# 定义 base62 字符集 +BASE62 = string.digits + string.ascii_letters + +# 验证API密钥的函数 +async def get_api_key(api_key: str = Security(api_key_header)): + if api_key and api_key.startswith("Bearer "): + key = api_key.split(" ")[1] + if key.startswith("obs-"): + return key + raise HTTPException( + status_code=401, + detail="无效的API密钥", + headers={"WWW-Authenticate": "Bearer"}, + ) + +def calculate_tokens(file_path: str, file_type: str) -> int: + base_tokens = 0 + + try: + file_size = os.path.getsize(file_path) # 获取文件大小(字节) + + # 基础token:每MB文件大小消耗10个token + base_tokens = int((file_size / (1024 * 1024)) * 10) + + if file_type == "image": + img = cv2.imread(file_path) + if img is None: + raise ValueError("无法读取图片文件") + height, width = img.shape[:2] + pixel_count = height * width + + # 图片token:每100个像素额外消耗5个token + image_tokens = int((pixel_count / 10000) * 5) + + base_tokens += image_tokens + + elif file_type == "video": + cap = cv2.VideoCapture(file_path) + if not cap.isOpened(): + raise ValueError("无法打开视频文件") + fps = cap.get(cv2.CAP_PROP_FPS) + frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + cap.release() + + pixel_count = width * height * frame_count + duration = frame_count / fps # 视频时长(秒) + + # 视频token:每100万像素每秒额外消耗1个token + video_tokens = int((pixel_count / 10000) * (duration / 60)) + + base_tokens += video_tokens + + return max(1, base_tokens) # 确保至少返回1个token + except Exception as e: + print(f"计算token时出错: {str(e)}") + return 1 # 出错时返回默认值1 + + +async def verify_api_key(api_key: str = Depends(get_api_key)): + redis_key = f"api_key:{api_key}" + + api_key_info = redis_api_client.hgetall(redis_key) + + if not api_key_info: + raise HTTPException(status_code=401, detail="无效的API密钥") + + api_key_info = {k.decode('utf-8'): v.decode('utf-8') for k, v in api_key_info.items()} + + if api_key_info.get('is_active') != '1': + raise HTTPException(status_code=401, detail="API密钥已停用") + + expires_at = datetime.fromisoformat(api_key_info.get('expires_at')) + if datetime.now(timezone.utc) > expires_at: + raise HTTPException(status_code=401, detail="API密钥已过期") + + usage_info = redis_api_usage_client.hgetall(redis_key) + usage_info = {k.decode('utf-8'): v.decode('utf-8') for k, v in usage_info.items()} + + return { + "api_key": api_key, + **api_key_info, + **usage_info + } + + +async def update_token_usage(api_key: str, new_tokens_used: int, model_name: str): + redis_key = f"api_key:{api_key}" + current_time = datetime.now(timezone.utc).isoformat() + + pipe = redis_api_usage_client.pipeline() + + # 更新总的token使用量 + pipe.hincrby(redis_key, "tokens_used", new_tokens_used) + pipe.hset(redis_key, "last_used_at", current_time) + + # 更新特定模型的token使用量 + model_tokens_field = f"{model_name}_tokens_used" + model_last_used_field = f"{model_name}_last_used_at" + + pipe.hincrby(redis_key, model_tokens_field, new_tokens_used) + pipe.hset(redis_key, model_last_used_field, current_time) + + pipe.execute() + +class PoseDetector: + def __init__(self, model_path): + self.model = YOLO(model_path).to('cuda:1') + + def detect(self, frame): + results = self.model(frame, device='cuda:1') + return results + + def format_results(self, results): + formatted_results = [] + for r in results: + boxes = r.boxes + keypoints = r.keypoints + for i in range(len(boxes)): + box = boxes[i] + kpts = keypoints[i] + formatted_results.append({ + "bbox": box.xyxy.tolist()[0], + "confidence": box.conf.item(), + "keypoints": kpts.xy.tolist()[0] + }) + return formatted_results + + def draw_results(self, frame, results, original_shape): + for r in results: + annotated_frame = r.plot(img=frame) + # 调整坐标以适应原始图像大小 + h, w = annotated_frame.shape[:2] + scale_x, scale_y = original_shape[1] / w, original_shape[0] / h + annotated_frame = cv2.resize(annotated_frame, (original_shape[1], original_shape[0])) + return annotated_frame + +detector = PoseDetector(MODEL_PATH) + +def process_image(image_data, filename): + try: + img = cv2.imdecode(np.frombuffer(image_data, np.uint8), cv2.IMREAD_COLOR) + original_shape = img.shape + # Convert BGR to RGB + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + + # Resize image to fit model requirements (640x640) + img_resized = cv2.resize(img, (640, 640)) + + # Normalize and reshape to BCHW format + img_tensor = torch.from_numpy(img_resized).float().permute(2, 0, 1).unsqueeze(0) / 255.0 + img_tensor = img_tensor.to('cuda:1') + + results = detector.detect(img_tensor) + + # Format results for JSON + json_results = detector.format_results(results) + + # Draw results on original image + annotated_img = detector.draw_results(img_resized, results, original_shape) + + # Save annotated image + annotated_filename = f"pose_{filename}" + annotated_path = os.path.join(RESULT_DIR, annotated_filename) + cv2.imwrite(annotated_path, cv2.cvtColor(annotated_img, cv2.COLOR_RGB2BGR)) + + return json_results, annotated_filename + except Exception as e: + print(f"Error processing image: {str(e)}") + return None, None + +def process_video(video_data, filename): + try: + temp_video_path = os.path.join(UPLOAD_DIR, f"pose_{filename}") + with open(temp_video_path, 'wb') as temp_video: + temp_video.write(video_data) + + cap = cv2.VideoCapture(temp_video_path) + frame_count = 0 + json_results = [] + + # Get video properties + fps = int(cap.get(cv2.CAP_PROP_FPS)) + width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + original_shape = (height, width) + + # Create output video file + annotated_filename = f"pose_{filename}" + output_path = os.path.join(RESULT_DIR, annotated_filename) + out = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height)) + + while cap.isOpened(): + ret, frame = cap.read() + if not ret: + break + + # Process one frame per second + if frame_count % fps == 0: + # Convert BGR to RGB + frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + + # Resize frame to fit model requirements (640x640) + frame_resized = cv2.resize(frame_rgb, (640, 640)) + + # Normalize and reshape to BCHW format + frame_tensor = torch.from_numpy(frame_resized).float().permute(2, 0, 1).unsqueeze(0) / 255.0 + frame_tensor = frame_tensor.to('cuda:1') + + results = detector.detect(frame_tensor) + frame_json_results = detector.format_results(results) + json_results.append({"frame": frame_count, "detections": frame_json_results}) + + # Draw results on original frame + annotated_frame = detector.draw_results(frame_resized, results, original_shape) + # Convert RGB back to BGR for OpenCV + annotated_frame = cv2.cvtColor(annotated_frame, cv2.COLOR_RGB2BGR) + else: + annotated_frame = frame + + out.write(annotated_frame) + frame_count += 1 + + cap.release() + out.release() + + # Clean up temporary input video file + os.remove(temp_video_path) + + return json_results, annotated_filename + except Exception as e: + print(f"Error processing video: {str(e)}") + return None, None + +@pose_app.post("/upload") +async def upload_file(file: UploadFile = File(...), api_key_info: dict = Depends(verify_api_key)): + + content = await file.read() + file_extension = os.path.splitext(file.filename)[1].lower() + new_filename = f"{uuid.uuid4()}{file_extension}" + + # Save the original file + original_file_path = os.path.join(UPLOAD_DIR, new_filename) + with open(original_file_path, "wb") as f: + f.write(content) + + # 计算 token + file_type = "image" if file_extension in ['.jpg', '.jpeg', '.png'] else "video" + tokens_required = calculate_tokens(original_file_path, file_type) + if tokens_required is None or tokens_required <= 0: + raise HTTPException(status_code=500, detail="无法计算所需的token数量") + + # 检查并更新 token 使用量 + api_key = api_key_info['api_key'] + usage_key = f"api_key:{api_key}" + total_tokens = int(redis_api_usage_client.hget(usage_key, "total_tokens") or 0) + tokens_used = int(redis_api_usage_client.hget(usage_key, "tokens_used") or 0) + + if tokens_used + tokens_required > total_tokens: + raise HTTPException(status_code=403, detail="Token 余额不足") + + # 更新 token 使用量 + model_name = "yolov8x-pose" + await update_token_usage(api_key, tokens_required, model_name) + + + # Send processing task to Kafka + producer.send(KAFKA_TOPIC, json.dumps({ + "filename": new_filename, + "file_type": "image" if file_extension in ['.jpg', '.jpeg', '.png'] else "video" + }).encode('utf-8')) + + # Set initial status in Redis + redis_key = f"pose_result:{new_filename}" + redis_client.set(redis_key, json.dumps({"status": "queued"})) + + # 获取更新后的 token 使用情况 + updated_api_key_info = await verify_api_key(api_key) + new_tokens_used = int(updated_api_key_info.get("tokens_used", 0)) + model_tokens_used = int(updated_api_key_info.get(f"{model_name}_tokens_used", 0)) + + return JSONResponse(content={ + "message": "文件已上传并排队等待处理", + "filename": new_filename, + "tokens_used": tokens_required, + "total_tokens_used": new_tokens_used, + f"{model_name}_tokens_used": model_tokens_used, + "tokens_remaining": total_tokens - new_tokens_used + }) + + +@pose_app.get("/result/{filename}", dependencies=[Depends(verify_api_key)]) +async def get_pose_result(filename: str): + redis_key = f"pose_result:{filename}" + result = redis_client.get(redis_key) + if result: + result_data = json.loads(result) + return JSONResponse(content=result_data) # 直接返回整个结果,包括 status + else: + raise HTTPException(status_code=404, detail="Result not found") + +@pose_app.get("/annotated/{filename}", dependencies=[Depends(verify_api_key)]) +async def get_annotated_file(filename: str): + redis_key = f"pose_result:{filename}" + result = redis_client.get(redis_key) + if result: + result_data = json.loads(result) + if result_data["status"] == "completed": + annotated_filename = result_data["annotated_filename"] + file_path = os.path.join(RESULT_DIR, annotated_filename) + if os.path.exists(file_path): + def iterfile(): + with open(file_path, mode="rb") as file_like: + yield from file_like + file_extension = os.path.splitext(annotated_filename)[1].lower() + return StreamingResponse(iterfile(), media_type=f"image/{file_extension[1:]}" if file_extension in ['.jpg', '.jpeg', '.png'] else "video/mp4") + + raise HTTPException(status_code=404, detail="Annotated file not found") + +def process_task(): + for message in consumer: + task = message.value + filename = task['filename'] + file_type = task['file_type'] + + file_path = os.path.join(UPLOAD_DIR, filename) + + # Update status to "processing" + redis_key = f"pose_result:{filename}" + redis_client.set(redis_key, json.dumps({"status": "processing"})) + + try: + if file_type == "image": + with open(file_path, 'rb') as f: + content = f.read() + json_results, annotated_filename = process_image(content, filename) + else: + with open(file_path, 'rb') as f: + content = f.read() + json_results, annotated_filename = process_video(content, filename) + + if json_results and annotated_filename: + redis_client.set(redis_key, json.dumps({ + "json_results": json_results, + "status": "completed", + "annotated_filename": annotated_filename + })) + else: + redis_client.set(redis_key, json.dumps({"status": "failed"})) + except Exception as e: + print(f"Error processing task: {str(e)}") + redis_client.set(redis_key, json.dumps({"status": "failed", "error": str(e)})) + +def listen_redis_changes(): + pubsub = redis_client.pubsub() + pubsub.psubscribe('__keyspace@3__:pose_result:*') # 监听所有pose_result键的变化 + + for message in pubsub.listen(): + if message['type'] == 'pmessage': + key = message['channel'].decode('utf-8').split(':')[-1] + operation = message['data'].decode('utf-8') + + if operation == 'set': + value = redis_client.get(f"pose_result:{key}") + if value: + result = json.loads(value) + print(f"Status update for {key}: {result['status']}") + + # 这里可以添加其他处理逻辑,比如发送通知等 + +if __name__ == "__main__": + # 启动处理任务的线程 + threading.Thread(target=process_task, daemon=True).start() + + # 启动Redis监听线程 + threading.Thread(target=listen_redis_changes, daemon=True).start() + + uvicorn.run(app, host="0.0.0.0", port=7001) \ No newline at end of file diff --git a/api_old/before/qwenvl_api.py b/api_old/before/qwenvl_api.py new file mode 100644 index 0000000..5cf7a90 --- /dev/null +++ b/api_old/before/qwenvl_api.py @@ -0,0 +1,356 @@ +import os +import json +import uuid +from datetime import datetime, timedelta +from fastapi import FastAPI, HTTPException, UploadFile, File +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import JSONResponse +from kafka import KafkaProducer, KafkaConsumer +from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor +from qwen_vl_utils import process_vision_info +from decord import VideoReader, cpu +from PIL import Image +from redis import Redis +import io +import re +import torch +from contextlib import asynccontextmanager +import threading + +app = FastAPI() +qwenvl_app = FastAPI() +app.mount("/qwenvl", qwenvl_app) +torch.cuda.set_device(1) +# CORS设置 +ALLOWED_ORIGINS = ['https://beta.obscura.work'] + +app.add_middleware( + CORSMiddleware, + allow_origins=ALLOWED_ORIGINS, + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + +# 配置 +MODEL_PATH = "/home/zydi/models/qwen/Qwen2-VL-2B-Instruct" +KAFKA_BROKER = "222.186.10.253:9092" +KAFKA_TOPIC = "qwenvl" +KAFKA_GROUP_ID = "qwenvl_group" + +REDIS_HOST = "222.186.10.253" +REDIS_PORT = 6379 +REDIS_PASSWORD = "Obscura@2024" +REDIS_DB = 8 + +UPLOAD_DIR = "/www/wwwroot/beta.obscura.work/upload_files/upload" +RESULT_DIR = "/www/wwwroot/beta.obscura.work/upload_files/result" +MAX_FILE_AGE = timedelta(hours=1) + +# 确保目录存在 +os.makedirs(UPLOAD_DIR, exist_ok=True) +os.makedirs(RESULT_DIR, exist_ok=True) + +# 初始化 Kafka +producer = KafkaProducer(bootstrap_servers=[KAFKA_BROKER]) +consumer = KafkaConsumer( + KAFKA_TOPIC, + bootstrap_servers=[KAFKA_BROKER], + group_id=KAFKA_GROUP_ID, + auto_offset_reset='earliest', + enable_auto_commit=True, + value_deserializer=lambda x: json.loads(x.decode('utf-8')) +) + +# 初始化 Redis +redis_client = Redis( + host=REDIS_HOST, + port=REDIS_PORT, + password=REDIS_PASSWORD, + db=REDIS_DB +) + + +# 初始化模型 +model = Qwen2VLForConditionalGeneration.from_pretrained( + MODEL_PATH, torch_dtype="auto", device_map="cuda:1" +) + +min_pixels = 128*28*28 +max_pixels = 512*28*28 +processor = AutoProcessor.from_pretrained(MODEL_PATH, min_pixels=min_pixels, max_pixels=max_pixels) + +class MediaAnalysisSystem: + def __init__(self, model, processor): + self.model = model + self.processor = processor + self.MAX_NUM_FRAMES = 10 + + def encode_video(self, video_data): + def uniform_sample(l, n): + gap = len(l) / n + return [l[int(i * gap + gap / 2)] for i in range(n)] + + video_file = io.BytesIO(video_data) + vr = VideoReader(video_file) + sample_fps = round(vr.get_avg_fps() / 1) + frame_idx = list(range(0, len(vr), sample_fps)) + if len(frame_idx) > self.MAX_NUM_FRAMES: + frame_idx = uniform_sample(frame_idx, self.MAX_NUM_FRAMES) + frames = vr.get_batch(frame_idx).asnumpy() + frames = [Image.fromarray(v.astype('uint8')) for v in frames] + print('num frames:', len(frames)) + return frames + + def process_media(self, media_data, object_name, media_type='image'): + if not media_data: + raise ValueError(f"Empty {media_type} data for {object_name}") + + print(f"Processing {media_type}: {object_name}, data size: {len(media_data)} bytes") + + if media_type == 'video': + frames = self.encode_video(media_data) + media_content = {"type": "video", "video": frames, "fps": 1.0} + else: # image + image = Image.open(io.BytesIO(media_data)) + media_content = {"type": "image", "image": image} + + messages = [ + { + "role": "user", + "content": [ + media_content, + {"type": "text", "text": "用中文尽可能详细地描述这个" + ("视频" if media_type == "video" else "图片") + ",包括场景、人物数量、行为变化等。"}, + ], + } + ] + + text = self.processor.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + image_inputs, video_inputs = process_vision_info(messages) + inputs = self.processor( + text=[text], + images=image_inputs, + videos=video_inputs, + padding=True, + return_tensors="pt", + ) + inputs = inputs.to('cuda:1') + generated_ids = self.model.generate(**inputs, max_new_tokens=512) + generated_ids_trimmed = [ + out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) + ] + answer = self.processor.batch_decode( + generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False + )[0] + + extracted_info = self.extract_info(answer) + + result = { + "original_answer": answer, + "extracted_info": extracted_info, + "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S") + } + + if media_type == 'video': + result["num_frames"] = len(frames) + + return result + + def process_video(self, video_data, object_name): + return self.process_media(video_data, object_name, media_type='video') + + def process_image(self, image_data, object_name): + return self.process_media(image_data, object_name, media_type='image') + + @staticmethod + def extract_time_from_filename(object_name): + filename = os.path.basename(object_name) + time_str = filename.split('_')[0] + '_' + filename.split('_')[1].split('.')[0] + + try: + start_time = datetime.strptime(time_str, "%Y%m%d_%H%M%S") + end_time = start_time + timedelta(seconds=10) + return start_time, end_time + except ValueError: + print(f"无法从文件名 '{filename}' 解析时间。使用默认时间。") + return datetime.now(), datetime.now() + timedelta(seconds=10) + + @staticmethod + def extract_info(answer): + info = { + "environment": None, + "num_people": None, + "actions": [], + "interactions": [], + "objects": [], + "furniture": [] + } + + environments = ["办公室", "室内", "室外", "会议室"] + for env in environments: + if env in answer.lower(): + info["environment"] = env + break + + people_patterns = [ + r'(\d+)\s*(人|个人|位|名|员工|用户|小朋友|成年人|女性|男性)', + r'(一|二|三|四|五|六|七|八|九|十)\s*(人|个人|位|名|员工|用户|小朋友|成年人|女性|男性)', + r'(一个|几个)\s*(人|个人|员工|用户|小朋友|成年人|女性|男性)', + r'几\s*(名|位)\s*(人|员工|用户|小朋友|成年人|女性|男性)?', + r'(男|女)(性|生|士)', + r'(成年|未成年|青少年|老年)\s*(人|群体)', + r'(员工|职工|工人|学生|顾客|观众|游客|乘客)', + r'(群众|民众|大众|公众)', + r'(男女|老少|老幼|大人|小孩)' + ] + for pattern in people_patterns: + match = re.search(pattern, answer) + if match: + if match.group(1).isdigit(): + info["num_people"] = int(match.group(1)) + elif match.group(1) in ['一个', '一']: + info["num_people"] = 1 + else: + num_word_to_digit = { + '二': 2, '三': 3, '四': 4, '五': 5, + '六': 6, '七': 7, '八': 8, '九': 9, '十': 10 + } + info["num_people"] = num_word_to_digit.get(match.group(1), 0) + break + + actions = ["坐", "站", "摔倒", "跳舞", "转身", "摔", "倒", "倒下", "躺下", "转身", "跳跃", "跳", "躺", "睡", "说话"] + interactions = ["互动", "交流", "身体语言", "交谈", "讨论", "开会"] + objects = ["水瓶", "办公用品", "文件", "电脑"] + furniture = ["椅子", "桌子", "咖啡桌", "文件柜", "床", "沙发"] + + for item_list, key in [(actions, "actions"), (interactions, "interactions"), (objects, "objects"), (furniture, "furniture")]: + for item in item_list: + if item in answer: + info[key].append(item) + + return info + +# 初始化 MediaAnalysisSystem +media_analysis_system = MediaAnalysisSystem(model, processor) + +async def process_file(file: UploadFile, file_type: str): + content = await file.read() + # 获取原始文件的后缀 + original_extension = os.path.splitext(file.filename)[1] + + # 生成新的文件名,包含 UUID 和原始后缀 + filename = f"qwenvl_{uuid.uuid4()}{original_extension}" + file_path = os.path.join(UPLOAD_DIR, filename) + with open(file_path, "wb") as f: + f.write(content) + + producer.send(KAFKA_TOPIC, json.dumps({ + "filename": filename, + "type": file_type + }).encode('utf-8')) + + redis_key = f"{file_type}_result:{filename}" + redis_client.set(redis_key, json.dumps({"status": "queued"})) + + return {"message": f"{file_type.capitalize()} uploaded and queued for processing", "filename": filename} + +@qwenvl_app.post("/upload") +async def upload_file(file: UploadFile = File(...)): + try: + file_type = "image" if file.content_type.startswith("image") else "video" + return await process_file(file, file_type) + except Exception as e: + return JSONResponse(content={"error": str(e)}, status_code=500) + +@qwenvl_app.post("/analyze_video") +async def analyze_video(file: UploadFile = File(...)): + try: + return await process_file(file, "video") + except Exception as e: + return JSONResponse(content={"error": str(e)}, status_code=500) + +@qwenvl_app.post("/analyze_image") +async def analyze_image(file: UploadFile = File(...)): + try: + return await process_file(file, "image") + except Exception as e: + return JSONResponse(content={"error": str(e)}, status_code=500) + +def process_task(): + for message in consumer: + try: + if isinstance(message.value, dict): + task = message.value + else: + task = json.loads(message.value.decode('utf-8')) + + filename = task['filename'] + file_type = task['type'] + + file_path = os.path.join(UPLOAD_DIR, filename) + + redis_key = f"{file_type}_result:{filename}" + redis_client.set(redis_key, json.dumps({"status": "processing"})) + + with open(file_path, 'rb') as f: + file_data = f.read() + + if file_type == "video": + result = media_analysis_system.process_video(file_data, filename) + elif file_type == "image": + result = media_analysis_system.process_image(file_data, filename) + + # 保存结果到 JSON 文件 + result_file_path = os.path.join(RESULT_DIR, f"{filename}.json") + with open(result_file_path, 'w', encoding='utf-8') as f: + json.dump(result, f, ensure_ascii=False, indent=2) + + # 将结果存储在 Redis 中 + redis_client.set(redis_key, json.dumps({ + "status": "completed", + "result": result + })) + + except Exception as e: + print(f"Error processing task: {str(e)}") + if 'filename' in locals() and 'file_type' in locals(): + redis_key = f"{file_type}_result:{filename}" + redis_client.set(redis_key, json.dumps({"status": "failed", "error": str(e)})) + else: + print("Error occurred before task details were extracted") + +@qwenvl_app.get("/result/{filename}") +async def get_result(filename: str): + for file_type in ["video", "image"]: + redis_key = f"{file_type}_result:{filename}" + result = redis_client.get(redis_key) + if result: + result_json = json.loads(result) + + if result_json.get("status") == "queued": + return {"status": "queued", "message": "Your request is in the queue and will be processed soon."} + elif result_json.get("status") == "processing": + return {"status": "processing", "message": "Your request is being processed."} + else: + return result_json + + raise HTTPException(status_code=404, detail="Result not found") + +async def listen_redis_changes(): + pubsub = redis_client.pubsub() + pubsub.psubscribe('__keyspace@5__:*_result:*') + + for message in pubsub.listen(): + if message['type'] == 'pmessage': + key = message['channel'].decode('utf-8').split(':')[-1] + print(f"Key changed: {key}") + +if __name__ == "__main__": + # 在后台线程中启动Kafka消费者 + consumer_thread = threading.Thread(target=process_task, daemon=True) + consumer_thread.start() + + import uvicorn + uvicorn.run(app, host="0.0.0.0", port=7005) \ No newline at end of file diff --git a/api_old/before/qwenvl_key.py b/api_old/before/qwenvl_key.py new file mode 100644 index 0000000..d730b5e --- /dev/null +++ b/api_old/before/qwenvl_key.py @@ -0,0 +1,508 @@ +import os +import json +import uuid +from datetime import datetime, timedelta, timezone +from fastapi import FastAPI, HTTPException, UploadFile, File, Depends, Header +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import JSONResponse +from fastapi.security import APIKeyHeader +from kafka import KafkaProducer, KafkaConsumer +from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor +from qwen_vl_utils import process_vision_info +from decord import VideoReader, cpu +from PIL import Image +from redis import Redis +import io +import re +import torch +from contextlib import asynccontextmanager +import threading + +app = FastAPI() +qwenvl_app = FastAPI() +app.mount("/qwenvl", qwenvl_app) +torch.cuda.set_device(1) + +# CORS设置 +ALLOWED_ORIGINS = ['https://beta.obscura.work'] + +app.add_middleware( + CORSMiddleware, + allow_origins=ALLOWED_ORIGINS, + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + +# 配置 +MODEL_PATH = "/home/zydi/models/qwen/Qwen2-VL-2B-Instruct" +KAFKA_BROKER = "222.186.10.253:9092" +KAFKA_TOPIC = "qwenvl" +KAFKA_GROUP_ID = "qwenvl_group" + +REDIS_HOST = "222.186.10.253" +REDIS_PORT = 6379 +REDIS_PASSWORD = "Obscura@2024" +REDIS_DB = 8 +REDIS_API_DB = 12 +REDIS_API_USAGE_DB = 13 + +UPLOAD_DIR = "/www/wwwroot/beta.obscura.work/upload_files/upload" +RESULT_DIR = "/www/wwwroot/beta.obscura.work/upload_files/result" +MAX_FILE_AGE = timedelta(hours=1) + +# 确保目录存在 +os.makedirs(UPLOAD_DIR, exist_ok=True) +os.makedirs(RESULT_DIR, exist_ok=True) + +# 初始化 Kafka +producer = KafkaProducer(bootstrap_servers=[KAFKA_BROKER]) +consumer = KafkaConsumer( + KAFKA_TOPIC, + bootstrap_servers=[KAFKA_BROKER], + group_id=KAFKA_GROUP_ID, + auto_offset_reset='earliest', + enable_auto_commit=True, + value_deserializer=lambda x: json.loads(x.decode('utf-8')) +) + +# 初始化 Redis +redis_client = Redis( + host=REDIS_HOST, + port=REDIS_PORT, + password=REDIS_PASSWORD, + db=REDIS_DB +) + +redis_api_client = Redis( + host=REDIS_HOST, + port=REDIS_PORT, + password=REDIS_PASSWORD, + db=REDIS_API_DB +) + +redis_api_usage_client = Redis( + host=REDIS_HOST, + port=REDIS_PORT, + password=REDIS_PASSWORD, + db=REDIS_API_USAGE_DB +) + +# 添加API密钥验证 +API_KEY_NAME = "X-API-Key" +api_key_header = APIKeyHeader(name=API_KEY_NAME, auto_error=False) + +async def get_api_key(api_key: str = Header(None, alias=API_KEY_NAME)): + if api_key is None: + raise HTTPException(status_code=400, detail="API密钥缺失") + return api_key + +async def verify_api_key(api_key: str = Depends(get_api_key)): + redis_key = f"api_key:{api_key}" + + api_key_info = redis_api_client.hgetall(redis_key) + + if not api_key_info: + return None + + api_key_info = {k.decode('utf-8'): v.decode('utf-8') for k, v in api_key_info.items()} + + if api_key_info.get('is_active') != '1': + return None + + expires_at = datetime.fromisoformat(api_key_info.get('expires_at')) + if datetime.now(timezone.utc) > expires_at: + return None + + usage_info = redis_api_usage_client.hgetall(redis_key) + usage_info = {k.decode('utf-8'): v.decode('utf-8') for k, v in usage_info.items()} + + return { + **api_key_info, + **usage_info + } + +async def update_token_usage(api_key: str, new_tokens_used: int, model_name: str): + redis_key = f"api_key:{api_key}" + current_time = datetime.now(timezone.utc).isoformat() + + pipe = redis_api_usage_client.pipeline() + + # 更新总的token使用量 + pipe.hincrby(redis_key, "tokens_used", new_tokens_used) + pipe.hset(redis_key, "last_used_at", current_time) + + # 更新特定模型的token使用量 + model_tokens_field = f"{model_name}_tokens_used" + model_last_used_field = f"{model_name}_last_used_at" + + pipe.hincrby(redis_key, model_tokens_field, new_tokens_used) + pipe.hset(redis_key, model_last_used_field, current_time) + + pipe.execute() + +def calculate_tokens(file_path: str, file_type: str) -> int: + base_tokens = 0 + + try: + file_size = os.path.getsize(file_path) # 获取文件大小(字节) + + # 基础token:每MB文件大小消耗10个token + base_tokens = int((file_size / (1024 * 1024)) * 10) + + if file_type == "image": + img = Image.open(file_path) + width, height = img.size + pixel_count = width * height + + # 图片token:每100个像素额外消耗5个token + image_tokens = int((pixel_count / 10000) * 5) + + base_tokens += image_tokens + + elif file_type == "video": + vr = VideoReader(file_path) + fps = vr.get_avg_fps() + frame_count = len(vr) + width, height = vr[0].shape[1], vr[0].shape[0] + + pixel_count = width * height * frame_count + duration = frame_count / fps # 视频时长(秒) + + # 视频token:每100万像素每秒额外消耗1个token + video_tokens = int((pixel_count / 10000) * (duration / 60)) + + base_tokens += video_tokens + + return max(1, base_tokens) # 确保至少返回1个token + except Exception as e: + print(f"计算token时出错: {str(e)}") + return 1 # 出错时返回默认值1 + +# 初始化模型 +model = Qwen2VLForConditionalGeneration.from_pretrained( + MODEL_PATH, torch_dtype="auto", device_map="cuda:1" +) + +min_pixels = 128*28*28 +max_pixels = 512*28*28 +processor = AutoProcessor.from_pretrained(MODEL_PATH, min_pixels=min_pixels, max_pixels=max_pixels) + +class MediaAnalysisSystem: + def __init__(self, model, processor): + self.model = model + self.processor = processor + self.MAX_NUM_FRAMES = 10 + + def encode_video(self, video_data): + def uniform_sample(l, n): + gap = len(l) / n + return [l[int(i * gap + gap / 2)] for i in range(n)] + + video_file = io.BytesIO(video_data) + vr = VideoReader(video_file) + sample_fps = round(vr.get_avg_fps() / 1) + frame_idx = list(range(0, len(vr), sample_fps)) + if len(frame_idx) > self.MAX_NUM_FRAMES: + frame_idx = uniform_sample(frame_idx, self.MAX_NUM_FRAMES) + frames = vr.get_batch(frame_idx).asnumpy() + frames = [Image.fromarray(v.astype('uint8')) for v in frames] + print('num frames:', len(frames)) + return frames + + def process_media(self, media_data, object_name, media_type='image'): + if not media_data: + raise ValueError(f"Empty {media_type} data for {object_name}") + + print(f"Processing {media_type}: {object_name}, data size: {len(media_data)} bytes") + + if media_type == 'video': + frames = self.encode_video(media_data) + media_content = {"type": "video", "video": frames, "fps": 1.0} + else: # image + image = Image.open(io.BytesIO(media_data)) + media_content = {"type": "image", "image": image} + + messages = [ + { + "role": "user", + "content": [ + media_content, + {"type": "text", "text": "用中文尽可能详细地描述这个" + ("视频" if media_type == "video" else "图片") + ",包括场景、人物数量、行为变化等。"}, + ], + } + ] + + text = self.processor.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + image_inputs, video_inputs = process_vision_info(messages) + inputs = self.processor( + text=[text], + images=image_inputs, + videos=video_inputs, + padding=True, + return_tensors="pt", + ) + inputs = inputs.to('cuda:1') + generated_ids = self.model.generate(**inputs, max_new_tokens=512) + generated_ids_trimmed = [ + out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) + ] + answer = self.processor.batch_decode( + generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False + )[0] + + extracted_info = self.extract_info(answer) + + result = { + "original_answer": answer, + "extracted_info": extracted_info, + "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S") + } + + if media_type == 'video': + result["num_frames"] = len(frames) + + return result + + def process_video(self, video_data, object_name): + return self.process_media(video_data, object_name, media_type='video') + + def process_image(self, image_data, object_name): + return self.process_media(image_data, object_name, media_type='image') + + @staticmethod + def extract_time_from_filename(object_name): + filename = os.path.basename(object_name) + time_str = filename.split('_')[0] + '_' + filename.split('_')[1].split('.')[0] + + try: + start_time = datetime.strptime(time_str, "%Y%m%d_%H%M%S") + end_time = start_time + timedelta(seconds=10) + return start_time, end_time + except ValueError: + print(f"无法从文件名 '{filename}' 解析时间。使用默认时间。") + return datetime.now(), datetime.now() + timedelta(seconds=10) + + @staticmethod + def extract_info(answer): + info = { + "environment": None, + "num_people": None, + "actions": [], + "interactions": [], + "objects": [], + "furniture": [] + } + + environments = ["办公室", "室内", "室外", "会议室"] + for env in environments: + if env in answer.lower(): + info["environment"] = env + break + + people_patterns = [ + r'(\d+)\s*(人|个人|位|名|员工|用户|小朋友|成年人|女性|男性)', + r'(一|二|三|四|五|六|七|八|九|十)\s*(人|个人|位|名|员工|用户|小朋友|成年人|女性|男性)', + r'(一个|几个)\s*(人|个人|员工|用户|小朋友|成年人|女性|男性)', + r'几\s*(名|位)\s*(人|员工|用户|小朋友|成年人|女性|男性)?', + r'(男|女)(性|生|士)', + r'(成年|未成年|青少年|老年)\s*(人|群体)', + r'(员工|职工|工人|学生|顾客|观众|游客|乘客)', + r'(群众|民众|大众|公众)', + r'(男女|老少|老幼|大人|小孩)' + ] + for pattern in people_patterns: + match = re.search(pattern, answer) + if match: + if match.group(1).isdigit(): + info["num_people"] = int(match.group(1)) + elif match.group(1) in ['一个', '一']: + info["num_people"] = 1 + else: + num_word_to_digit = { + '二': 2, '三': 3, '四': 4, '五': 5, + '六': 6, '七': 7, '八': 8, '九': 9, '十': 10 + } + info["num_people"] = num_word_to_digit.get(match.group(1), 0) + break + + actions = ["坐", "站", "摔倒", "跳舞", "转身", "摔", "倒", "倒下", "躺下", "转身", "跳跃", "跳", "躺", "睡", "说话"] + interactions = ["互动", "交流", "身体语言", "交谈", "讨论", "开会"] + objects = ["水瓶", "办公用品", "文件", "电脑"] + furniture = ["椅子", "桌子", "咖啡桌", "文件柜", "床", "沙发"] + + for item_list, key in [(actions, "actions"), (interactions, "interactions"), (objects, "objects"), (furniture, "furniture")]: + for item in item_list: + if item in answer: + info[key].append(item) + + return info + +# 初始化 MediaAnalysisSystem +media_analysis_system = MediaAnalysisSystem(model, processor) + +async def process_file(file: UploadFile, file_type: str, api_key: str): + content = await file.read() + original_extension = os.path.splitext(file.filename)[1] + + filename = f"qwenvl_{uuid.uuid4()}{original_extension}" + file_path = os.path.join(UPLOAD_DIR, filename) + with open(file_path, "wb") as f: + f.write(content) + + # 计算token + tokens_required = calculate_tokens(file_path, file_type) + + # 检查并更新token使用量 + usage_key = f"api_key:{api_key}" + total_tokens = int(redis_api_usage_client.hget(usage_key, "total_tokens") or 0) + tokens_used = int(redis_api_usage_client.hget(usage_key, "tokens_used") or 0) + + if tokens_used + tokens_required > total_tokens: + raise HTTPException(status_code=403, detail="Token 余额不足") + + # 更新token使用量 + model_name = "Qwen2-VL-2B-Instruct" + await update_token_usage(api_key, tokens_required, model_name) + + + producer.send(KAFKA_TOPIC, json.dumps({ + "filename": filename, + "type": file_type + }).encode('utf-8')) + + redis_key = f"{file_type}_result:{filename}" + redis_client.set(redis_key, json.dumps({"status": "queued"})) + + # 获取更新后的 token 使用情况 + updated_api_key_info = await verify_api_key(api_key) + new_tokens_used = int(updated_api_key_info.get("tokens_used", 0)) + model_tokens_used = int(updated_api_key_info.get(f"{model_name}_tokens_used", 0)) + + return JSONResponse(content={ + "message": "文件已上传并排队等待处理", + "filename": filename, + "tokens_used": tokens_required, + "total_tokens_used": new_tokens_used, + f"{model_name}_tokens_used": model_tokens_used, + "tokens_remaining": total_tokens - new_tokens_used + }) + + +@qwenvl_app.post("/upload") +async def upload_file(file: UploadFile = File(...), api_key: str = Depends(get_api_key)): + api_key_info = await verify_api_key(api_key) + if not api_key_info: + raise HTTPException(status_code=403, detail="无效的API密钥") + + try: + file_type = "image" if file.content_type.startswith("image") else "video" + return await process_file(file, file_type, api_key) + except Exception as e: + return JSONResponse(content={"error": str(e)}, status_code=500) + +@qwenvl_app.post("/analyze_video") +async def analyze_video(file: UploadFile = File(...), api_key: str = Depends(get_api_key)): + api_key_info = await verify_api_key(api_key) + if not api_key_info: + raise HTTPException(status_code=403, detail="无效的API密钥") + + try: + return await process_file(file, "video", api_key) + except Exception as e: + return JSONResponse(content={"error": str(e)}, status_code=500) + +@qwenvl_app.post("/analyze_image") +async def analyze_image(file: UploadFile = File(...), api_key: str = Depends(get_api_key)): + api_key_info = await verify_api_key(api_key) + if not api_key_info: + raise HTTPException(status_code=403, detail="无效的API密钥") + + try: + return await process_file(file, "image", api_key) + except Exception as e: + return JSONResponse(content={"error": str(e)}, status_code=500) + +def process_task(): + for message in consumer: + try: + if isinstance(message.value, dict): + task = message.value + else: + task = json.loads(message.value.decode('utf-8')) + + filename = task['filename'] + file_type = task['type'] + + file_path = os.path.join(UPLOAD_DIR, filename) + + redis_key = f"{file_type}_result:{filename}" + redis_client.set(redis_key, json.dumps({"status": "processing"})) + + with open(file_path, 'rb') as f: + file_data = f.read() + + if file_type == "video": + result = media_analysis_system.process_video(file_data, filename) + elif file_type == "image": + result = media_analysis_system.process_image(file_data, filename) + + # 保存结果到 JSON 文件 + result_file_path = os.path.join(RESULT_DIR, f"{filename}.json") + with open(result_file_path, 'w', encoding='utf-8') as f: + json.dump(result, f, ensure_ascii=False, indent=2) + + # 将结果存储在 Redis 中 + redis_client.set(redis_key, json.dumps({ + "status": "completed", + "result": result + })) + + except Exception as e: + print(f"Error processing task: {str(e)}") + if 'filename' in locals() and 'file_type' in locals(): + redis_key = f"{file_type}_result:{filename}" + redis_client.set(redis_key, json.dumps({"status": "failed", "error": str(e)})) + else: + print("Error occurred before task details were extracted") + +@qwenvl_app.get("/result/{filename}") +async def get_result(filename: str, api_key: str = Depends(get_api_key)): + api_key_info = await verify_api_key(api_key) + if not api_key_info: + raise HTTPException(status_code=403, detail="无效的API密钥") + + for file_type in ["video", "image"]: + redis_key = f"{file_type}_result:{filename}" + result = redis_client.get(redis_key) + if result: + result_json = json.loads(result) + + if result_json.get("status") == "queued": + return {"status": "queued", "message": "Your request is in the queue and will be processed soon."} + elif result_json.get("status") == "processing": + return {"status": "processing", "message": "Your request is being processed."} + else: + return result_json + + raise HTTPException(status_code=404, detail="Result not found") + + +async def listen_redis_changes(): + pubsub = redis_client.pubsub() + pubsub.psubscribe('__keyspace@5__:*_result:*') + + for message in pubsub.listen(): + if message['type'] == 'pmessage': + key = message['channel'].decode('utf-8').split(':')[-1] + print(f"Key changed: {key}") + +if __name__ == "__main__": + # 在后台线程中启动Kafka消费者 + consumer_thread = threading.Thread(target=process_task, daemon=True) + consumer_thread.start() + + import uvicorn + uvicorn.run(app, host="0.0.0.0", port=7005) \ No newline at end of file diff --git a/api_old/before/qwenvl_key2.py b/api_old/before/qwenvl_key2.py new file mode 100644 index 0000000..06b187e --- /dev/null +++ b/api_old/before/qwenvl_key2.py @@ -0,0 +1,518 @@ +import os +import json +import uuid +from datetime import datetime, timedelta, timezone +from fastapi import FastAPI, HTTPException, UploadFile, File, Depends, Security +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import JSONResponse +from fastapi.security import APIKeyHeader +from kafka import KafkaProducer, KafkaConsumer +from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor +from qwen_vl_utils import process_vision_info +from decord import VideoReader, cpu +from PIL import Image +from redis import Redis +import io +import re +import torch +from contextlib import asynccontextmanager +import threading +import string + + +app = FastAPI() +qwenvl_app = FastAPI() +app.mount("/qwenvl", qwenvl_app) +torch.cuda.set_device(1) + +# CORS设置 +ALLOWED_ORIGINS = ['https://beta.obscura.work'] + +app.add_middleware( + CORSMiddleware, + allow_origins=ALLOWED_ORIGINS, + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + +# 配置 +MODEL_PATH = "/home/zydi/models/qwen/Qwen2-VL-2B-Instruct" +KAFKA_BROKER = "222.186.10.253:9092" +KAFKA_TOPIC = "qwenvl" +KAFKA_GROUP_ID = "qwenvl_group" + +REDIS_HOST = "222.186.10.253" +REDIS_PORT = 6379 +REDIS_PASSWORD = "Obscura@2024" +REDIS_DB = 8 +REDIS_API_DB = 12 +REDIS_API_USAGE_DB = 13 + +UPLOAD_DIR = "/www/wwwroot/beta.obscura.work/upload_files/upload" +RESULT_DIR = "/www/wwwroot/beta.obscura.work/upload_files/result" +MAX_FILE_AGE = timedelta(hours=1) + +# 确保目录存在 +os.makedirs(UPLOAD_DIR, exist_ok=True) +os.makedirs(RESULT_DIR, exist_ok=True) + +# 初始化 Kafka +producer = KafkaProducer(bootstrap_servers=[KAFKA_BROKER]) +consumer = KafkaConsumer( + KAFKA_TOPIC, + bootstrap_servers=[KAFKA_BROKER], + group_id=KAFKA_GROUP_ID, + auto_offset_reset='earliest', + enable_auto_commit=True, + value_deserializer=lambda x: json.loads(x.decode('utf-8')) +) + +# 初始化 Redis +redis_client = Redis( + host=REDIS_HOST, + port=REDIS_PORT, + password=REDIS_PASSWORD, + db=REDIS_DB +) + +redis_api_client = Redis( + host=REDIS_HOST, + port=REDIS_PORT, + password=REDIS_PASSWORD, + db=REDIS_API_DB +) + +redis_api_usage_client = Redis( + host=REDIS_HOST, + port=REDIS_PORT, + password=REDIS_PASSWORD, + db=REDIS_API_USAGE_DB +) + +# 定义API密钥头部 +api_key_header = APIKeyHeader(name="Authorization", auto_error=False) + +# 定义 base62 字符集 +BASE62 = string.digits + string.ascii_letters + +# 验证API密钥的函数 +async def get_api_key(api_key: str = Security(api_key_header)): + if api_key and api_key.startswith("Bearer "): + key = api_key.split(" ")[1] + if key.startswith("obs-"): + return key + raise HTTPException( + status_code=401, + detail="无效的API密钥", + headers={"WWW-Authenticate": "Bearer"}, + ) + + +async def verify_api_key(api_key: str = Depends(get_api_key)): + redis_key = f"api_key:{api_key}" + + api_key_info = redis_api_client.hgetall(redis_key) + + if not api_key_info: + raise HTTPException(status_code=401, detail="无效的API密钥") + + api_key_info = {k.decode('utf-8'): v.decode('utf-8') for k, v in api_key_info.items()} + + if api_key_info.get('is_active') != '1': + raise HTTPException(status_code=401, detail="API密钥已停用") + + expires_at = datetime.fromisoformat(api_key_info.get('expires_at')) + if datetime.now(timezone.utc) > expires_at: + raise HTTPException(status_code=401, detail="API密钥已过期") + + usage_info = redis_api_usage_client.hgetall(redis_key) + usage_info = {k.decode('utf-8'): v.decode('utf-8') for k, v in usage_info.items()} + + return { + "api_key": api_key, + **api_key_info, + **usage_info + } + +async def update_token_usage(api_key: str, new_tokens_used: int, model_name: str): + redis_key = f"api_key:{api_key}" + current_time = datetime.now(timezone.utc).isoformat() + + pipe = redis_api_usage_client.pipeline() + + # 更新总的token使用量 + pipe.hincrby(redis_key, "tokens_used", new_tokens_used) + pipe.hset(redis_key, "last_used_at", current_time) + + # 更新特定模型的token使用量 + model_tokens_field = f"{model_name}_tokens_used" + model_last_used_field = f"{model_name}_last_used_at" + + pipe.hincrby(redis_key, model_tokens_field, new_tokens_used) + pipe.hset(redis_key, model_last_used_field, current_time) + + pipe.execute() + +def calculate_tokens(file_path: str, file_type: str) -> int: + base_tokens = 0 + + try: + file_size = os.path.getsize(file_path) # 获取文件大小(字节) + + # 基础token:每MB文件大小消耗10个token + base_tokens = int((file_size / (1024 * 1024)) * 10) + + if file_type == "image": + img = Image.open(file_path) + width, height = img.size + pixel_count = width * height + + # 图片token:每100个像素额外消耗5个token + image_tokens = int((pixel_count / 10000) * 5) + + base_tokens += image_tokens + + elif file_type == "video": + vr = VideoReader(file_path) + fps = vr.get_avg_fps() + frame_count = len(vr) + width, height = vr[0].shape[1], vr[0].shape[0] + + pixel_count = width * height * frame_count + duration = frame_count / fps # 视频时长(秒) + + # 视频token:每100万像素每秒额外消耗1个token + video_tokens = int((pixel_count / 10000) * (duration / 60)) + + base_tokens += video_tokens + + return max(1, base_tokens) # 确保至少返回1个token + except Exception as e: + print(f"计算token时出错: {str(e)}") + return 1 # 出错时返回默认值1 + +# 初始化模型 +model = Qwen2VLForConditionalGeneration.from_pretrained( + MODEL_PATH, torch_dtype="auto", device_map="cuda:1" +) + +min_pixels = 128*28*28 +max_pixels = 512*28*28 +processor = AutoProcessor.from_pretrained(MODEL_PATH, min_pixels=min_pixels, max_pixels=max_pixels) + +class MediaAnalysisSystem: + def __init__(self, model, processor): + self.model = model + self.processor = processor + self.MAX_NUM_FRAMES = 10 + + def encode_video(self, video_data): + def uniform_sample(l, n): + gap = len(l) / n + return [l[int(i * gap + gap / 2)] for i in range(n)] + + video_file = io.BytesIO(video_data) + vr = VideoReader(video_file) + sample_fps = round(vr.get_avg_fps() / 1) + frame_idx = list(range(0, len(vr), sample_fps)) + if len(frame_idx) > self.MAX_NUM_FRAMES: + frame_idx = uniform_sample(frame_idx, self.MAX_NUM_FRAMES) + frames = vr.get_batch(frame_idx).asnumpy() + frames = [Image.fromarray(v.astype('uint8')) for v in frames] + print('num frames:', len(frames)) + return frames + + def process_media(self, media_data, object_name, media_type='image'): + if not media_data: + raise ValueError(f"Empty {media_type} data for {object_name}") + + print(f"Processing {media_type}: {object_name}, data size: {len(media_data)} bytes") + + if media_type == 'video': + frames = self.encode_video(media_data) + media_content = {"type": "video", "video": frames, "fps": 1.0} + else: # image + image = Image.open(io.BytesIO(media_data)) + media_content = {"type": "image", "image": image} + + messages = [ + { + "role": "user", + "content": [ + media_content, + {"type": "text", "text": "用中文尽可能详细地描述这个" + ("视频" if media_type == "video" else "图片") + ",包括场景、人物数量、行为变化等。"}, + ], + } + ] + + text = self.processor.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + image_inputs, video_inputs = process_vision_info(messages) + inputs = self.processor( + text=[text], + images=image_inputs, + videos=video_inputs, + padding=True, + return_tensors="pt", + ) + inputs = inputs.to('cuda:1') + generated_ids = self.model.generate(**inputs, max_new_tokens=512) + generated_ids_trimmed = [ + out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) + ] + answer = self.processor.batch_decode( + generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False + )[0] + + extracted_info = self.extract_info(answer) + + result = { + "original_answer": answer, + "extracted_info": extracted_info, + "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S") + } + + if media_type == 'video': + result["num_frames"] = len(frames) + + return result + + def process_video(self, video_data, object_name): + return self.process_media(video_data, object_name, media_type='video') + + def process_image(self, image_data, object_name): + return self.process_media(image_data, object_name, media_type='image') + + @staticmethod + def extract_time_from_filename(object_name): + filename = os.path.basename(object_name) + time_str = filename.split('_')[0] + '_' + filename.split('_')[1].split('.')[0] + + try: + start_time = datetime.strptime(time_str, "%Y%m%d_%H%M%S") + end_time = start_time + timedelta(seconds=10) + return start_time, end_time + except ValueError: + print(f"无法从文件名 '{filename}' 解析时间。使用默认时间。") + return datetime.now(), datetime.now() + timedelta(seconds=10) + + @staticmethod + def extract_info(answer): + info = { + "environment": None, + "num_people": None, + "actions": [], + "interactions": [], + "objects": [], + "furniture": [] + } + + environments = ["办公室", "室内", "室外", "会议室"] + for env in environments: + if env in answer.lower(): + info["environment"] = env + break + + people_patterns = [ + r'(\d+)\s*(人|个人|位|名|员工|用户|小朋友|成年人|女性|男性)', + r'(一|二|三|四|五|六|七|八|九|十)\s*(人|个人|位|名|员工|用户|小朋友|成年人|女性|男性)', + r'(一个|几个)\s*(人|个人|员工|用户|小朋友|成年人|女性|男性)', + r'几\s*(名|位)\s*(人|员工|用户|小朋友|成年人|女性|男性)?', + r'(男|女)(性|生|士)', + r'(成年|未成年|青少年|老年)\s*(人|群体)', + r'(员工|职工|工人|学生|顾客|观众|游客|乘客)', + r'(群众|民众|大众|公众)', + r'(男女|老少|老幼|大人|小孩)' + ] + for pattern in people_patterns: + match = re.search(pattern, answer) + if match: + if match.group(1).isdigit(): + info["num_people"] = int(match.group(1)) + elif match.group(1) in ['一个', '一']: + info["num_people"] = 1 + else: + num_word_to_digit = { + '二': 2, '三': 3, '四': 4, '五': 5, + '六': 6, '七': 7, '八': 8, '九': 9, '十': 10 + } + info["num_people"] = num_word_to_digit.get(match.group(1), 0) + break + + actions = ["坐", "站", "摔倒", "跳舞", "转身", "摔", "倒", "倒下", "躺下", "转身", "跳跃", "跳", "躺", "睡", "说话"] + interactions = ["互动", "交流", "身体语言", "交谈", "讨论", "开会"] + objects = ["水瓶", "办公用品", "文件", "电脑"] + furniture = ["椅子", "桌子", "咖啡桌", "文件柜", "床", "沙发"] + + for item_list, key in [(actions, "actions"), (interactions, "interactions"), (objects, "objects"), (furniture, "furniture")]: + for item in item_list: + if item in answer: + info[key].append(item) + + return info + +# 初始化 MediaAnalysisSystem +media_analysis_system = MediaAnalysisSystem(model, processor) + +async def process_file(file: UploadFile, file_type: str, api_key_info: dict): + content = await file.read() + original_extension = os.path.splitext(file.filename)[1] + + filename = f"qwenvl_{uuid.uuid4()}{original_extension}" + file_path = os.path.join(UPLOAD_DIR, filename) + with open(file_path, "wb") as f: + f.write(content) + + # 计算token + tokens_required = calculate_tokens(file_path, file_type) + + # 检查并更新token使用量 + api_key = api_key_info['api_key'] + usage_key = f"api_key:{api_key}" + total_tokens = int(redis_api_usage_client.hget(usage_key, "total_tokens") or 0) + tokens_used = int(redis_api_usage_client.hget(usage_key, "tokens_used") or 0) + + if tokens_used + tokens_required > total_tokens: + raise HTTPException(status_code=403, detail="Token 余额不足") + + # 更新token使用量 + model_name = "Qwen2-VL-2B-Instruct" + await update_token_usage(api_key, tokens_required, model_name) + + + producer.send(KAFKA_TOPIC, json.dumps({ + "filename": filename, + "type": file_type + }).encode('utf-8')) + + redis_key = f"{file_type}_result:{filename}" + redis_client.set(redis_key, json.dumps({"status": "queued"})) + + # 获取更新后的 token 使用情况 + updated_api_key_info = await verify_api_key(api_key) + new_tokens_used = int(updated_api_key_info.get("tokens_used", 0)) + model_tokens_used = int(updated_api_key_info.get(f"{model_name}_tokens_used", 0)) + + return JSONResponse(content={ + "message": "文件已上传并排队等待处理", + "filename": filename, + "tokens_used": tokens_required, + "total_tokens_used": new_tokens_used, + f"{model_name}_tokens_used": model_tokens_used, + "tokens_remaining": total_tokens - new_tokens_used + }) + + +@qwenvl_app.post("/upload") +async def upload_file(file: UploadFile = File(...), api_key_info: dict = Depends(verify_api_key)): + try: + file_type = "image" if file.content_type.startswith("image") else "video" + return await process_file(file, file_type, api_key_info) + except Exception as e: + return JSONResponse(content={"error": str(e)}, status_code=500) + +@qwenvl_app.post("/analyze_video") +async def upload_file(file: UploadFile = File(...), api_key_info: dict = Depends(verify_api_key)): + api_key_info = await verify_api_key(api_key_info) + if not api_key_info: + raise HTTPException(status_code=403, detail="无效的API密钥") + + try: + return await process_file(file, "video", api_key_info) + except Exception as e: + return JSONResponse(content={"error": str(e)}, status_code=500) + +@qwenvl_app.post("/analyze_image") +async def upload_file(file: UploadFile = File(...), api_key_info: dict = Depends(verify_api_key)): + api_key_info = await verify_api_key(api_key_info) + if not api_key_info: + raise HTTPException(status_code=403, detail="无效的API密钥") + + try: + return await process_file(file, "image", api_key_info) + except Exception as e: + return JSONResponse(content={"error": str(e)}, status_code=500) + +def process_task(): + for message in consumer: + try: + if isinstance(message.value, dict): + task = message.value + else: + task = json.loads(message.value.decode('utf-8')) + + filename = task['filename'] + file_type = task['type'] + + file_path = os.path.join(UPLOAD_DIR, filename) + + redis_key = f"{file_type}_result:{filename}" + redis_client.set(redis_key, json.dumps({"status": "processing"})) + + with open(file_path, 'rb') as f: + file_data = f.read() + + if file_type == "video": + result = media_analysis_system.process_video(file_data, filename) + elif file_type == "image": + result = media_analysis_system.process_image(file_data, filename) + + # 保存结果到 JSON 文件 + result_file_path = os.path.join(RESULT_DIR, f"{filename}.json") + with open(result_file_path, 'w', encoding='utf-8') as f: + json.dump(result, f, ensure_ascii=False, indent=2) + + # 将结果存储在 Redis 中 + redis_client.set(redis_key, json.dumps({ + "status": "completed", + "result": result + })) + + except Exception as e: + print(f"Error processing task: {str(e)}") + if 'filename' in locals() and 'file_type' in locals(): + redis_key = f"{file_type}_result:{filename}" + redis_client.set(redis_key, json.dumps({"status": "failed", "error": str(e)})) + else: + print("Error occurred before task details were extracted") + +@qwenvl_app.get("/result/{filename}") +async def get_result(filename: str, api_key: str = Depends(get_api_key)): + api_key_info = await verify_api_key(api_key) + if not api_key_info: + raise HTTPException(status_code=403, detail="无效的API密钥") + + for file_type in ["video", "image"]: + redis_key = f"{file_type}_result:{filename}" + result = redis_client.get(redis_key) + if result: + result_json = json.loads(result) + + if result_json.get("status") == "queued": + return {"status": "queued", "message": "Your request is in the queue and will be processed soon."} + elif result_json.get("status") == "processing": + return {"status": "processing", "message": "Your request is being processed."} + else: + return result_json + + raise HTTPException(status_code=404, detail="Result not found") + + +async def listen_redis_changes(): + pubsub = redis_client.pubsub() + pubsub.psubscribe('__keyspace@5__:*_result:*') + + for message in pubsub.listen(): + if message['type'] == 'pmessage': + key = message['channel'].decode('utf-8').split(':')[-1] + print(f"Key changed: {key}") + +if __name__ == "__main__": + # 在后台线程中启动Kafka消费者 + consumer_thread = threading.Thread(target=process_task, daemon=True) + consumer_thread.start() + + import uvicorn + uvicorn.run(app, host="0.0.0.0", port=7005) \ No newline at end of file diff --git a/api_old/before/yolo_api.py b/api_old/before/yolo_api.py new file mode 100644 index 0000000..d2d5043 --- /dev/null +++ b/api_old/before/yolo_api.py @@ -0,0 +1,315 @@ +from fastapi import FastAPI, HTTPException, File, UploadFile +from fastapi.responses import StreamingResponse, JSONResponse +from fastapi.middleware.cors import CORSMiddleware +from ultralytics import YOLO +import cv2 +import numpy as np +import json +import uvicorn +from kafka import KafkaProducer, KafkaConsumer +from redis import Redis +import io +import uuid +import os +from datetime import datetime, timedelta +import threading +import torch +torch.cuda.set_device(1) +import colorsys + +app = FastAPI() +yolo_app = FastAPI() +app.mount("/yolo", yolo_app) + +# CORS配置 +ALLOWED_ORIGINS = 'https://beta.obscura.work' + +# 只为主应用添加CORS中间件 +app.add_middleware( + CORSMiddleware, + allow_origins=ALLOWED_ORIGINS, + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + +# 配置 +MODEL_PATH = "/home/zydi/models/yolov8x.pt" # 请替换为您的模型路径 +KAFKA_BROKER = "222.186.10.253:9092" +KAFKA_TOPIC = "yolo" # 指定Kafka topic +KAFKA_GROUP_ID = "yolo_group" # 指定消费者组ID + +REDIS_HOST = "222.186.10.253" +REDIS_PORT = 6379 +REDIS_PASSWORD = "Obscura@2024" +REDIS_DB = 6 + +UPLOAD_DIR = "/www/wwwroot/beta.obscura.work/upload_files/upload" +RESULT_DIR = "/www/wwwroot/beta.obscura.work/upload_files/result" +MAX_FILE_AGE = timedelta(hours=1) + +# 确保目录存在 +os.makedirs(UPLOAD_DIR, exist_ok=True) +os.makedirs(RESULT_DIR, exist_ok=True) + +# 初始化 Kafka +producer = KafkaProducer(bootstrap_servers=[KAFKA_BROKER]) +consumer = KafkaConsumer( + KAFKA_TOPIC, + bootstrap_servers=[KAFKA_BROKER], + group_id=KAFKA_GROUP_ID, + auto_offset_reset='earliest', + enable_auto_commit=True, + value_deserializer=lambda x: json.loads(x.decode('utf-8')) +) + +# 初始化 Redis +redis_client = Redis( + host=REDIS_HOST, + port=REDIS_PORT, + password=REDIS_PASSWORD, + db=REDIS_DB +) + +class yoloDetector: + def __init__(self, model_path): + self.model = YOLO(model_path) + + def detect(self, frame): + results = self.model(frame) + return results + + def format_results(self, results, original_shape): + formatted_results = [] + for r in results: + boxes = r.boxes + for box in boxes: + x1, y1, x2, y2 = box.xyxy[0].tolist() + + # 缩放坐标到原始图像尺寸 + x1, x2 = [x * original_shape[1] / 640 for x in [x1, x2]] + y1, y2 = [y * original_shape[0] / 640 for y in [y1, y2]] + + conf = box.conf.item() + cls = int(box.cls.item()) + name = self.model.names[cls] + + formatted_results.append({ + "class": name, + "confidence": conf, + "bbox": [x1, y1, x2, y2] + }) + return formatted_results + + def draw_results(self, frame, formatted_results): + for result in formatted_results: + x1, y1, x2, y2 = map(int, result['bbox']) + name = result['class'] + conf = result['confidence'] + + cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 255, 0), 2) # 使用固定的绿色 + label = f"{name} {conf:.2f}" + (text_width, text_height), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 1, 2) + cv2.rectangle(frame, (x1, y1 - text_height - 5), (x1 + text_width, y1), (0, 255, 0), -1) + cv2.putText(frame, label, (x1, y1 - 5), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 0), 2) + + return frame + +detector = yoloDetector(MODEL_PATH) + +def process_image(image_data, filename): + try: + original_img = cv2.imdecode(np.frombuffer(image_data, np.uint8), cv2.IMREAD_COLOR) + original_shape = original_img.shape + + img = cv2.cvtColor(original_img, cv2.COLOR_BGR2RGB) + img = cv2.resize(img, (640, 640)) + img = img.transpose((2, 0, 1)) + img = np.ascontiguousarray(img) + img = torch.from_numpy(img).float() + img /= 255.0 + img = img.unsqueeze(0) + + results = detector.detect(img) + + json_results = detector.format_results(results, original_shape) + + annotated_img = detector.draw_results(original_img, json_results) + + annotated_filename = f"yolo_{filename}" + annotated_path = os.path.join(RESULT_DIR, annotated_filename) + cv2.imwrite(annotated_path, annotated_img) + + return json_results, annotated_filename + except Exception as e: + print(f"处理图像时出错: {str(e)}") + return None, None + +def process_video(video_data, filename): + try: + temp_video_path = os.path.join(UPLOAD_DIR, f"yolo_{filename}") + with open(temp_video_path, 'wb') as temp_video: + temp_video.write(video_data) + + cap = cv2.VideoCapture(temp_video_path) + frame_count = 0 + json_results = [] + + fps = int(cap.get(cv2.CAP_PROP_FPS)) + width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + original_shape = (height, width) + + annotated_filename = f"yolo_{filename}" + output_path = os.path.join(RESULT_DIR, annotated_filename) + out = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*'mp4v'), 1, (width, height)) + + while cap.isOpened(): + ret, frame = cap.read() + if not ret: + break + + # 每秒只处理一帧 + if frame_count % fps == 0: + preprocessed_frame = preprocess_frame(frame) + + results = detector.detect(preprocessed_frame) + frame_json_results = detector.format_results(results, original_shape) + json_results.append({"frame": frame_count, "detections": frame_json_results}) + + annotated_frame = detector.draw_results(frame, frame_json_results) + out.write(annotated_frame) + + frame_count += 1 + + cap.release() + out.release() + + os.remove(temp_video_path) + + return json_results, annotated_filename + except Exception as e: + print(f"处理视频时出错: {str(e)}") + return None, None + +def preprocess_frame(frame): + # 预处理单个视频帧 + frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + frame_resized = cv2.resize(frame_rgb, (640, 640)) # 调整为YOLO输入尺寸 + frame_transposed = frame_resized.transpose((2, 0, 1)) # HWC转为CHW + frame_contiguous = np.ascontiguousarray(frame_transposed) + frame_tensor = torch.from_numpy(frame_contiguous).float() + frame_normalized = frame_tensor / 255.0 # 归一化到[0, 1] + frame_batched = frame_normalized.unsqueeze(0) # 添加批次维度 + return frame_batched + +@yolo_app.post("/upload") +async def upload_file(file: UploadFile = File(...)): + content = await file.read() + file_extension = os.path.splitext(file.filename)[1].lower() + new_filename = f"{uuid.uuid4()}{file_extension}" + + # Save the original file + original_file_path = os.path.join(UPLOAD_DIR, new_filename) + with open(original_file_path, "wb") as f: + f.write(content) + + # Send processing task to Kafka + producer.send(KAFKA_TOPIC, json.dumps({ + "filename": new_filename, + "file_type": "image" if file_extension in ['.jpg', '.jpeg', '.png'] else "video" + }).encode('utf-8')) + + # Set initial status in Redis + redis_key = f"yolo_result:{new_filename}" + redis_client.set(redis_key, json.dumps({"status": "queued"})) + + return JSONResponse(content={"message": "File uploaded and queued for processing", "filename": new_filename}) + +@yolo_app.get("/result/{filename}") +async def get_yolo_result(filename: str): + redis_key = f"yolo_result:{filename}" + result = redis_client.get(redis_key) + if result: + result_data = json.loads(result) + return JSONResponse(content=result_data) # 直接返回整个结果,包括 status + else: + raise HTTPException(status_code=404, detail="Result not found") + +@yolo_app.get("/annotated/{filename}") +async def get_annotated_file(filename: str): + redis_key = f"yolo_result:{filename}" + result = redis_client.get(redis_key) + if result: + result_data = json.loads(result) + if result_data["status"] == "completed": + annotated_filename = result_data["annotated_filename"] + file_path = os.path.join(RESULT_DIR, annotated_filename) + if os.path.exists(file_path): + def iterfile(): + with open(file_path, mode="rb") as file_like: + yield from file_like + file_extension = os.path.splitext(annotated_filename)[1].lower() + return StreamingResponse(iterfile(), media_type=f"image/{file_extension[1:]}" if file_extension in ['.jpg', '.jpeg', '.png'] else "video/mp4") + + raise HTTPException(status_code=404, detail="Annotated file not found") + +def process_task(): + for message in consumer: + task = message.value + filename = task['filename'] + file_type = task['file_type'] + + file_path = os.path.join(UPLOAD_DIR, filename) + + # Update status to "processing" + redis_key = f"yolo_result:{filename}" + redis_client.set(redis_key, json.dumps({"status": "processing"})) + + try: + if file_type == "image": + with open(file_path, 'rb') as f: + content = f.read() + json_results, annotated_filename = process_image(content, filename) + else: + with open(file_path, 'rb') as f: + content = f.read() + json_results, annotated_filename = process_video(content, filename) + + if json_results and annotated_filename: + redis_client.set(redis_key, json.dumps({ + "json_results": json_results, + "status": "completed", + "annotated_filename": annotated_filename + })) + else: + redis_client.set(redis_key, json.dumps({"status": "failed"})) + except Exception as e: + print(f"Error processing task: {str(e)}") + redis_client.set(redis_key, json.dumps({"status": "failed", "error": str(e)})) + +def listen_redis_changes(): + pubsub = redis_client.pubsub() + pubsub.psubscribe('__keyspace@3__:yolo_result:*') # 监听所有yolo_result键的变化 + + for message in pubsub.listen(): + if message['type'] == 'pmessage': + key = message['channel'].decode('utf-8').split(':')[-1] + operation = message['data'].decode('utf-8') + + if operation == 'set': + value = redis_client.get(f"yolo_result:{key}") + if value: + result = json.loads(value) + print(f"Status update for {key}: {result['status']}") + + # 这里可以添加其他处理逻辑,比如发送通知等 + +if __name__ == "__main__": + # 启动处理任务的线程 + threading.Thread(target=process_task, daemon=True).start() + + # 启动Redis监听线程 + threading.Thread(target=listen_redis_changes, daemon=True).start() + + uvicorn.run(app, host="0.0.0.0", port=7003) \ No newline at end of file diff --git a/api_old/before/yolo_key.py b/api_old/before/yolo_key.py new file mode 100644 index 0000000..f4f9721 --- /dev/null +++ b/api_old/before/yolo_key.py @@ -0,0 +1,462 @@ +import os +import cv2 +import torch +import numpy as np +from fastapi import FastAPI, HTTPException, File, UploadFile, Depends, Header +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import JSONResponse, StreamingResponse +from fastapi.security import APIKeyHeader +from redis import Redis +from ultralytics import YOLO +import json +import uvicorn +from kafka import KafkaProducer, KafkaConsumer +import uuid +from datetime import datetime, timedelta, timezone +import threading + + +app = FastAPI() +yolo_app = FastAPI() +app.mount("/yolo", yolo_app) + +# CORS配置 +ALLOWED_ORIGINS = 'https://beta.obscura.work' + +# 只为主应用添加CORS中间件 +app.add_middleware( + CORSMiddleware, + allow_origins=ALLOWED_ORIGINS, + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + +# 配置 +MODEL_PATH = "/home/zydi/models/yolov8x.pt" # 请替换为您的模型路径 +KAFKA_BROKER = "222.186.10.253:9092" +KAFKA_TOPIC = "yolo" # 指定Kafka topic +KAFKA_GROUP_ID = "yolo_group" # 指定消费者组ID + +REDIS_HOST = "222.186.10.253" +REDIS_PORT = 6379 +REDIS_PASSWORD = "Obscura@2024" +REDIS_DB = 6 +REDIS_API_DB = 12 +REDIS_API_USAGE_DB = 13 +UPLOAD_DIR = "/www/wwwroot/beta.obscura.work/upload_files/upload" +RESULT_DIR = "/www/wwwroot/beta.obscura.work/upload_files/result" +MAX_FILE_AGE = timedelta(hours=1) + +# 确保目录存在 +os.makedirs(UPLOAD_DIR, exist_ok=True) +os.makedirs(RESULT_DIR, exist_ok=True) + +# 初始化 Kafka +producer = KafkaProducer(bootstrap_servers=[KAFKA_BROKER]) +consumer = KafkaConsumer( + KAFKA_TOPIC, + bootstrap_servers=[KAFKA_BROKER], + group_id=KAFKA_GROUP_ID, + auto_offset_reset='earliest', + enable_auto_commit=True, + value_deserializer=lambda x: json.loads(x.decode('utf-8')) +) + +# 初始化 Redis +redis_client = Redis( + host=REDIS_HOST, + port=REDIS_PORT, + password=REDIS_PASSWORD, + db=REDIS_DB +) + +redis_api_client = Redis( + host=REDIS_HOST, + port=REDIS_PORT, + password=REDIS_PASSWORD, + db=REDIS_API_DB +) +redis_api_usage_client = Redis( + host=REDIS_HOST, + port=REDIS_PORT, + password=REDIS_PASSWORD, + db=REDIS_API_USAGE_DB +) + +# 添加API密钥验证 +API_KEY_NAME = "X-API-Key" +api_key_header = APIKeyHeader(name=API_KEY_NAME, auto_error=False) + +async def get_api_key(api_key: str = Header(None, alias=API_KEY_NAME)): + if api_key is None: + raise HTTPException(status_code=400, detail="API密钥缺失") + return api_key + +def calculate_tokens(file_path: str, file_type: str) -> int: + base_tokens = 0 + + try: + file_size = os.path.getsize(file_path) # 获取文件大小(字节) + + # 基础token:每MB文件大小消耗10个token + base_tokens = int((file_size / (1024 * 1024)) * 10) + + if file_type == "image": + img = cv2.imread(file_path) + if img is None: + raise ValueError("无法读取图片文件") + height, width = img.shape[:2] + pixel_count = height * width + + # 图片token:每100个像素额外消耗5个token + image_tokens = int((pixel_count / 10000) * 5) + + base_tokens += image_tokens + + elif file_type == "video": + cap = cv2.VideoCapture(file_path) + if not cap.isOpened(): + raise ValueError("无法打开视频文件") + fps = cap.get(cv2.CAP_PROP_FPS) + frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + cap.release() + + pixel_count = width * height * frame_count + duration = frame_count / fps # 视频时长(秒) + + # 视频token:每100万像素每秒额外消耗1个token + video_tokens = int((pixel_count / 10000) * (duration / 60)) + + base_tokens += video_tokens + + return max(1, base_tokens) # 确保至少返回1个token + except Exception as e: + print(f"计算token时出错: {str(e)}") + return 1 # 出错时返回默认值1 + + +async def verify_api_key(api_key: str = Depends(get_api_key)): + redis_key = f"api_key:{api_key}" + + api_key_info = redis_api_client.hgetall(redis_key) + + if not api_key_info: + return None + + api_key_info = {k.decode('utf-8'): v.decode('utf-8') for k, v in api_key_info.items()} + + if api_key_info.get('is_active') != '1': + return None + + expires_at = datetime.fromisoformat(api_key_info.get('expires_at')) + if datetime.now(timezone.utc) > expires_at: + return None + + usage_info = redis_api_usage_client.hgetall(redis_key) + usage_info = {k.decode('utf-8'): v.decode('utf-8') for k, v in usage_info.items()} + + return { + **api_key_info, + **usage_info + } + +async def update_token_usage(api_key: str, new_tokens_used: int, model_name: str): + redis_key = f"api_key:{api_key}" + current_time = datetime.now(timezone.utc).isoformat() + + pipe = redis_api_usage_client.pipeline() + + # 更新总的token使用量 + pipe.hincrby(redis_key, "tokens_used", new_tokens_used) + pipe.hset(redis_key, "last_used_at", current_time) + + # 更新特定模型的token使用量 + model_tokens_field = f"{model_name}_tokens_used" + model_last_used_field = f"{model_name}_last_used_at" + + pipe.hincrby(redis_key, model_tokens_field, new_tokens_used) + pipe.hset(redis_key, model_last_used_field, current_time) + + pipe.execute() + +class yoloDetector: + def __init__(self, model_path): + self.model = YOLO(model_path) + + def detect(self, frame): + results = self.model(frame) + return results + + def format_results(self, results, original_shape): + formatted_results = [] + for r in results: + boxes = r.boxes + for box in boxes: + x1, y1, x2, y2 = box.xyxy[0].tolist() + + # 缩放坐标到原始图像尺寸 + x1, x2 = [x * original_shape[1] / 640 for x in [x1, x2]] + y1, y2 = [y * original_shape[0] / 640 for y in [y1, y2]] + + conf = box.conf.item() + cls = int(box.cls.item()) + name = self.model.names[cls] + + formatted_results.append({ + "class": name, + "confidence": conf, + "bbox": [x1, y1, x2, y2] + }) + return formatted_results + + def draw_results(self, frame, formatted_results): + for result in formatted_results: + x1, y1, x2, y2 = map(int, result['bbox']) + name = result['class'] + conf = result['confidence'] + + cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 255, 0), 2) # 使用固定的绿色 + label = f"{name} {conf:.2f}" + (text_width, text_height), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 1, 2) + cv2.rectangle(frame, (x1, y1 - text_height - 5), (x1 + text_width, y1), (0, 255, 0), -1) + cv2.putText(frame, label, (x1, y1 - 5), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 0), 2) + + return frame + +detector = yoloDetector(MODEL_PATH) + +def process_image(image_data, filename): + try: + original_img = cv2.imdecode(np.frombuffer(image_data, np.uint8), cv2.IMREAD_COLOR) + original_shape = original_img.shape + + img = cv2.cvtColor(original_img, cv2.COLOR_BGR2RGB) + img = cv2.resize(img, (640, 640)) + img = img.transpose((2, 0, 1)) + img = np.ascontiguousarray(img) + img = torch.from_numpy(img).float() + img /= 255.0 + img = img.unsqueeze(0) + + results = detector.detect(img) + + json_results = detector.format_results(results, original_shape) + + annotated_img = detector.draw_results(original_img, json_results) + + annotated_filename = f"yolo_{filename}" + annotated_path = os.path.join(RESULT_DIR, annotated_filename) + cv2.imwrite(annotated_path, annotated_img) + + return json_results, annotated_filename + except Exception as e: + print(f"处理图像时出错: {str(e)}") + return None, None + +def process_video(video_data, filename): + try: + temp_video_path = os.path.join(UPLOAD_DIR, f"yolo_{filename}") + with open(temp_video_path, 'wb') as temp_video: + temp_video.write(video_data) + + cap = cv2.VideoCapture(temp_video_path) + frame_count = 0 + json_results = [] + + fps = int(cap.get(cv2.CAP_PROP_FPS)) + width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + original_shape = (height, width) + + annotated_filename = f"yolo_{filename}" + output_path = os.path.join(RESULT_DIR, annotated_filename) + out = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*'mp4v'), 1, (width, height)) + + while cap.isOpened(): + ret, frame = cap.read() + if not ret: + break + + # 每秒只处理一帧 + if frame_count % fps == 0: + preprocessed_frame = preprocess_frame(frame) + + results = detector.detect(preprocessed_frame) + frame_json_results = detector.format_results(results, original_shape) + json_results.append({"frame": frame_count, "detections": frame_json_results}) + + annotated_frame = detector.draw_results(frame, frame_json_results) + out.write(annotated_frame) + + frame_count += 1 + + cap.release() + out.release() + + os.remove(temp_video_path) + + return json_results, annotated_filename + except Exception as e: + print(f"处理视频时出错: {str(e)}") + return None, None + +def preprocess_frame(frame): + # 预处理单个视频帧 + frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + frame_resized = cv2.resize(frame_rgb, (640, 640)) # 调整为YOLO输入尺寸 + frame_transposed = frame_resized.transpose((2, 0, 1)) # HWC转为CHW + frame_contiguous = np.ascontiguousarray(frame_transposed) + frame_tensor = torch.from_numpy(frame_contiguous).float() + frame_normalized = frame_tensor / 255.0 # 归一化到[0, 1] + frame_batched = frame_normalized.unsqueeze(0) # 添加批次维度 + return frame_batched + +@yolo_app.post("/upload") +async def upload_file(file: UploadFile = File(...),api_key: str = Depends(get_api_key)): + # 验证 API key + api_key_info = await verify_api_key(api_key) + if not api_key_info: + raise HTTPException(status_code=403, detail="无效的API密钥") + + content = await file.read() + file_extension = os.path.splitext(file.filename)[1].lower() + new_filename = f"{uuid.uuid4()}{file_extension}" + + # 保存原始文件 + original_file_path = os.path.join(UPLOAD_DIR, new_filename) + with open(original_file_path, "wb") as f: + f.write(content) + + # 计算 token + file_type = "image" if file_extension in ['.jpg', '.jpeg', '.png'] else "video" + tokens_required = calculate_tokens(original_file_path, file_type) + if tokens_required is None or tokens_required <= 0: + raise HTTPException(status_code=500, detail="无法计算所需的token数量") + + # 检查并更新 token 使用量 + usage_key = f"api_key:{api_key}" + total_tokens = int(redis_api_usage_client.hget(usage_key, "total_tokens") or 0) + tokens_used = int(redis_api_usage_client.hget(usage_key, "tokens_used") or 0) + + if tokens_used + tokens_required > total_tokens: + raise HTTPException(status_code=403, detail="Token 余额不足") + + # 更新 token 使用量 + model_name = "yolov8x" + await update_token_usage(api_key, tokens_required, model_name) + + + # 发送处理任务到 Kafka + producer.send(KAFKA_TOPIC, json.dumps({ + "filename": new_filename, + "file_type": file_type + }).encode('utf-8')) + + # 在 Redis 中设置初始状态 + redis_key = f"yolo_result:{new_filename}" + redis_client.set(redis_key, json.dumps({"status": "queued"})) + + # 获取更新后的 token 使用情况 + updated_api_key_info = await verify_api_key(api_key) + new_tokens_used = int(updated_api_key_info.get("tokens_used", 0)) + model_tokens_used = int(updated_api_key_info.get(f"{model_name}_tokens_used", 0)) + + return JSONResponse(content={ + "message": "文件已上传并排队等待处理", + "filename": new_filename, + "tokens_used": tokens_required, + "total_tokens_used": new_tokens_used, + f"{model_name}_tokens_used": model_tokens_used, + "tokens_remaining": total_tokens - new_tokens_used + }) +@yolo_app.get("/result/{filename}", dependencies=[Depends(verify_api_key)]) +async def get_yolo_result(filename: str): + redis_key = f"yolo_result:{filename}" + result = redis_client.get(redis_key) + if result: + result_data = json.loads(result) + return JSONResponse(content=result_data) # 直接返回整个结果,包括 status + else: + raise HTTPException(status_code=404, detail="Result not found") + +@yolo_app.get("/annotated/{filename}", dependencies=[Depends(verify_api_key)]) +async def get_annotated_file(filename: str): + redis_key = f"yolo_result:{filename}" + result = redis_client.get(redis_key) + if result: + result_data = json.loads(result) + if result_data["status"] == "completed": + annotated_filename = result_data["annotated_filename"] + file_path = os.path.join(RESULT_DIR, annotated_filename) + if os.path.exists(file_path): + def iterfile(): + with open(file_path, mode="rb") as file_like: + yield from file_like + file_extension = os.path.splitext(annotated_filename)[1].lower() + return StreamingResponse(iterfile(), media_type=f"image/{file_extension[1:]}" if file_extension in ['.jpg', '.jpeg', '.png'] else "video/mp4") + + raise HTTPException(status_code=404, detail="Annotated file not found") + +def process_task(): + for message in consumer: + task = message.value + filename = task['filename'] + file_type = task['file_type'] + + file_path = os.path.join(UPLOAD_DIR, filename) + + # Update status to "processing" + redis_key = f"yolo_result:{filename}" + redis_client.set(redis_key, json.dumps({"status": "processing"})) + + try: + if file_type == "image": + with open(file_path, 'rb') as f: + content = f.read() + json_results, annotated_filename = process_image(content, filename) + else: + with open(file_path, 'rb') as f: + content = f.read() + json_results, annotated_filename = process_video(content, filename) + + if json_results and annotated_filename: + redis_client.set(redis_key, json.dumps({ + "json_results": json_results, + "status": "completed", + "annotated_filename": annotated_filename + })) + else: + redis_client.set(redis_key, json.dumps({"status": "failed"})) + except Exception as e: + print(f"Error processing task: {str(e)}") + redis_client.set(redis_key, json.dumps({"status": "failed", "error": str(e)})) + +def listen_redis_changes(): + pubsub = redis_client.pubsub() + pubsub.psubscribe('__keyspace@3__:yolo_result:*') # 监听所有yolo_result键的变化 + + for message in pubsub.listen(): + if message['type'] == 'pmessage': + key = message['channel'].decode('utf-8').split(':')[-1] + operation = message['data'].decode('utf-8') + + if operation == 'set': + value = redis_client.get(f"yolo_result:{key}") + if value: + result = json.loads(value) + print(f"Status update for {key}: {result['status']}") + + # 这里可以添加其他处理逻辑,比如发送通知等 + + +if __name__ == "__main__": + # 启动处理任务的线程 + threading.Thread(target=process_task, daemon=True).start() + + # 启动Redis监听线程 + threading.Thread(target=listen_redis_changes, daemon=True).start() + + uvicorn.run(app, host="0.0.0.0", port=7003) \ No newline at end of file diff --git a/api_old/before/yolo_key2.py b/api_old/before/yolo_key2.py new file mode 100644 index 0000000..f9ab7e0 --- /dev/null +++ b/api_old/before/yolo_key2.py @@ -0,0 +1,472 @@ +import os +import cv2 +import torch +import numpy as np +from fastapi import FastAPI, HTTPException, File, UploadFile, Depends, Security +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import JSONResponse, StreamingResponse +from fastapi.security import APIKeyHeader +from redis import Redis +from ultralytics import YOLO +import json +import uvicorn +from kafka import KafkaProducer, KafkaConsumer +import uuid +from datetime import datetime, timedelta, timezone +import threading +import string + + +app = FastAPI() +yolo_app = FastAPI() +app.mount("/yolo", yolo_app) + +# CORS配置 +ALLOWED_ORIGINS = 'https://beta.obscura.work' + +# 只为主应用添加CORS中间件 +app.add_middleware( + CORSMiddleware, + allow_origins=ALLOWED_ORIGINS, + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + +# 配置 +MODEL_PATH = "/home/zydi/models/yolov8x.pt" # 请替换为您的模型路径 +KAFKA_BROKER = "222.186.10.253:9092" +KAFKA_TOPIC = "yolo" # 指定Kafka topic +KAFKA_GROUP_ID = "yolo_group" # 指定消费者组ID + +REDIS_HOST = "222.186.10.253" +REDIS_PORT = 6379 +REDIS_PASSWORD = "Obscura@2024" +REDIS_DB = 6 +REDIS_API_DB = 12 +REDIS_API_USAGE_DB = 13 +UPLOAD_DIR = "/www/wwwroot/beta.obscura.work/upload_files/upload" +RESULT_DIR = "/www/wwwroot/beta.obscura.work/upload_files/result" +MAX_FILE_AGE = timedelta(hours=1) + +# 确保目录存在 +os.makedirs(UPLOAD_DIR, exist_ok=True) +os.makedirs(RESULT_DIR, exist_ok=True) + +# 初始化 Kafka +producer = KafkaProducer(bootstrap_servers=[KAFKA_BROKER]) +consumer = KafkaConsumer( + KAFKA_TOPIC, + bootstrap_servers=[KAFKA_BROKER], + group_id=KAFKA_GROUP_ID, + auto_offset_reset='earliest', + enable_auto_commit=True, + value_deserializer=lambda x: json.loads(x.decode('utf-8')) +) + +# 初始化 Redis +redis_client = Redis( + host=REDIS_HOST, + port=REDIS_PORT, + password=REDIS_PASSWORD, + db=REDIS_DB +) + +redis_api_client = Redis( + host=REDIS_HOST, + port=REDIS_PORT, + password=REDIS_PASSWORD, + db=REDIS_API_DB +) +redis_api_usage_client = Redis( + host=REDIS_HOST, + port=REDIS_PORT, + password=REDIS_PASSWORD, + db=REDIS_API_USAGE_DB +) + +# 定义API密钥头部 +api_key_header = APIKeyHeader(name="Authorization", auto_error=False) + +# 定义 base62 字符集 +BASE62 = string.digits + string.ascii_letters + +# 验证API密钥的函数 +async def get_api_key(api_key: str = Security(api_key_header)): + if api_key and api_key.startswith("Bearer "): + key = api_key.split(" ")[1] + if key.startswith("obs-"): + return key + raise HTTPException( + status_code=401, + detail="无效的API密钥", + headers={"WWW-Authenticate": "Bearer"}, + ) + + +def calculate_tokens(file_path: str, file_type: str) -> int: + base_tokens = 0 + + try: + file_size = os.path.getsize(file_path) # 获取文件大小(字节) + + # 基础token:每MB文件大小消耗10个token + base_tokens = int((file_size / (1024 * 1024)) * 10) + + if file_type == "image": + img = cv2.imread(file_path) + if img is None: + raise ValueError("无法读取图片文件") + height, width = img.shape[:2] + pixel_count = height * width + + # 图片token:每100个像素额外消耗5个token + image_tokens = int((pixel_count / 10000) * 5) + + base_tokens += image_tokens + + elif file_type == "video": + cap = cv2.VideoCapture(file_path) + if not cap.isOpened(): + raise ValueError("无法打开视频文件") + fps = cap.get(cv2.CAP_PROP_FPS) + frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + cap.release() + + pixel_count = width * height * frame_count + duration = frame_count / fps # 视频时长(秒) + + # 视频token:每100万像素每秒额外消耗1个token + video_tokens = int((pixel_count / 10000) * (duration / 60)) + + base_tokens += video_tokens + + return max(1, base_tokens) # 确保至少返回1个token + except Exception as e: + print(f"计算token时出错: {str(e)}") + return 1 # 出错时返回默认值1 + + +async def verify_api_key(api_key: str = Depends(get_api_key)): + redis_key = f"api_key:{api_key}" + + api_key_info = redis_api_client.hgetall(redis_key) + + if not api_key_info: + raise HTTPException(status_code=401, detail="无效的API密钥") + + api_key_info = {k.decode('utf-8'): v.decode('utf-8') for k, v in api_key_info.items()} + + if api_key_info.get('is_active') != '1': + raise HTTPException(status_code=401, detail="API密钥已停用") + + expires_at = datetime.fromisoformat(api_key_info.get('expires_at')) + if datetime.now(timezone.utc) > expires_at: + raise HTTPException(status_code=401, detail="API密钥已过期") + + usage_info = redis_api_usage_client.hgetall(redis_key) + usage_info = {k.decode('utf-8'): v.decode('utf-8') for k, v in usage_info.items()} + + return { + "api_key": api_key, + **api_key_info, + **usage_info + } + +async def update_token_usage(api_key: str, new_tokens_used: int, model_name: str): + redis_key = f"api_key:{api_key}" + current_time = datetime.now(timezone.utc).isoformat() + + pipe = redis_api_usage_client.pipeline() + + # 更新总的token使用量 + pipe.hincrby(redis_key, "tokens_used", new_tokens_used) + pipe.hset(redis_key, "last_used_at", current_time) + + # 更新特定模型的token使用量 + model_tokens_field = f"{model_name}_tokens_used" + model_last_used_field = f"{model_name}_last_used_at" + + pipe.hincrby(redis_key, model_tokens_field, new_tokens_used) + pipe.hset(redis_key, model_last_used_field, current_time) + + pipe.execute() + +class yoloDetector: + def __init__(self, model_path): + self.model = YOLO(model_path) + + def detect(self, frame): + results = self.model(frame) + return results + + def format_results(self, results, original_shape): + formatted_results = [] + for r in results: + boxes = r.boxes + for box in boxes: + x1, y1, x2, y2 = box.xyxy[0].tolist() + + # 缩放坐标到原始图像尺寸 + x1, x2 = [x * original_shape[1] / 640 for x in [x1, x2]] + y1, y2 = [y * original_shape[0] / 640 for y in [y1, y2]] + + conf = box.conf.item() + cls = int(box.cls.item()) + name = self.model.names[cls] + + formatted_results.append({ + "class": name, + "confidence": conf, + "bbox": [x1, y1, x2, y2] + }) + return formatted_results + + def draw_results(self, frame, formatted_results): + for result in formatted_results: + x1, y1, x2, y2 = map(int, result['bbox']) + name = result['class'] + conf = result['confidence'] + + cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 255, 0), 2) # 使用固定的绿色 + label = f"{name} {conf:.2f}" + (text_width, text_height), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 1, 2) + cv2.rectangle(frame, (x1, y1 - text_height - 5), (x1 + text_width, y1), (0, 255, 0), -1) + cv2.putText(frame, label, (x1, y1 - 5), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 0), 2) + + return frame + +detector = yoloDetector(MODEL_PATH) + +def process_image(image_data, filename): + try: + original_img = cv2.imdecode(np.frombuffer(image_data, np.uint8), cv2.IMREAD_COLOR) + original_shape = original_img.shape + + img = cv2.cvtColor(original_img, cv2.COLOR_BGR2RGB) + img = cv2.resize(img, (640, 640)) + img = img.transpose((2, 0, 1)) + img = np.ascontiguousarray(img) + img = torch.from_numpy(img).float() + img /= 255.0 + img = img.unsqueeze(0) + + results = detector.detect(img) + + json_results = detector.format_results(results, original_shape) + + annotated_img = detector.draw_results(original_img, json_results) + + annotated_filename = f"yolo_{filename}" + annotated_path = os.path.join(RESULT_DIR, annotated_filename) + cv2.imwrite(annotated_path, annotated_img) + + return json_results, annotated_filename + except Exception as e: + print(f"处理图像时出错: {str(e)}") + return None, None + +def process_video(video_data, filename): + try: + temp_video_path = os.path.join(UPLOAD_DIR, f"yolo_{filename}") + with open(temp_video_path, 'wb') as temp_video: + temp_video.write(video_data) + + cap = cv2.VideoCapture(temp_video_path) + frame_count = 0 + json_results = [] + + fps = int(cap.get(cv2.CAP_PROP_FPS)) + width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + original_shape = (height, width) + + annotated_filename = f"yolo_{filename}" + output_path = os.path.join(RESULT_DIR, annotated_filename) + out = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*'mp4v'), 1, (width, height)) + + while cap.isOpened(): + ret, frame = cap.read() + if not ret: + break + + # 每秒只处理一帧 + if frame_count % fps == 0: + preprocessed_frame = preprocess_frame(frame) + + results = detector.detect(preprocessed_frame) + frame_json_results = detector.format_results(results, original_shape) + json_results.append({"frame": frame_count, "detections": frame_json_results}) + + annotated_frame = detector.draw_results(frame, frame_json_results) + out.write(annotated_frame) + + frame_count += 1 + + cap.release() + out.release() + + os.remove(temp_video_path) + + return json_results, annotated_filename + except Exception as e: + print(f"处理视频时出错: {str(e)}") + return None, None + +def preprocess_frame(frame): + # 预处理单个视频帧 + frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + frame_resized = cv2.resize(frame_rgb, (640, 640)) # 调整为YOLO输入尺寸 + frame_transposed = frame_resized.transpose((2, 0, 1)) # HWC转为CHW + frame_contiguous = np.ascontiguousarray(frame_transposed) + frame_tensor = torch.from_numpy(frame_contiguous).float() + frame_normalized = frame_tensor / 255.0 # 归一化到[0, 1] + frame_batched = frame_normalized.unsqueeze(0) # 添加批次维度 + return frame_batched + +@yolo_app.post("/upload") +async def upload_file(file: UploadFile = File(...), api_key_info: dict = Depends(verify_api_key)): + content = await file.read() + file_extension = os.path.splitext(file.filename)[1].lower() + new_filename = f"{uuid.uuid4()}{file_extension}" + + # 保存原始文件 + original_file_path = os.path.join(UPLOAD_DIR, new_filename) + with open(original_file_path, "wb") as f: + f.write(content) + + # 计算 token + file_type = "image" if file_extension in ['.jpg', '.jpeg', '.png'] else "video" + tokens_required = calculate_tokens(original_file_path, file_type) + if tokens_required is None or tokens_required <= 0: + raise HTTPException(status_code=500, detail="无法计算所需的token数量") + + # 检查并更新 token 使用量 + api_key = api_key_info['api_key'] + usage_key = f"api_key:{api_key}" + total_tokens = int(redis_api_usage_client.hget(usage_key, "total_tokens") or 0) + tokens_used = int(redis_api_usage_client.hget(usage_key, "tokens_used") or 0) + + if tokens_used + tokens_required > total_tokens: + raise HTTPException(status_code=403, detail="Token 余额不足") + + # 更新 token 使用量 + model_name = "yolov8x" + await update_token_usage(api_key, tokens_required, model_name) + + + # 发送处理任务到 Kafka + producer.send(KAFKA_TOPIC, json.dumps({ + "filename": new_filename, + "file_type": file_type + }).encode('utf-8')) + + # 在 Redis 中设置初始状态 + redis_key = f"yolo_result:{new_filename}" + redis_client.set(redis_key, json.dumps({"status": "queued"})) + + # 获取更新后的 token 使用情况 + updated_api_key_info = await verify_api_key(api_key) + new_tokens_used = int(updated_api_key_info.get("tokens_used", 0)) + model_tokens_used = int(updated_api_key_info.get(f"{model_name}_tokens_used", 0)) + + return JSONResponse(content={ + "message": "文件已上传并排队等待处理", + "filename": new_filename, + "tokens_used": tokens_required, + "total_tokens_used": new_tokens_used, + f"{model_name}_tokens_used": model_tokens_used, + "tokens_remaining": total_tokens - new_tokens_used + }) + + +@yolo_app.get("/result/{filename}", dependencies=[Depends(verify_api_key)]) +async def get_yolo_result(filename: str): + redis_key = f"yolo_result:{filename}" + result = redis_client.get(redis_key) + if result: + result_data = json.loads(result) + return JSONResponse(content=result_data) # 直接返回整个结果,包括 status + else: + raise HTTPException(status_code=404, detail="Result not found") + +@yolo_app.get("/annotated/{filename}", dependencies=[Depends(verify_api_key)]) +async def get_annotated_file(filename: str): + redis_key = f"yolo_result:{filename}" + result = redis_client.get(redis_key) + if result: + result_data = json.loads(result) + if result_data["status"] == "completed": + annotated_filename = result_data["annotated_filename"] + file_path = os.path.join(RESULT_DIR, annotated_filename) + if os.path.exists(file_path): + def iterfile(): + with open(file_path, mode="rb") as file_like: + yield from file_like + file_extension = os.path.splitext(annotated_filename)[1].lower() + return StreamingResponse(iterfile(), media_type=f"image/{file_extension[1:]}" if file_extension in ['.jpg', '.jpeg', '.png'] else "video/mp4") + + raise HTTPException(status_code=404, detail="Annotated file not found") + +def process_task(): + for message in consumer: + task = message.value + filename = task['filename'] + file_type = task['file_type'] + + file_path = os.path.join(UPLOAD_DIR, filename) + + # Update status to "processing" + redis_key = f"yolo_result:{filename}" + redis_client.set(redis_key, json.dumps({"status": "processing"})) + + try: + if file_type == "image": + with open(file_path, 'rb') as f: + content = f.read() + json_results, annotated_filename = process_image(content, filename) + else: + with open(file_path, 'rb') as f: + content = f.read() + json_results, annotated_filename = process_video(content, filename) + + if json_results and annotated_filename: + redis_client.set(redis_key, json.dumps({ + "json_results": json_results, + "status": "completed", + "annotated_filename": annotated_filename + })) + else: + redis_client.set(redis_key, json.dumps({"status": "failed"})) + except Exception as e: + print(f"Error processing task: {str(e)}") + redis_client.set(redis_key, json.dumps({"status": "failed", "error": str(e)})) + +def listen_redis_changes(): + pubsub = redis_client.pubsub() + pubsub.psubscribe('__keyspace@3__:yolo_result:*') # 监听所有yolo_result键的变化 + + for message in pubsub.listen(): + if message['type'] == 'pmessage': + key = message['channel'].decode('utf-8').split(':')[-1] + operation = message['data'].decode('utf-8') + + if operation == 'set': + value = redis_client.get(f"yolo_result:{key}") + if value: + result = json.loads(value) + print(f"Status update for {key}: {result['status']}") + + # 这里可以添加其他处理逻辑,比如发送通知等 + + +if __name__ == "__main__": + # 启动处理任务的线程 + threading.Thread(target=process_task, daemon=True).start() + + # 启动Redis监听线程 + threading.Thread(target=listen_redis_changes, daemon=True).start() + + uvicorn.run(app, host="0.0.0.0", port=7003) \ No newline at end of file diff --git a/chat_history/ChatTTS/.gitkeep b/chat_history/ChatTTS/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/chat_history/chat_ser2.py b/chat_history/chat_ser2.py new file mode 100644 index 0000000..b510914 --- /dev/null +++ b/chat_history/chat_ser2.py @@ -0,0 +1,88 @@ +from flask import Flask, request, send_file, jsonify +import ChatTTS +import tempfile +import numpy as np +import soundfile as sf +from flask_cors import CORS +import os +import pickle +import asyncio +from concurrent.futures import ThreadPoolExecutor +import time +import hashlib + +app = Flask(__name__) +CORS(app) + +chat_tts = ChatTTS.Chat() +chat_tts.load(compile=False) + +SAMPLE_RATE = 24000 + +SPEAKER_EMBEDDING_FILE = 'cutegirl_speaker_embedding.pkl' +AUDIO_DIR = '/www/wwwroot/chat.obscura.work/audio_files' + +with open(SPEAKER_EMBEDDING_FILE, 'rb') as f: + FIXED_SPEAKER = pickle.load(f) + +executor = ThreadPoolExecutor(max_workers=3) + +def generate_audio(text): + params_infer_code = ChatTTS.Chat.InferCodeParams( + spk_emb=FIXED_SPEAKER, + temperature=0.3, + top_P=0.6, + top_K=20, + ) + + wavs = chat_tts.infer(text, params_infer_code=params_infer_code) + audio_data = wavs[0] + + if not np.issubdtype(audio_data.dtype, np.floating): + audio_data = audio_data.astype(np.float32) + + if np.max(np.abs(audio_data)) > 1: + audio_data = audio_data / np.max(np.abs(audio_data)) + + return audio_data + +def get_audio_filename(text): + return hashlib.md5(text.encode()).hexdigest() + '.wav' + +@app.route('/synthesize', methods=['POST', 'OPTIONS']) +async def synthesize(): + if request.method == 'OPTIONS': + return '', 204 + + data = request.json + texts = data.get('texts') + if not texts: + return jsonify({"error": "No texts provided"}), 400 + + audio_urls = [] + + for text in texts: + filename = get_audio_filename(text) + filepath = os.path.join(AUDIO_DIR, filename) + + if os.path.exists(filepath): + audio_urls.append(f"/audio_files/{filename}") + else: + loop = asyncio.get_event_loop() + audio_data = await loop.run_in_executor(executor, generate_audio, text) + sf.write(filepath, audio_data, SAMPLE_RATE) + audio_urls.append(f"/audio_files/{filename}") + + return jsonify({"audio_urls": audio_urls}) + +@app.route('/audio_files/', methods=['GET', 'OPTIONS']) +def get_audio(filename): + if request.method == 'OPTIONS': + return '', 204 + try: + return send_file(os.path.join(AUDIO_DIR, filename), mimetype='audio/wav') + except Exception as e: + return jsonify({"error": str(e)}), 404 + +if __name__ == '__main__': + app.run(port=5002) \ No newline at end of file diff --git a/chat_history/chattts_service.py b/chat_history/chattts_service.py new file mode 100644 index 0000000..cc2a7b5 --- /dev/null +++ b/chat_history/chattts_service.py @@ -0,0 +1,92 @@ +from flask import Flask, request, send_file, jsonify +import ChatTTS +import tempfile +import numpy as np +import soundfile as sf +from flask_cors import CORS +import os +import random +import torch +import pickle + + +app = Flask(__name__) +CORS(app) + +# 初始化 ChatTTS +chat_tts = ChatTTS.Chat() +chat_tts.load(compile=False) + +# 定义采样率 +SAMPLE_RATE = 24000 + +# # 生成一个固定的说话人嵌入 +# FIXED_SPEAKER = chat_tts.sample_random_speaker() + +# 文件名用于保存和加载说话人嵌入 +SPEAKER_EMBEDDING_FILE = 'two_speaker_embedding.pkl' + +def get_or_create_fixed_speaker(): + try: + if os.path.exists(SPEAKER_EMBEDDING_FILE): + with open(SPEAKER_EMBEDDING_FILE, 'rb') as f: + fixed_speaker = pickle.load(f) + else: + fixed_speaker = chat_tts.sample_random_speaker() + with open(SPEAKER_EMBEDDING_FILE, 'wb') as f: + pickle.dump(fixed_speaker, f) + except (EOFError, pickle.UnpicklingError): + print("Warning: Unable to load speaker embedding. Creating a new one.") + fixed_speaker = chat_tts.sample_random_speaker() + with open(SPEAKER_EMBEDDING_FILE, 'wb') as f: + pickle.dump(fixed_speaker, f) + return fixed_speaker + + +# 获取或创建固定的说话人嵌入 +FIXED_SPEAKER = get_or_create_fixed_speaker() + +@app.route('/synthesize', methods=['POST']) +def synthesize(): + data = request.json + text = data.get('text') + if not text: + return jsonify({"error": "No text provided"}), 400 + + temp_file = None + try: + params_infer_code = ChatTTS.Chat.InferCodeParams( + spk_emb=FIXED_SPEAKER, + temperature=0.3, + top_P=0.7, + top_K=20, + ) + + params_refine_text = ChatTTS.Chat.RefineTextParams( + prompt='[oral_2][laugh_0][break_6]', + ) + + wavs = chat_tts.infer(text, params_refine_text=params_refine_text, params_infer_code=params_infer_code) + + audio_data = wavs[0] + + if not np.issubdtype(audio_data.dtype, np.floating): + audio_data = audio_data.astype(np.float32) + + if np.max(np.abs(audio_data)) > 1: + audio_data = audio_data / np.max(np.abs(audio_data)) + + temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav") + sf.write(temp_file.name, audio_data, SAMPLE_RATE) + + return send_file(temp_file.name, mimetype='audio/wav') + + except Exception as e: + return jsonify({"error": str(e)}), 500 + + finally: + if temp_file and os.path.exists(temp_file.name): + os.unlink(temp_file.name) + +if __name__ == '__main__': + app.run(port=5002) \ No newline at end of file diff --git a/chat_history/kafka_chat/demo.py b/chat_history/kafka_chat/demo.py new file mode 100644 index 0000000..7e930db --- /dev/null +++ b/chat_history/kafka_chat/demo.py @@ -0,0 +1,49 @@ +import argparse +import os +import soundfile as sf + +from tools.i18n.i18n import I18nAuto +from GPT_SoVITS.inference_webui import change_gpt_weights, change_sovits_weights, get_tts_wav + +i18n = I18nAuto() + +def synthesize(GPT_model_path, SoVITS_model_path, ref_audio_path, ref_text_path, ref_language, target_text, target_language, output_path, output_filename): + # Read reference text + with open(ref_text_path, 'r', encoding='utf-8') as file: + ref_text = file.read() + + # Change model weights + change_gpt_weights(gpt_path=GPT_model_path) + change_sovits_weights(sovits_path=SoVITS_model_path) + + # Synthesize audio + synthesis_result = get_tts_wav(ref_wav_path=ref_audio_path, + prompt_text=ref_text, + prompt_language=i18n(ref_language), + text=target_text, + text_language=i18n(target_language), top_p=1, temperature=1) + + result_list = list(synthesis_result) + + if result_list: + last_sampling_rate, last_audio_data = result_list[-1] + output_wav_path = os.path.join(output_path, output_filename) + sf.write(output_wav_path, last_audio_data, last_sampling_rate) + print(f"Audio saved to {output_wav_path}") + +def main(): + GPT_model_path = "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt" + SoVITS_model_path = "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth" + ref_audio_path = "/home/zydi//worker_chat/kafka/sample/woman.wav" + ref_text_path = "/home/zydi//worker_chat/kafka/sample/woman.txt" + ref_language = "中文" + target_text = """我们开发了"病人实时健康监测系统"和"AI辅助诊断系统",这些系统显著提高了医疗诊断的效率和准确性。obscura形成了全面的医疗智能解决方案""" + + target_language = "多语种混合" + output_path = "/home/zydi//worker_chat/kafka" + output_filename = "output.wav" + + synthesize(GPT_model_path, SoVITS_model_path, ref_audio_path, ref_text_path, ref_language, target_text, target_language, output_path, output_filename) + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/chat_history/kafka_chat/text_input_db.py b/chat_history/kafka_chat/text_input_db.py new file mode 100644 index 0000000..e0c9ccd --- /dev/null +++ b/chat_history/kafka_chat/text_input_db.py @@ -0,0 +1,210 @@ +import json +import threading +import redis +from kafka import KafkaConsumer, KafkaProducer +from pymongo import MongoClient,TEXT +import requests + +DEFAULT_SYSTEM_PROMPT = "你是一个智能助手,请用尽可能简短、简洁的方式回答问题。当用户询问数据库相关内容时,你可以访问MongoDB数据库来获取额外的信息。" + +def search_mongodb(mongodb_client, query, mongo_db, search_collection): + collection = mongodb_client[mongo_db][search_collection] + results = collection.find({"$text": {"$search": query}}) + + processed_results = [] + for result in results: + if 'original_answer' in result: + processed_results.append(result['original_answer']) + else: + # 记录这个问题并跳过这个结果 + print(f"警告: 文档缺少 'original_answer' 字段: {result['_id']}") + + return processed_results + +def get_conversation_history(redis_client, conversation_id, max_history=5): + history = redis_client.lrange(f"conversation:{conversation_id}", 0, max_history * 2 - 1) + return list(zip(history[::2], history[1::2])) + +def add_to_conversation_history(redis_client, conversation_id, query, answer): + redis_client.rpush(f"conversation:{conversation_id}", query, answer) + redis_client.expire(f"conversation:{conversation_id}", 3600) # 设置1小时的过期时间 + +def generate_answer(query, context, history, mongo_client, mongo_db, search_collection): + full_prompt = DEFAULT_SYSTEM_PROMPT + "\n" + for past_query, past_response in history: + full_prompt += f"用户: {past_query}\n助手: {past_response}\n" + full_prompt += f"用户: {query}\n上下文: {context}\n\n" + full_prompt += "请根据上下文和历史对话回答用户的问题。如果需要额外信息,你可以使用以下函数查询MongoDB数据库:\n" + full_prompt += "search_mongodb(query: str) -> List[str]\n" + full_prompt += "该函数会返回与查询相关的内容列表。\n" + full_prompt += "助手: " + + def search_mongodb_wrapper(query): + return search_mongodb(mongo_client, query, mongo_db, search_collection) + + data = { + "model": "llama3.1", + "prompt": full_prompt, + "stream": True, + "temperature": 0, + "functions": [ + { + "name": "search_mongodb", + "description": "Search for information in MongoDB", + "parameters": { + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "The search query" + } + }, + "required": ["query"] + } + } + ], + "function_call": "auto" + } + + try: + response = requests.post("http://127.0.0.1:11434/api/generate", json=data, stream=True) + response.raise_for_status() + + text_output = "" + for line in response.iter_lines(): + if line: + json_data = json.loads(line) + if 'response' in json_data: + text_output += json_data['response'] + elif 'function_call' in json_data: + function_call = json.loads(json_data['function_call']) + if function_call['name'] == 'search_mongodb': + search_results = search_mongodb_wrapper(function_call['arguments']['query']) + text_output += f"根据MongoDB查询结果:{', '.join(search_results)}\n" + + return text_output + + except Exception as e: + print(f"Error generating answer: {str(e)}") + return "抱歉,生成回答时出现错误。" + +def process_message(message, kafka_config, mongodb_client, redis_client, mongo_db, search_collection): + try: + # 检查 message.value 是否已经是字典 + if isinstance(message.value, dict): + message_data = message.value + else: + # 如果不是字典,尝试解析为 JSON + message_data = json.loads(message.value) + + query = message_data.get('text', '') + conversation_id = message_data.get('conversation_id', 'default') + except (json.JSONDecodeError, AttributeError): + # 如果解析失败或 message.value 不是预期的格式,假设整个消息就是查询文本 + query = str(message.value) + conversation_id = 'default' + + # 获取对话历史 + history = get_conversation_history(redis_client, conversation_id) + + # 在 MongoDB 中搜索相关信息 + search_results = search_mongodb(mongodb_client, query, mongo_db, search_collection) + context = " ".join(search_results) + + # 使用 llama3.1:8b 模型生成答案 + answer = generate_answer(query, context, history, mongodb_client, mongo_db, search_collection) + + # 将对话添加到 Redis 历史记录 + add_to_conversation_history(redis_client, conversation_id, query, answer) + + # 将答案发送到 Kafka 的 voice-output 主题 + producer = KafkaProducer(bootstrap_servers=[kafka_config['bootstrap_servers']], + value_serializer=lambda x: json.dumps(x).encode('utf-8')) + producer.send(kafka_config['voice_output_topic'], {'answer': answer, 'conversation_id': conversation_id}) + + print(f"Processed message: {query}") + print(f"Generated answer: {answer}") + print(f"Sent to voice-output topic: {{'answer': '{answer}', 'conversation_id': '{conversation_id}'}}") +def consumer_thread(kafka_config, mongodb_config, redis_config, consumer_group, thread_id): + consumer = KafkaConsumer( + kafka_config['text_input_topic'], + bootstrap_servers=[kafka_config['bootstrap_servers']], + group_id=consumer_group, + value_deserializer=lambda x: json.loads(x.decode('utf-8')), + auto_offset_reset='earliest', + enable_auto_commit=True + ) + + mongodb_client = MongoClient(mongodb_config['uri']) + redis_client = redis.Redis( + host=redis_config['host'], + port=redis_config['port'], + db=redis_config['db'], + password=redis_config['password'] + ) + + print(f"Consumer thread {thread_id} started, listening to {kafka_config['text_input_topic']} topic, consumer group: {consumer_group}") + + for message in consumer: + print(f"Thread {thread_id} received message in partition {message.partition}, offset: {message.offset}") + process_message(message, kafka_config, mongodb_client, redis_client, + mongodb_config['db_name'], mongodb_config['search_collection']) + +def main(kafka_config, mongodb_config, redis_config): + threads = [] + consumer_group = f"{kafka_config['consumer_group_prefix']}_single" + + for i in range(kafka_config['num_threads']): + thread = threading.Thread( + target=consumer_thread, + args=(kafka_config, mongodb_config, redis_config, consumer_group, i) + ) + thread.start() + threads.append(thread) + print(f"Started consumer thread {i}, consumer group: {consumer_group}") + + # Wait for all threads to complete (in reality, they will run indefinitely) + for thread in threads: + thread.join() + +if __name__ == "__main__": + # Kafka configuration + KAFKA_BOOTSTRAP_SERVERS = '222.186.136.78:9092' + KAFKA_INPUT_TOPIC = 'text-input' + KAFKA_OUTPUT_TOPIC = 'voice-output' + KAFKA_CONSUMER_GROUP_PREFIX = 'text_group' + KAFKA_NUM_THREADS = 3 # 您可以根据需要调整线程数 + + # MongoDB configuration + MONGO_URI = 'mongodb://minio_mongo:BCd4npzKBnwmCRdh@222.186.136.78:27017/?authSource=minio_mongo' + MONGO_DB = 'minio_mongo' + MONGO_SEARCH_COLLECTION = 'cpm' + + # Redis configuration + REDIS_HOST = '222.186.136.78' + REDIS_PORT = 6379 + REDIS_DB = 0 + REDIS_PASSWORD = 'Obscura@2024' # 添加Redis密码 + + kafka_config = { + 'bootstrap_servers': KAFKA_BOOTSTRAP_SERVERS, + 'text_input_topic': KAFKA_INPUT_TOPIC, + 'voice_output_topic': KAFKA_OUTPUT_TOPIC, + 'consumer_group_prefix': KAFKA_CONSUMER_GROUP_PREFIX, + 'num_threads': KAFKA_NUM_THREADS + } + + mongodb_config = { + 'uri': MONGO_URI, + 'db_name': MONGO_DB, + 'search_collection': MONGO_SEARCH_COLLECTION + } + + redis_config = { + 'host': REDIS_HOST, + 'port': REDIS_PORT, + 'db': REDIS_DB, + 'password': REDIS_PASSWORD # 添加Redis密码 + } + + main(kafka_config, mongodb_config, redis_config) \ No newline at end of file diff --git a/chat_history/kafka_chat/voice_input.py b/chat_history/kafka_chat/voice_input.py new file mode 100644 index 0000000..5ecc4ca --- /dev/null +++ b/chat_history/kafka_chat/voice_input.py @@ -0,0 +1,169 @@ +import os +import json +import uuid +from kafka import KafkaConsumer, KafkaProducer, TopicPartition +import whisper +from pydub import AudioSegment +import io +import tempfile +from minio import Minio +import threading +import requests + +def get_audio_from_minio(minio_client, bucket, object_name): + try: + response = minio_client.get_object(bucket, object_name) + return response.read() + except Exception as e: + print(f"从 MinIO 获取音频时出错: {str(e)}") + print(f"Bucket: {bucket}, Object: {object_name}") + return None + +def process_audio(model, audio_data, file_extension): + with tempfile.NamedTemporaryFile(delete=False, suffix=f'.{file_extension}') as temp_audio_file: + temp_audio_file.write(audio_data) + temp_audio_file.flush() + + if file_extension.lower() != 'wav': + audio = AudioSegment.from_file(temp_audio_file.name, format=file_extension) + wav_path = temp_audio_file.name + '.wav' + audio.export(wav_path, format="wav") + else: + wav_path = temp_audio_file.name + + result = model.transcribe(wav_path) + + os.unlink(temp_audio_file.name) + if file_extension.lower() != 'wav': + os.unlink(wav_path) + + return json.dumps({"text": result["text"]}) + +def consumer_thread(kafka_config, minio_config, model, partition): + client_id = f'voice_{partition}' + group_id = f"{kafka_config['consumer_group']}_{partition}" + + consumer = KafkaConsumer( + bootstrap_servers=[kafka_config['bootstrap_servers']], + group_id=group_id, + client_id=client_id, + value_deserializer=lambda x: json.loads(x.decode('utf-8')), + enable_auto_commit=True, + auto_commit_interval_ms=5000 + ) + + # 手动分配分区 + topic_partition = TopicPartition(kafka_config['voice_input_topic'], partition) + consumer.assign([topic_partition]) + + producer = KafkaProducer( + bootstrap_servers=[kafka_config['bootstrap_servers']], + value_serializer=lambda x: json.dumps(x).encode('utf-8') + ) + + minio_client = Minio( + minio_config['endpoint'], + access_key=minio_config['access_key'], + secret_key=minio_config['secret_key'], + secure=minio_config['secure'] + ) + + print(f"消费者 {client_id} 开始监听 {kafka_config['voice_input_topic']} 主题的分区 {partition}...") + + + for message in consumer: + print(f"消费者 {client_id} 从分区 {partition} 收到新的音频消息") + event_info = message.value + print(f"事件信息: {event_info}") + + # 从S3事件中提取音频文件路径 + audio_path = event_info.get('Key') + if not audio_path: + # 如果在顶层没有找到Key,尝试从Records中获取 + records = event_info.get('Records', []) + if records: + audio_path = records[0].get('s3', {}).get('object', {}).get('key') + + if not audio_path: + print(f"消费者 {client_id} 无法从事件中获取音频路径") + continue + + # 移除可能的 'audio/' 前缀 + if audio_path.startswith('audio/'): + audio_path = audio_path[6:] + + file_extension = audio_path.split('.')[-1] if '.' in audio_path else 'wav' + + print(f"尝试从MinIO获取音频: {audio_path}") + audio_data = get_audio_from_minio(minio_client, minio_config['bucket'], audio_path) + + if audio_data: + try: + transcribed_result = process_audio(model, audio_data, file_extension) + transcribed_data = json.loads(transcribed_result) + transcribed_text = transcribed_data["text"] + print(f"消费者 {client_id} 识别结果: {transcribed_text}") + + audio_info = { + 'file_name': audio_path, + 'file_extension': file_extension, + 'size': len(audio_data) + } + + # 发送到下一个Kafka主题 + producer.send(kafka_config['text_input_topic'], value=json.dumps({ + 'text': transcribed_text, + 'audio_info': audio_info, + 'consumer_id': client_id, + 'partition': partition + })) + print(f"消费者 {client_id} 已发送识别结果到 {kafka_config['text_input_topic']} 主题") + + # # 发送结果到PHP服务 + # send_to_php(transcribed_text, audio_info) + + except Exception as e: + print(f"消费者 {client_id} 处理音频时出错: {str(e)}") + import traceback + traceback.print_exc() + else: + print(f"消费者 {client_id} 无法获取音频数据") + + consumer.close() + +def main(kafka_config, minio_config, whisper_config): + model = whisper.load_model(whisper_config['model_name']) + + threads = [] + for partition in range(kafka_config['num_partitions']): + thread = threading.Thread(target=consumer_thread, args=(kafka_config, minio_config, model, partition)) + thread.start() + threads.append(thread) + + for thread in threads: + thread.join() + +if __name__ == "__main__": + # Kafka 配置 + kafka_config = { + 'bootstrap_servers': '222.186.136.78:9092', + 'voice_input_topic': 'voice-input', + 'text_input_topic': 'text-input', + 'consumer_group': 'voice_group', + 'num_partitions': 3 # 修改为实际的分区数 + } + + # MinIO 配置 + minio_config = { + 'endpoint': "api.obscura.work", + 'access_key': "00v3MtLtIAIkR3hkIuYR", + 'secret_key': "XfDeVe5bJjPU21NEYc023gzJVUTJzQqxsWHqIKMf", + 'bucket': 'audio', + 'secure': True + } + # Whisper 配置 + whisper_config = { + 'model_name': 'large-v3' # 可以根据需要选择不同的模型大小 + } + + main(kafka_config, minio_config, whisper_config) \ No newline at end of file diff --git a/chat_history/kafka_chat/voice_output.py b/chat_history/kafka_chat/voice_output.py new file mode 100644 index 0000000..231fb73 --- /dev/null +++ b/chat_history/kafka_chat/voice_output.py @@ -0,0 +1,202 @@ +import json +import hashlib +import io +import threading +from concurrent.futures import ThreadPoolExecutor +from kafka import KafkaConsumer, KafkaProducer +from minio import Minio +from GPT_SoVITS.inference_webui import get_tts_wav +from tools.i18n.i18n import I18nAuto +import soundfile as sf +import redis + +i18n = I18nAuto() + +# Global variables +global_model_config = None + +def initialize_models(model_config): + global global_model_config + global_model_config = model_config + print("Models initialized") + +def generate_content_hash(text): + return hashlib.md5(text.encode()).hexdigest() + +def synthesize(target_text): + global global_model_config + + with open(global_model_config['ref_text_path'], 'r', encoding='utf-8') as file: + ref_text = file.read() + + synthesis_result = get_tts_wav( + ref_wav_path=global_model_config['ref_audio_path'], + prompt_text=ref_text, + prompt_language=i18n(global_model_config['ref_language']), + text=target_text, + text_language=i18n(global_model_config['target_language']), + top_p=1, temperature=1 + ) + + result_list = list(synthesis_result) + print(f"Synthesizing audio for text: {target_text}") + if result_list: + last_sampling_rate, last_audio_data = result_list[-1] + return last_sampling_rate, last_audio_data + return None, None + +def process_message(message_data, minio_client, minio_config, redis_client): + target_text = message_data.get('answer', message_data.get('text', '')) + content_hash = generate_content_hash(target_text) + + # Check Redis cache + cached_audio_info = redis_client.get(content_hash) + if cached_audio_info: + print(f"Using Redis cached audio, content hash: {content_hash}") + return json.loads(cached_audio_info) + + # If not in Redis, check MinIO + bucket_name = minio_config['bucket'] + object_name = f"{content_hash}.wav" + try: + minio_client.stat_object(bucket_name, object_name) + print(f"Using existing audio from MinIO, content hash: {content_hash}") + audio_info = { + 'id': content_hash, + 'text': target_text, + 'content_hash': content_hash, + 'minio_bucket': bucket_name, + 'minio_object': object_name, + 'status': 'completed', + } + redis_client.set(content_hash, json.dumps(audio_info)) + return audio_info + except: + pass # Object doesn't exist in MinIO, continue to synthesis + + # If not found, synthesize new audio + try: + sampling_rate, audio_data = synthesize(target_text) + + if audio_data is not None: + audio_id = content_hash + object_name = f"{audio_id}.wav" + audio_buffer = io.BytesIO() + sf.write(audio_buffer, audio_data, sampling_rate, format='wav') + audio_buffer.seek(0) + + minio_client.put_object( + bucket_name, object_name, audio_buffer, + length=audio_buffer.getbuffer().nbytes + ) + + etag = minio_client.stat_object(bucket_name, object_name).etag + + audio_info = { + 'id': audio_id, + 'text': target_text, + 'sampling_rate': sampling_rate, + 'content_hash': content_hash, + 'minio_bucket': bucket_name, + 'minio_object': object_name, + 'etag': etag, + 'status': 'completed', + } + + redis_client.set(content_hash, json.dumps(audio_info)) + + return audio_info + except Exception as e: + print(f"Error processing message: {e}") + error_info = { + 'status': 'failed', + 'error': str(e), + 'text': target_text, + 'content_hash': content_hash, + } + redis_client.set(content_hash, json.dumps(error_info)) + return None + +def message_handler(message, minio_client, minio_config, redis_client): + print(f"Processing message: {message.value}") + message_data = json.loads(message.value) + audio_info = process_message(message_data, minio_client, minio_config, redis_client) + +def consumer_thread(consumer_id, kafka_config, minio_config, redis_config): + consumer = KafkaConsumer( + kafka_config['text_input_topic'], + bootstrap_servers=kafka_config['bootstrap_servers'], + auto_offset_reset='latest', + enable_auto_commit=True, + group_id=kafka_config['consumer_group'] + ) + + minio_client = Minio( + minio_config['endpoint'], + access_key=minio_config['access_key'], + secret_key=minio_config['secret_key'], + secure=minio_config['secure'] + ) + + redis_client = redis.Redis( + host=redis_config['host'], + port=redis_config['port'], + db=redis_config['db'], + password=redis_config['password'] + ) + + with ThreadPoolExecutor(max_workers=kafka_config['threads_per_consumer']) as executor: + print(f"Consumer {consumer_id} started running") + for message in consumer: + executor.submit(message_handler, message, minio_client, minio_config, redis_client) + +def main(kafka_config, minio_config, model_config, redis_config): + initialize_models(model_config) + + threads = [] + for i in range(kafka_config['num_consumers']): + t = threading.Thread(target=consumer_thread, args=(i, kafka_config, minio_config, redis_config)) + threads.append(t) + t.start() + + for t in threads: + t.join() + +if __name__ == "__main__": + # Kafka configuration + kafka_config = { + 'bootstrap_servers': '222.186.136.78:9092', + 'text_input_topic': 'voice-output', + 'consumer_group': 'voice_group', + 'num_consumers': 3, + 'threads_per_consumer': 4 + } + + # MinIO configuration + minio_config = { + 'endpoint': "api.obscura.work", + 'access_key': "00v3MtLtIAIkR3hkIuYR", + 'secret_key': "XfDeVe5bJjPU21NEYc023gzJVUTJzQqxsWHqIKMf", + 'bucket': 'tts-audio', + 'secure': True + } + + # Redis configuration + redis_config = { + 'host': '222.186.136.78', + 'port': 6379, + 'db': 4, + 'password': "Obscura@2024" + } + + # Model configuration + model_config = { + 'GPT_model_path': "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt", + 'SoVITS_model_path': "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth", + 'ref_audio_path': "sample/woman.wav", + 'ref_text_path': "sample/woman.txt", + 'ref_language': "中文", + 'target_language': "多语种混合" + } + + main(kafka_config, minio_config, model_config, redis_config) \ No newline at end of file diff --git a/chat_history/kafka_chat/weight.json b/chat_history/kafka_chat/weight.json new file mode 100644 index 0000000..58ec624 --- /dev/null +++ b/chat_history/kafka_chat/weight.json @@ -0,0 +1 @@ +{"GPT": {"v2": "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt"}, "SoVITS": {"v2": "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth"}} \ No newline at end of file diff --git a/chat_history/qwen_service.py b/chat_history/qwen_service.py new file mode 100644 index 0000000..5e9ce48 --- /dev/null +++ b/chat_history/qwen_service.py @@ -0,0 +1,96 @@ +# from flask import Flask, request, jsonify +# import requests +# import json +# from flask_cors import CORS + +# app = Flask(__name__) +# CORS(app) +# DEFAULT_SYSTEM_PROMPT = "你是一个智能助手,请用尽可能简短、简洁、友好的方式回答问题。" + +# @app.route('/chat', methods=['POST']) +# def chat(): +# data = request.json +# query = data.get('query') +# history = data.get('history', []) +# model = data.get('model', 'qwen2') + +# full_prompt = f"{DEFAULT_SYSTEM_PROMPT}\n用户: {query}" + +# data = { +# "model": model, +# "prompt": full_prompt, +# "stream": True, +# "temperature": 0 +# } + +# try: +# response = requests.post("http://127.0.0.1:11434/api/generate", json=data, stream=True) +# response.raise_for_status() + +# text_output = "" +# for line in response.iter_lines(): +# if line: +# json_data = json.loads(line) +# if 'response' in json_data: +# text_output += json_data['response'] + +# final_history = history + [(query, text_output)] + +# return jsonify({"response": text_output, "history": final_history}) + +# except Exception as e: +# return jsonify({"error": str(e)}), 500 + +# if __name__ == '__main__': +# app.run(port=5001) + + +from flask import Flask, request, jsonify +import requests +import json +from flask_cors import CORS + +app = Flask(__name__) +CORS(app) +DEFAULT_SYSTEM_PROMPT = "你是一个智能助手,请用尽可能简短、简洁、友好的方式回答问题。输入的所有内容都来自于语音识别输入,因此可能会出现各种错误,请尽可能猜测用户的意思" + +@app.route('/chat', methods=['POST']) +def chat(): + data = request.json + query = data.get('query') + history = data.get('history', []) + model = data.get('model', 'qwen2') + + # 构建包含历史对话的完整提示 + full_prompt = DEFAULT_SYSTEM_PROMPT + "\n" + for past_query, past_response in history: + full_prompt += f"用户: {past_query}\n助手: {past_response}\n" + full_prompt += f"用户: {query}" + + data = { + "model": model, + "prompt": full_prompt, + "stream": True, + "temperature": 0 + } + + try: + response = requests.post("http://127.0.0.1:11434/api/generate", json=data, stream=True) + response.raise_for_status() + + text_output = "" + for line in response.iter_lines(): + if line: + json_data = json.loads(line) + if 'response' in json_data: + text_output += json_data['response'] + + final_history = history + [(query, text_output)] + + return jsonify({"response": text_output, "history": final_history}) + + except Exception as e: + return jsonify({"error": str(e)}), 500 + +if __name__ == '__main__': + app.run(port=5001) \ No newline at end of file diff --git a/chat_history/sovit.py b/chat_history/sovit.py new file mode 100644 index 0000000..7e930db --- /dev/null +++ b/chat_history/sovit.py @@ -0,0 +1,49 @@ +import argparse +import os +import soundfile as sf + +from tools.i18n.i18n import I18nAuto +from GPT_SoVITS.inference_webui import change_gpt_weights, change_sovits_weights, get_tts_wav + +i18n = I18nAuto() + +def synthesize(GPT_model_path, SoVITS_model_path, ref_audio_path, ref_text_path, ref_language, target_text, target_language, output_path, output_filename): + # Read reference text + with open(ref_text_path, 'r', encoding='utf-8') as file: + ref_text = file.read() + + # Change model weights + change_gpt_weights(gpt_path=GPT_model_path) + change_sovits_weights(sovits_path=SoVITS_model_path) + + # Synthesize audio + synthesis_result = get_tts_wav(ref_wav_path=ref_audio_path, + prompt_text=ref_text, + prompt_language=i18n(ref_language), + text=target_text, + text_language=i18n(target_language), top_p=1, temperature=1) + + result_list = list(synthesis_result) + + if result_list: + last_sampling_rate, last_audio_data = result_list[-1] + output_wav_path = os.path.join(output_path, output_filename) + sf.write(output_wav_path, last_audio_data, last_sampling_rate) + print(f"Audio saved to {output_wav_path}") + +def main(): + GPT_model_path = "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt" + SoVITS_model_path = "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth" + ref_audio_path = "/home/zydi//worker_chat/kafka/sample/woman.wav" + ref_text_path = "/home/zydi//worker_chat/kafka/sample/woman.txt" + ref_language = "中文" + target_text = """我们开发了"病人实时健康监测系统"和"AI辅助诊断系统",这些系统显著提高了医疗诊断的效率和准确性。obscura形成了全面的医疗智能解决方案""" + + target_language = "多语种混合" + output_path = "/home/zydi//worker_chat/kafka" + output_filename = "output.wav" + + synthesize(GPT_model_path, SoVITS_model_path, ref_audio_path, ref_text_path, ref_language, target_text, target_language, output_path, output_filename) + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/chat_history/whisper_service.py b/chat_history/whisper_service.py new file mode 100644 index 0000000..8d43a7a --- /dev/null +++ b/chat_history/whisper_service.py @@ -0,0 +1,25 @@ +from flask import Flask, request, jsonify +import whisper +import tempfile +from flask_cors import CORS + +app = Flask(__name__) +CORS(app) +print("Loading Whisper model...") +model = whisper.load_model("small") +print("Whisper model loaded.") + +@app.route('/transcribe', methods=['POST']) +def transcribe(): + if 'audio' not in request.files: + return jsonify({"error": "No audio file provided"}), 400 + + audio_file = request.files['audio'] + with tempfile.NamedTemporaryFile(delete=False, suffix='.wav') as temp_audio: + audio_file.save(temp_audio.name) + result = model.transcribe(temp_audio.name) + + return jsonify({"transcription": result['text']}) + +if __name__ == '__main__': + app.run(port=5000) \ No newline at end of file